Merge ~kajiya/+git/google-guest-agent:kajiya/upstream into ~ubuntu-core-dev/+git/google-guest-agent:upstream

Proposed by Chloé Smith
Status: Merged
Merged at revision: 382d1f5333b0eb73bae0fe74efc5534c295219f6
Proposed branch: ~kajiya/+git/google-guest-agent:kajiya/upstream
Merge into: ~ubuntu-core-dev/+git/google-guest-agent:upstream
Diff against target: 10201 lines (+6466/-1274)
70 files modified
.gitignore (+11/-1)
OWNERS (+2/-2)
THIRD_PARTY_LICENSES/cloud.google.com/go/internal/LICENSE (+2/-1)
THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/LICENSE (+2/-2)
THIRD_PARTY_LICENSES/github.com/Microsoft/go-winio/LICENSE (+22/-0)
THIRD_PARTY_LICENSES/golang.org/x/xerrors/LICENSE (+27/-0)
THIRD_PARTY_LICENSES/software.sslmate.com/src/go-pkcs12/LICENSE (+1/-1)
dev/null (+0/-93)
gce_workload_cert_refresh/main.go (+169/-152)
gce_workload_cert_refresh/main_test.go (+495/-0)
go.mod (+5/-2)
go.sum (+9/-0)
google_authorized_keys/main.go (+14/-49)
google_authorized_keys/main_test.go (+103/-44)
google_guest_agent/addresses.go (+21/-245)
google_guest_agent/agentcrypto/mtls_mds.go (+1/-1)
google_guest_agent/agentcrypto/mtls_mds_linux.go (+35/-14)
google_guest_agent/agentcrypto/mtls_mds_linux_test.go (+21/-4)
google_guest_agent/agentcrypto/mtls_mds_windows.go (+25/-0)
google_guest_agent/cfg/cfg.go (+30/-8)
google_guest_agent/cfg/cfg_test.go (+3/-1)
google_guest_agent/command/Readme.md (+24/-0)
google_guest_agent/command/command.go (+146/-0)
google_guest_agent/command/command_linux.go (+140/-0)
google_guest_agent/command/command_monitor.go (+228/-0)
google_guest_agent/command/command_test.go (+209/-0)
google_guest_agent/command/command_windows.go (+104/-0)
google_guest_agent/command/command_windows_test.go (+73/-0)
google_guest_agent/diagnostics.go (+2/-1)
google_guest_agent/events/events.go (+308/-138)
google_guest_agent/events/events_test.go (+372/-107)
google_guest_agent/events/metadata/metadata.go (+0/-6)
google_guest_agent/events/metadata/metadata_test.go (+4/-0)
google_guest_agent/fakes/fake_mds.go (+5/-0)
google_guest_agent/instance_setup.go (+18/-7)
google_guest_agent/main.go (+22/-20)
google_guest_agent/network/manager/common.go (+90/-0)
google_guest_agent/network/manager/dhclient_linux.go (+292/-0)
google_guest_agent/network/manager/dhclient_linux_test.go (+511/-0)
google_guest_agent/network/manager/manager.go (+362/-0)
google_guest_agent/network/manager/manager_test.go (+353/-0)
google_guest_agent/network/manager/systemd_networkd_linux.go (+295/-0)
google_guest_agent/network/manager/systemd_networkd_linux_test.go (+549/-0)
google_guest_agent/non_windows_accounts.go (+3/-1)
google_guest_agent/oslogin.go (+91/-0)
google_guest_agent/ps/ps.go (+37/-0)
google_guest_agent/ps/ps_linux.go (+121/-0)
google_guest_agent/ps/ps_linux_test.go (+157/-0)
google_guest_agent/ps/ps_windows.go (+21/-0)
google_guest_agent/run/run.go (+51/-4)
google_guest_agent/scheduler/logger.go (+2/-2)
google_guest_agent/scheduler/scheduler.go (+1/-0)
google_guest_agent/scheduler/scheduler_test.go (+8/-8)
google_guest_agent/snapshot_listener.go (+7/-5)
google_guest_agent/sshca/sshca.go (+9/-11)
google_guest_agent/windows_accounts.go (+2/-1)
google_guest_agent/windows_accounts_test.go (+49/-42)
google_metadata_script_runner/main.go (+63/-78)
google_metadata_script_runner/main_test.go (+78/-11)
metadata/metadata.go (+16/-0)
metadata/metadata_test.go (+30/-0)
packaging/genlicense.sh (+24/-0)
retry/retry.go (+107/-0)
retry/retry_test.go (+193/-0)
utils/file.go (+80/-0)
utils/file_test.go (+87/-0)
utils/serial_port_logger.go (+35/-0)
utils/ssh.go (+19/-108)
utils/ssh_test.go (+32/-104)
utils/test.go (+38/-0)
Reviewer Review Type Date Requested Status
Utkarsh Gupta Approve
Review via email: mp+460883@code.launchpad.net

Commit message

New upstream version 20240213.00

To post a comment you must log in.
Revision history for this message
Utkarsh Gupta (utkarsh) wrote :

$ dput ubuntu ../google-guest-agent_20240213.00-0ubuntu1_source.changes
Uploading google-guest-agent using ftp to ubuntu (host: upload.ubuntu.com; directory: /ubuntu)
running supported-distribution: check whether the target distribution is currently supported (using distro-info)
{'allowed': ['release', 'proposed', 'backports', 'security'], 'known': ['release', 'proposed', 'updates', 'backports', 'security']}
running required-fields: check whether a field is present and non-empty in the changes file
running checksum: verify checksums before uploading
running suite-mismatch: check the target distribution for common errors
running check-debs: makes sure the upload contains a binary package
running gpg: check GnuPG signatures before the upload
Uploading google-guest-agent_20240213.00-0ubuntu1.dsc
Uploading google-guest-agent_20240213.00.orig.tar.gz
Uploading google-guest-agent_20240213.00-0ubuntu1.debian.tar.xz
Uploading google-guest-agent_20240213.00-0ubuntu1_source.buildinfo
Uploading google-guest-agent_20240213.00-0ubuntu1_source.changes

review: Approve

Preview Diff

[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1diff --git a/.gitignore b/.gitignore
2index 2ceee21..85eb02e 100644
3--- a/.gitignore
4+++ b/.gitignore
5@@ -1,5 +1,15 @@
6-# ignore all built binaries
7+# Ignore all built binaries.
8 **/gce_workload_cert_refresh
9+**/gce_workload_cert_refresh.exe
10 **/google_authorized_keys
11+**/google_authorized_keys.exe
12 **/google_guest_agent
13+**/google_guest_agent.exe
14 **/google_metadata_script_runner
15+**/google_metadata_script_runner.exe
16+
17+# Don't ignore new content to directories.
18+!**/gce_workload_cert_refresh/
19+!**/google_authorized_keys/
20+!**/google_guest_agent/
21+!**/google_metadata_script_runner/
22\ No newline at end of file
23diff --git a/OWNERS b/OWNERS
24index edde10e..59fca61 100644
25--- a/OWNERS
26+++ b/OWNERS
27@@ -3,14 +3,14 @@
28
29 approvers:
30 - a-crate
31+ - ajorg
32 - bkatyl
33 - chaitanyakulkarni28
34 - dorileo
35 - drewhli
36 - elicriffield
37+ - gaughen
38 - jjerger
39 - karnvadaliya
40 - koln67
41- - quintonamore
42- - vorakl
43 - zmarano
44diff --git a/THIRD_PARTY_LICENSES/cloud.google.com/go/LICENSE b/THIRD_PARTY_LICENSES/cloud.google.com/go/iam/LICENSE
45similarity index 100%
46rename from THIRD_PARTY_LICENSES/cloud.google.com/go/LICENSE
47rename to THIRD_PARTY_LICENSES/cloud.google.com/go/iam/LICENSE
48diff --git a/THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/LICENSE b/THIRD_PARTY_LICENSES/cloud.google.com/go/internal/LICENSE
49similarity index 100%
50rename from THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/LICENSE
51rename to THIRD_PARTY_LICENSES/cloud.google.com/go/internal/LICENSE
52index 65fb971..d645695 100644
53--- a/THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/LICENSE
54+++ b/THIRD_PARTY_LICENSES/cloud.google.com/go/internal/LICENSE
55@@ -1,3 +1,4 @@
56+
57 Apache License
58 Version 2.0, January 2004
59 http://www.apache.org/licenses/
60@@ -186,7 +187,7 @@
61 same "printed page" as the copyright notice for easier
62 identification within third-party archives.
63
64- Copyright 2020 Google Inc.
65+ Copyright [yyyy] [name of copyright owner]
66
67 Licensed under the Apache License, Version 2.0 (the "License");
68 you may not use this file except in compliance with the License.
69diff --git a/THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/LICENSE b/THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/LICENSE
70index 65fb971..f49a4e1 100644
71--- a/THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/LICENSE
72+++ b/THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/LICENSE
73@@ -186,7 +186,7 @@
74 same "printed page" as the copyright notice for easier
75 identification within third-party archives.
76
77- Copyright 2020 Google Inc.
78+ Copyright [yyyy] [name of copyright owner]
79
80 Licensed under the Apache License, Version 2.0 (the "License");
81 you may not use this file except in compliance with the License.
82@@ -198,4 +198,4 @@
83 distributed under the License is distributed on an "AS IS" BASIS,
84 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85 See the License for the specific language governing permissions and
86- limitations under the License.
87+ limitations under the License.
88\ No newline at end of file
89diff --git a/THIRD_PARTY_LICENSES/github.com/Microsoft/go-winio/LICENSE b/THIRD_PARTY_LICENSES/github.com/Microsoft/go-winio/LICENSE
90new file mode 100644
91index 0000000..b8b569d
92--- /dev/null
93+++ b/THIRD_PARTY_LICENSES/github.com/Microsoft/go-winio/LICENSE
94@@ -0,0 +1,22 @@
95+The MIT License (MIT)
96+
97+Copyright (c) 2015 Microsoft
98+
99+Permission is hereby granted, free of charge, to any person obtaining a copy
100+of this software and associated documentation files (the "Software"), to deal
101+in the Software without restriction, including without limitation the rights
102+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
103+copies of the Software, and to permit persons to whom the Software is
104+furnished to do so, subject to the following conditions:
105+
106+The above copyright notice and this permission notice shall be included in all
107+copies or substantial portions of the Software.
108+
109+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
110+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
111+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
112+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
113+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
114+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
115+SOFTWARE.
116+
117diff --git a/THIRD_PARTY_LICENSES/golang.org/x/xerrors/LICENSE b/THIRD_PARTY_LICENSES/golang.org/x/xerrors/LICENSE
118new file mode 100644
119index 0000000..e4a47e1
120--- /dev/null
121+++ b/THIRD_PARTY_LICENSES/golang.org/x/xerrors/LICENSE
122@@ -0,0 +1,27 @@
123+Copyright (c) 2019 The Go Authors. All rights reserved.
124+
125+Redistribution and use in source and binary forms, with or without
126+modification, are permitted provided that the following conditions are
127+met:
128+
129+ * Redistributions of source code must retain the above copyright
130+notice, this list of conditions and the following disclaimer.
131+ * Redistributions in binary form must reproduce the above
132+copyright notice, this list of conditions and the following disclaimer
133+in the documentation and/or other materials provided with the
134+distribution.
135+ * Neither the name of Google Inc. nor the names of its
136+contributors may be used to endorse or promote products derived from
137+this software without specific prior written permission.
138+
139+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
140+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
141+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
142+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
143+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
144+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
145+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
146+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
147+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
148+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
149+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
150diff --git a/THIRD_PARTY_LICENSES/google.golang.org/genproto/LICENSE b/THIRD_PARTY_LICENSES/google.golang.org/genproto/LICENSE
151deleted file mode 100644
152index d645695..0000000
153--- a/THIRD_PARTY_LICENSES/google.golang.org/genproto/LICENSE
154+++ /dev/null
155@@ -1,202 +0,0 @@
156-
157- Apache License
158- Version 2.0, January 2004
159- http://www.apache.org/licenses/
160-
161- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
162-
163- 1. Definitions.
164-
165- "License" shall mean the terms and conditions for use, reproduction,
166- and distribution as defined by Sections 1 through 9 of this document.
167-
168- "Licensor" shall mean the copyright owner or entity authorized by
169- the copyright owner that is granting the License.
170-
171- "Legal Entity" shall mean the union of the acting entity and all
172- other entities that control, are controlled by, or are under common
173- control with that entity. For the purposes of this definition,
174- "control" means (i) the power, direct or indirect, to cause the
175- direction or management of such entity, whether by contract or
176- otherwise, or (ii) ownership of fifty percent (50%) or more of the
177- outstanding shares, or (iii) beneficial ownership of such entity.
178-
179- "You" (or "Your") shall mean an individual or Legal Entity
180- exercising permissions granted by this License.
181-
182- "Source" form shall mean the preferred form for making modifications,
183- including but not limited to software source code, documentation
184- source, and configuration files.
185-
186- "Object" form shall mean any form resulting from mechanical
187- transformation or translation of a Source form, including but
188- not limited to compiled object code, generated documentation,
189- and conversions to other media types.
190-
191- "Work" shall mean the work of authorship, whether in Source or
192- Object form, made available under the License, as indicated by a
193- copyright notice that is included in or attached to the work
194- (an example is provided in the Appendix below).
195-
196- "Derivative Works" shall mean any work, whether in Source or Object
197- form, that is based on (or derived from) the Work and for which the
198- editorial revisions, annotations, elaborations, or other modifications
199- represent, as a whole, an original work of authorship. For the purposes
200- of this License, Derivative Works shall not include works that remain
201- separable from, or merely link (or bind by name) to the interfaces of,
202- the Work and Derivative Works thereof.
203-
204- "Contribution" shall mean any work of authorship, including
205- the original version of the Work and any modifications or additions
206- to that Work or Derivative Works thereof, that is intentionally
207- submitted to Licensor for inclusion in the Work by the copyright owner
208- or by an individual or Legal Entity authorized to submit on behalf of
209- the copyright owner. For the purposes of this definition, "submitted"
210- means any form of electronic, verbal, or written communication sent
211- to the Licensor or its representatives, including but not limited to
212- communication on electronic mailing lists, source code control systems,
213- and issue tracking systems that are managed by, or on behalf of, the
214- Licensor for the purpose of discussing and improving the Work, but
215- excluding communication that is conspicuously marked or otherwise
216- designated in writing by the copyright owner as "Not a Contribution."
217-
218- "Contributor" shall mean Licensor and any individual or Legal Entity
219- on behalf of whom a Contribution has been received by Licensor and
220- subsequently incorporated within the Work.
221-
222- 2. Grant of Copyright License. Subject to the terms and conditions of
223- this License, each Contributor hereby grants to You a perpetual,
224- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
225- copyright license to reproduce, prepare Derivative Works of,
226- publicly display, publicly perform, sublicense, and distribute the
227- Work and such Derivative Works in Source or Object form.
228-
229- 3. Grant of Patent License. Subject to the terms and conditions of
230- this License, each Contributor hereby grants to You a perpetual,
231- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
232- (except as stated in this section) patent license to make, have made,
233- use, offer to sell, sell, import, and otherwise transfer the Work,
234- where such license applies only to those patent claims licensable
235- by such Contributor that are necessarily infringed by their
236- Contribution(s) alone or by combination of their Contribution(s)
237- with the Work to which such Contribution(s) was submitted. If You
238- institute patent litigation against any entity (including a
239- cross-claim or counterclaim in a lawsuit) alleging that the Work
240- or a Contribution incorporated within the Work constitutes direct
241- or contributory patent infringement, then any patent licenses
242- granted to You under this License for that Work shall terminate
243- as of the date such litigation is filed.
244-
245- 4. Redistribution. You may reproduce and distribute copies of the
246- Work or Derivative Works thereof in any medium, with or without
247- modifications, and in Source or Object form, provided that You
248- meet the following conditions:
249-
250- (a) You must give any other recipients of the Work or
251- Derivative Works a copy of this License; and
252-
253- (b) You must cause any modified files to carry prominent notices
254- stating that You changed the files; and
255-
256- (c) You must retain, in the Source form of any Derivative Works
257- that You distribute, all copyright, patent, trademark, and
258- attribution notices from the Source form of the Work,
259- excluding those notices that do not pertain to any part of
260- the Derivative Works; and
261-
262- (d) If the Work includes a "NOTICE" text file as part of its
263- distribution, then any Derivative Works that You distribute must
264- include a readable copy of the attribution notices contained
265- within such NOTICE file, excluding those notices that do not
266- pertain to any part of the Derivative Works, in at least one
267- of the following places: within a NOTICE text file distributed
268- as part of the Derivative Works; within the Source form or
269- documentation, if provided along with the Derivative Works; or,
270- within a display generated by the Derivative Works, if and
271- wherever such third-party notices normally appear. The contents
272- of the NOTICE file are for informational purposes only and
273- do not modify the License. You may add Your own attribution
274- notices within Derivative Works that You distribute, alongside
275- or as an addendum to the NOTICE text from the Work, provided
276- that such additional attribution notices cannot be construed
277- as modifying the License.
278-
279- You may add Your own copyright statement to Your modifications and
280- may provide additional or different license terms and conditions
281- for use, reproduction, or distribution of Your modifications, or
282- for any such Derivative Works as a whole, provided Your use,
283- reproduction, and distribution of the Work otherwise complies with
284- the conditions stated in this License.
285-
286- 5. Submission of Contributions. Unless You explicitly state otherwise,
287- any Contribution intentionally submitted for inclusion in the Work
288- by You to the Licensor shall be under the terms and conditions of
289- this License, without any additional terms or conditions.
290- Notwithstanding the above, nothing herein shall supersede or modify
291- the terms of any separate license agreement you may have executed
292- with Licensor regarding such Contributions.
293-
294- 6. Trademarks. This License does not grant permission to use the trade
295- names, trademarks, service marks, or product names of the Licensor,
296- except as required for reasonable and customary use in describing the
297- origin of the Work and reproducing the content of the NOTICE file.
298-
299- 7. Disclaimer of Warranty. Unless required by applicable law or
300- agreed to in writing, Licensor provides the Work (and each
301- Contributor provides its Contributions) on an "AS IS" BASIS,
302- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
303- implied, including, without limitation, any warranties or conditions
304- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
305- PARTICULAR PURPOSE. You are solely responsible for determining the
306- appropriateness of using or redistributing the Work and assume any
307- risks associated with Your exercise of permissions under this License.
308-
309- 8. Limitation of Liability. In no event and under no legal theory,
310- whether in tort (including negligence), contract, or otherwise,
311- unless required by applicable law (such as deliberate and grossly
312- negligent acts) or agreed to in writing, shall any Contributor be
313- liable to You for damages, including any direct, indirect, special,
314- incidental, or consequential damages of any character arising as a
315- result of this License or out of the use or inability to use the
316- Work (including but not limited to damages for loss of goodwill,
317- work stoppage, computer failure or malfunction, or any and all
318- other commercial damages or losses), even if such Contributor
319- has been advised of the possibility of such damages.
320-
321- 9. Accepting Warranty or Additional Liability. While redistributing
322- the Work or Derivative Works thereof, You may choose to offer,
323- and charge a fee for, acceptance of support, warranty, indemnity,
324- or other liability obligations and/or rights consistent with this
325- License. However, in accepting such obligations, You may act only
326- on Your own behalf and on Your sole responsibility, not on behalf
327- of any other Contributor, and only if You agree to indemnify,
328- defend, and hold each Contributor harmless for any liability
329- incurred by, or claims asserted against, such Contributor by reason
330- of your accepting any such warranty or additional liability.
331-
332- END OF TERMS AND CONDITIONS
333-
334- APPENDIX: How to apply the Apache License to your work.
335-
336- To apply the Apache License to your work, attach the following
337- boilerplate notice, with the fields enclosed by brackets "[]"
338- replaced with your own identifying information. (Don't include
339- the brackets!) The text should be enclosed in the appropriate
340- comment syntax for the file format. We also recommend that a
341- file or class name and description of purpose be included on the
342- same "printed page" as the copyright notice for easier
343- identification within third-party archives.
344-
345- Copyright [yyyy] [name of copyright owner]
346-
347- Licensed under the Apache License, Version 2.0 (the "License");
348- you may not use this file except in compliance with the License.
349- You may obtain a copy of the License at
350-
351- http://www.apache.org/licenses/LICENSE-2.0
352-
353- Unless required by applicable law or agreed to in writing, software
354- distributed under the License is distributed on an "AS IS" BASIS,
355- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
356- See the License for the specific language governing permissions and
357- limitations under the License.
358diff --git a/THIRD_PARTY_LICENSES/software.sslmate.com/src/go-pkcs12/LICENSE b/THIRD_PARTY_LICENSES/software.sslmate.com/src/go-pkcs12/LICENSE
359index 6ac6b11..bcecd3d 100644
360--- a/THIRD_PARTY_LICENSES/software.sslmate.com/src/go-pkcs12/LICENSE
361+++ b/THIRD_PARTY_LICENSES/software.sslmate.com/src/go-pkcs12/LICENSE
362@@ -25,4 +25,4 @@ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
363 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
364 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
365 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
366-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
367\ No newline at end of file
368+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
369diff --git a/gce_workload_cert_refresh/main.go b/gce_workload_cert_refresh/main.go
370index adb83aa..c3c0b7d 100644
371--- a/gce_workload_cert_refresh/main.go
372+++ b/gce_workload_cert_refresh/main.go
373@@ -16,11 +16,15 @@
374 package main
375
376 import (
377+ "bytes"
378 "context"
379 "encoding/json"
380 "fmt"
381 "io"
382 "os"
383+ "path"
384+ "path/filepath"
385+ "strings"
386 "time"
387
388 "github.com/GoogleCloudPlatform/guest-agent/metadata"
389@@ -28,15 +32,28 @@ import (
390 )
391
392 const (
393- contentDirPrefix = "/run/secrets/workload-spiffe-contents"
394+ // trustAnchorsKey endpoint contains a set of trusted certificates for peer X.509 certificate chain validation.
395+ trustAnchorsKey = "instance/gce-workload-certificates/trust-anchors"
396+ // workloadIdentitiesKey endpoint contains identities managed by the GCE control plane. This contains the X.509 certificate and the private key for the VM's trust domain.
397+ workloadIdentitiesKey = "instance/gce-workload-certificates/workload-identities"
398+ // configStatusKey contains status and any errors in the config values provided via the VM metadata.
399+ configStatusKey = "instance/gce-workload-certificates/config-status"
400+ // enableWorkloadCertsKey is set to true as custom metadata to enable automatic provisioning of credentials.
401+ enableWorkloadCertsKey = "instance/attributes/enable-workload-certificate"
402+ // contentDirPrefix is used as prefx to create certificate directories on refresh as contentDirPrefix-<time>.
403+ contentDirPrefix = "/run/secrets/workload-spiffe-contents"
404+ // tempSymlinkPrefix is used as prefix to create temporary symlinks on refresh as tempSymlinkPrefix-<time> to content directories.
405 tempSymlinkPrefix = "/run/secrets/workload-spiffe-symlink"
406- symlink = "/run/secrets/workload-spiffe-credentials"
407- programName = "gce_workload_certs_refresh"
408+ // symlink points to the directory with current GCE workload certificates and is always expected to be present.
409+ symlink = "/run/secrets/workload-spiffe-credentials"
410 )
411
412 var (
413 // mdsClient is the client used to query Metadata server.
414- mdsClient *metadata.Client
415+ mdsClient metadata.MDSClientInterface
416+ programName = path.Base(os.Args[0])
417+ // timeNow returns current time, defining as variable allows the time to be stubbed during testing.
418+ timeNow = func() string { return time.Now().Format(time.RFC3339) }
419 )
420
421 func init() {
422@@ -48,138 +65,86 @@ func logFormat(e logger.LogEntry) string {
423 return fmt.Sprintf("%s: %s", now, e.Message)
424 }
425
426+// isEnabled returns true only if enable-workload-certificate metadata attribute is present and set to true.
427+func isEnabled(ctx context.Context) bool {
428+ resp, err := getMetadata(ctx, enableWorkloadCertsKey)
429+ if err != nil {
430+ logger.Debugf("Failed to get %q from MDS with error: %v", enableWorkloadCertsKey, err)
431+ return false
432+ }
433+
434+ return bytes.EqualFold(resp, []byte("true"))
435+}
436+
437 func getMetadata(ctx context.Context, key string) ([]byte, error) {
438 // GCE Workload Certificate endpoints return 412 Precondition failed if the VM was
439 // never configured with valid config values at least once. Without valid config
440 // values GCE cannot provision the workload certificates.
441 resp, err := mdsClient.GetKey(ctx, key, nil)
442 if err != nil {
443- return []byte{}, fmt.Errorf("failed to GET %q from MDS with error: %w", key, err)
444+ return nil, fmt.Errorf("failed to GET %q from MDS with error: %w", key, err)
445 }
446 return []byte(resp), nil
447 }
448
449 /*
450 metadata key instance/gce-workload-certificates/workload-identities
451-
452- {
453- "status": "OK",
454- "workloadCredentials": {
455- "PROJECT_ID.svc.id.goog": {
456- "metadata": {
457- "workload_creds_dir_path": "/var/run/secrets/workload-spiffe-credentials"
458- },
459- "certificatePem": "-----BEGIN CERTIFICATE-----datahere-----END CERTIFICATE-----",
460- "privateKeyPem": "-----BEGIN PRIVATE KEY-----datahere-----END PRIVATE KEY-----"
461- }
462- }
463- }
464-*/
465-
466-// WorkloadIdentities represents Workload Identities in metadata.
467-type WorkloadIdentities struct {
468- Status string
469- WorkloadCredentials map[string]WorkloadCredential
470-}
471-
472-// UnmarshalJSON is a custom JSON unmarshaller for WorkloadIdentities.
473-func (wi *WorkloadIdentities) UnmarshalJSON(b []byte) error {
474- tmp := map[string]json.RawMessage{}
475- err := json.Unmarshal(b, &tmp)
476- if err != nil {
477- return err
478- }
479-
480- if err := json.Unmarshal(tmp["status"], &wi.Status); err != nil {
481- return err
482- }
483-
484- wi.WorkloadCredentials = map[string]WorkloadCredential{}
485- wcs := map[string]json.RawMessage{}
486- if err := json.Unmarshal(tmp["workloadCredentials"], &wcs); err != nil {
487- return err
488- }
489-
490- for domain, value := range wcs {
491- wc := WorkloadCredential{}
492- err := json.Unmarshal(value, &wc)
493- if err != nil {
494- return err
495+MANAGED_WORKLOAD_IDENTITY_SPIFFE is of the format:
496+spiffe://POOL_ID.global.PROJECT_NUMBER.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID
497+
498+{
499+ "status": "OK", // Status of the response,
500+ "workloadCredentials": { // Credentials for the VM's trust domains
501+ "MANAGED_WORKLOAD_IDENTITY_SPIFFE": {
502+ "certificatePem": "-----BEGIN CERTIFICATE-----datahere-----END CERTIFICATE-----",
503+ "privateKeyPem": "-----BEGIN PRIVATE KEY-----datahere-----END PRIVATE KEY-----"
504 }
505- wi.WorkloadCredentials[domain] = wc
506 }
507-
508- return nil
509 }
510+*/
511
512 // WorkloadCredential represents Workload Credentials in metadata.
513 type WorkloadCredential struct {
514- Metadata Metadata
515- CertificatePem string
516- PrivateKeyPem string
517+ CertificatePem string `json:"certificatePem"`
518+ PrivateKeyPem string `json:"privateKeyPem"`
519 }
520
521-/*
522-metadata key instance/gce-workload-certificates/root-certs
523-
524- {
525- "status": "OK",
526- "rootCertificates": {
527- "PROJECT.svc.id.goog": {
528- "metadata": {
529- "workload_creds_dir_path": "/var/run/secrets/workload-spiffe-credentials"
530- },
531- "rootCertificatesPem": "-----BEGIN CERTIFICATE-----datahere-----END CERTIFICATE-----"
532- }
533- }
534- }
535-*/
536-
537-// WorkloadTrustedRootCerts represents Workload Trusted Root Certs in metadata.
538-type WorkloadTrustedRootCerts struct {
539- Status string
540- RootCertificates map[string]RootCertificate
541+// WorkloadIdentities represents Workload Identities in metadata.
542+type WorkloadIdentities struct {
543+ Status string `json:"status"`
544+ WorkloadCredentials map[string]WorkloadCredential `json:"workloadCredentials"`
545 }
546
547-// UnmarshalJSON is a custom JSON unmarshaller for WorkloadTrustedRootCerts
548-func (wtrc *WorkloadTrustedRootCerts) UnmarshalJSON(b []byte) error {
549- tmp := map[string]json.RawMessage{}
550- err := json.Unmarshal(b, &tmp)
551- if err != nil {
552- return err
553- }
554-
555- if err := json.Unmarshal(tmp["status"], &wtrc.Status); err != nil {
556- return err
557- }
558-
559- wtrc.RootCertificates = map[string]RootCertificate{}
560- rcs := map[string]json.RawMessage{}
561- if err := json.Unmarshal(tmp["rootCertificates"], &rcs); err != nil {
562- return err
563- }
564-
565- for domain, value := range rcs {
566- rc := RootCertificate{}
567- err := json.Unmarshal(value, &rc)
568- if err != nil {
569- return err
570- }
571- wtrc.RootCertificates[domain] = rc
572- }
573+/*
574+metadata key instance/gce-workload-certificates/trust-anchors
575+
576+{
577+ "status": "<status string>" // Status of the response,
578+ "trustAnchors": { // Trust bundle for the VM's trust domains
579+ "PEER_SPIFFE_TRUST_DOMAIN_1": {
580+ "trustAnchorsPem" : "<Trust bundle containing the X.509 roots certificates>",
581+ },
582+ "PEER_SPIFFE_TRUST_DOMAIN_2": {
583+ "trustAnchorsPem" : "<Trust bundle containing the X.509 roots certificates>",
584+ }
585+ }
586+}
587+*/
588
589- return nil
590+// TrustAnchor represents one or more certificates in an arbitrary order in the metadata.
591+type TrustAnchor struct {
592+ TrustAnchorsPem string `json:"trustAnchorsPem"`
593 }
594
595-// RootCertificate represents a Root Certificate in metadata
596-type RootCertificate struct {
597- Metadata Metadata
598- RootCertificatesPem string
599+// WorkloadTrustedAnchors represents Workload Trusted Root Certs in metadata.
600+type WorkloadTrustedAnchors struct {
601+ Status string `json:"status"`
602+ TrustAnchors map[string]TrustAnchor `json:"trustAnchors"`
603 }
604
605-// Metadata represents Metadata in metadata
606-type Metadata struct {
607- WorkloadCredsDirPath string
608+// outputOpts is a struct for output directory name and symlink templates.
609+type outputOpts struct {
610+ contentDirPrefix, tempSymlinkPrefix, symlink string
611 }
612
613 func main() {
614@@ -193,37 +158,102 @@ func main() {
615 }
616
617 opts.Writers = []io.Writer{os.Stderr}
618- logger.Init(ctx, opts)
619+
620+ if err := logger.Init(ctx, opts); err != nil {
621+ fmt.Printf("Error initializing logger: %v", err)
622+ os.Exit(1)
623+ }
624+
625 defer logger.Infof("Done")
626
627- // TODO: prune old dirs
628- if err := refreshCreds(ctx); err != nil {
629+ if !isEnabled(ctx) {
630+ logger.Debugf("GCE Workload Certificate refresh is not enabled, skipping cert generation.")
631+ return
632+ }
633+
634+ out := outputOpts{contentDirPrefix, tempSymlinkPrefix, symlink}
635+ if err := refreshCreds(ctx, out); err != nil {
636 logger.Fatalf("Error refreshCreds: %v", err.Error())
637 }
638
639 }
640
641-func refreshCreds(ctx context.Context) error {
642- project, err := getMetadata(ctx, "project/project-id")
643+// findDomain finds the anchor matching with the domain from spiffeID.
644+// spiffeID is of the form -
645+// spiffe://POOL_ID.global.PROJECT_NUMBER.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID
646+// where domain is POOL_ID.global.PROJECT_NUMBER.workload.id.goog
647+// anchors is a map of various domains and their corresponding trust PEMs.
648+// However, if anchor map contains single entry it returns that without any check.
649+func findDomain(anchors map[string]TrustAnchor, spiffeID string) (string, error) {
650+ c := len(anchors)
651+ for k := range anchors {
652+ if c == 1 {
653+ return k, nil
654+ }
655+ if strings.Contains(spiffeID, k) {
656+ return k, nil
657+ }
658+ }
659+
660+ return "", fmt.Errorf("no matching trust anchor found")
661+}
662+
663+// writeTrustAnchors parses the input data, finds the domain from spiffeID and writes ca_certificate.pem
664+// in the destDir for that domain.
665+func writeTrustAnchors(wtrcsMd []byte, destDir, spiffeID string) error {
666+ wtrcs := WorkloadTrustedAnchors{}
667+ if err := json.Unmarshal(wtrcsMd, &wtrcs); err != nil {
668+ return fmt.Errorf("error unmarshaling workload trusted root certs: %v", err)
669+ }
670+
671+ // Currently there's only one trust anchor but there could be multipe trust anchors in future.
672+ // In either case we want the trust anchor with domain matching with the one in SPIFFE ID.
673+ domain, err := findDomain(wtrcs.TrustAnchors, spiffeID)
674 if err != nil {
675- return fmt.Errorf("error getting project ID: %v", err)
676+ return err
677 }
678
679+ return os.WriteFile(fmt.Sprintf("%s/ca_certificates.pem", destDir), []byte(wtrcs.TrustAnchors[domain].TrustAnchorsPem), 0644)
680+}
681+
682+// writeWorkloadIdentities parses the input data, writes the certificates.pem, private_key.pem files in the
683+// destDir, and returns the SPIFFE ID for which it wrote the certificates.
684+func writeWorkloadIdentities(destDir string, wisMd []byte) (string, error) {
685+ var spiffeID string
686+ wis := WorkloadIdentities{}
687+ if err := json.Unmarshal(wisMd, &wis); err != nil {
688+ return "", fmt.Errorf("error unmarshaling workload identities response: %w", err)
689+ }
690+
691+ // Its guaranteed to have single entry in workload credentials map.
692+ for k := range wis.WorkloadCredentials {
693+ spiffeID = k
694+ break
695+ }
696+
697+ if err := os.WriteFile(filepath.Join(destDir, "certificates.pem"), []byte(wis.WorkloadCredentials[spiffeID].CertificatePem), 0644); err != nil {
698+ return "", fmt.Errorf("error writing certificates.pem: %w", err)
699+ }
700+
701+ if err := os.WriteFile(filepath.Join(destDir, "private_key.pem"), []byte(wis.WorkloadCredentials[spiffeID].PrivateKeyPem), 0644); err != nil {
702+ return "", fmt.Errorf("error writing private_key.pem: %w", err)
703+ }
704+ return spiffeID, nil
705+}
706+
707+func refreshCreds(ctx context.Context, opts outputOpts) error {
708+ now := timeNow()
709+ contentDir := fmt.Sprintf("%s-%s", opts.contentDirPrefix, now)
710+ tempSymlink := fmt.Sprintf("%s-%s", opts.tempSymlinkPrefix, now)
711+
712 // Get status first so it can be written even when other endpoints are empty.
713- certConfigStatus, err := getMetadata(ctx, "instance/gce-workload-certificates/config-status")
714+ certConfigStatus, err := getMetadata(ctx, configStatusKey)
715 if err != nil {
716 // Return success when certs are not configured to avoid unnecessary systemd failed units.
717 logger.Infof("Error getting config status, workload certificates may not be configured: %v", err)
718 return nil
719 }
720
721- domain := fmt.Sprintf("%s.svc.id.goog", project)
722- logger.Infof("Rotating workload credentials for trust domain %s", domain)
723-
724- now := time.Now().Format(time.RFC3339)
725- contentDir := fmt.Sprintf("%s-%s", contentDirPrefix, now)
726- tempSymlink := fmt.Sprintf("%s-%s", tempSymlinkPrefix, now)
727-
728 logger.Infof("Creating timestamp contents dir %s", contentDir)
729
730 if err := os.MkdirAll(contentDir, 0755); err != nil {
731@@ -231,72 +261,59 @@ func refreshCreds(ctx context.Context) error {
732 }
733
734 // Write config_status first even if remaining endpoints are empty.
735- if err := os.WriteFile(fmt.Sprintf("%s/config_status", contentDir), certConfigStatus, 0644); err != nil {
736+ if err := os.WriteFile(filepath.Join(contentDir, "config_status"), certConfigStatus, 0644); err != nil {
737 return fmt.Errorf("error writing config_status: %v", err)
738 }
739
740 // Handles the edge case where the config values provided for the first time may be invalid. This ensures
741- // that the symlink directory alwasys exists and contains the config_status to surface config errors to the VM.
742- if _, err := os.Stat(symlink); os.IsNotExist(err) {
743+ // that the symlink directory always exists and contains the config_status to surface config errors to the VM.
744+ if _, err := os.Stat(opts.symlink); os.IsNotExist(err) {
745 logger.Infof("Creating new symlink %s", symlink)
746
747- if err := os.Symlink(contentDir, symlink); err != nil {
748+ if err := os.Symlink(contentDir, opts.symlink); err != nil {
749 return fmt.Errorf("error creating symlink: %v", err)
750 }
751 }
752
753 // Now get the rest of the content.
754- wisMd, err := getMetadata(ctx, "instance/gce-workload-certificates/workload-identities")
755+ wisMd, err := getMetadata(ctx, workloadIdentitiesKey)
756 if err != nil {
757 return fmt.Errorf("error getting workload-identities: %v", err)
758 }
759
760- wtrcsMd, err := getMetadata(ctx, "instance/gce-workload-certificates/root-certs")
761+ spiffeID, err := writeWorkloadIdentities(contentDir, wisMd)
762 if err != nil {
763- return fmt.Errorf("error getting workload-trusted-root-certs: %v", err)
764+ return fmt.Errorf("failed to write workload identities with error: %w", err)
765 }
766
767- wis := WorkloadIdentities{}
768- if err := json.Unmarshal(wisMd, &wis); err != nil {
769- return fmt.Errorf("error unmarshaling workload identities response: %v", err)
770- }
771-
772- wtrcs := WorkloadTrustedRootCerts{}
773- if err := json.Unmarshal(wtrcsMd, &wtrcs); err != nil {
774- return fmt.Errorf("error unmarshaling workload trusted root certs: %v", err)
775- }
776-
777- if err := os.WriteFile(fmt.Sprintf("%s/certificates.pem", contentDir), []byte(wis.WorkloadCredentials[domain].CertificatePem), 0644); err != nil {
778- return fmt.Errorf("error writing certificates.pem: %v", err)
779- }
780-
781- if err := os.WriteFile(fmt.Sprintf("%s/private_key.pem", contentDir), []byte(wis.WorkloadCredentials[domain].PrivateKeyPem), 0644); err != nil {
782- return fmt.Errorf("error writing private_key.pem: %v", err)
783+ wtrcsMd, err := getMetadata(ctx, trustAnchorsKey)
784+ if err != nil {
785+ return fmt.Errorf("error getting workload-trust-anchors: %v", err)
786 }
787
788- if err := os.WriteFile(fmt.Sprintf("%s/ca_certificates.pem", contentDir), []byte(wtrcs.RootCertificates[domain].RootCertificatesPem), 0644); err != nil {
789- return fmt.Errorf("error writing ca_certificates.pem: %v", err)
790+ if err := writeTrustAnchors(wtrcsMd, contentDir, spiffeID); err != nil {
791+ return fmt.Errorf("failed to write trust anchors: %w", err)
792 }
793
794 if err := os.Symlink(contentDir, tempSymlink); err != nil {
795 return fmt.Errorf("error creating temporary link: %v", err)
796 }
797
798- oldTarget, err := os.Readlink(symlink)
799+ oldTarget, err := os.Readlink(opts.symlink)
800 if err != nil {
801 logger.Infof("Error reading existing symlink: %v\n", err)
802 oldTarget = ""
803 }
804
805 // Only rotate on success of all steps above.
806- logger.Infof("Rotating symlink %s", symlink)
807+ logger.Infof("Rotating symlink %s", opts.symlink)
808
809- if err := os.Rename(tempSymlink, symlink); err != nil {
810+ if err := os.Rename(tempSymlink, opts.symlink); err != nil {
811 return fmt.Errorf("error rotating target link: %v", err)
812 }
813
814 // Clean up previous contents dir.
815- newTarget, err := os.Readlink(symlink)
816+ newTarget, err := os.Readlink(opts.symlink)
817 if err != nil {
818 return fmt.Errorf("error reading new symlink: %v, unable to remove old symlink target", err)
819 }
820diff --git a/gce_workload_cert_refresh/main_test.go b/gce_workload_cert_refresh/main_test.go
821new file mode 100644
822index 0000000..5edb1c8
823--- /dev/null
824+++ b/gce_workload_cert_refresh/main_test.go
825@@ -0,0 +1,495 @@
826+// Copyright 2023 Google LLC
827+
828+// Licensed under the Apache License, Version 2.0 (the "License");
829+// you may not use this file except in compliance with the License.
830+// You may obtain a copy of the License at
831+
832+// https://www.apache.org/licenses/LICENSE-2.0
833+
834+// Unless required by applicable law or agreed to in writing, software
835+// distributed under the License is distributed on an "AS IS" BASIS,
836+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
837+// See the License for the specific language governing permissions and
838+// limitations under the License.
839+
840+package main
841+
842+import (
843+ "context"
844+ "encoding/json"
845+ "fmt"
846+ "os"
847+ "path/filepath"
848+ "testing"
849+
850+ "github.com/GoogleCloudPlatform/guest-agent/metadata"
851+ "github.com/google/go-cmp/cmp"
852+)
853+
854+const (
855+ workloadRespTpl = `
856+ {
857+ "status": "OK",
858+ "workloadCredentials": {
859+ "%s": {
860+ "certificatePem": "%s",
861+ "privateKeyPem": "%s"
862+ }
863+ }
864+ }
865+ `
866+ trustAnchorRespTpl = `
867+ {
868+ "status": "Ok",
869+ "trustAnchors": {
870+ "%s": {
871+ "trustAnchorsPem": "%s"
872+ },
873+ "%s": {
874+ "trustAnchorsPem": "%s"
875+ }
876+ }
877+ }
878+ `
879+ testConfigStatusResp = `
880+ {
881+ "status": "Ok",
882+ }
883+ `
884+)
885+
886+func TestWorkloadIdentitiesUnmarshal(t *testing.T) {
887+ certPem := "-----BEGIN CERTIFICATE-----datahere-----END CERTIFICATE-----"
888+ pvtPem := "-----BEGIN PRIVATE KEY-----datahere-----END PRIVATE KEY-----"
889+ spiffe := "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID"
890+
891+ resp := fmt.Sprintf(workloadRespTpl, spiffe, certPem, pvtPem)
892+ want := WorkloadIdentities{
893+ Status: "OK",
894+ WorkloadCredentials: map[string]WorkloadCredential{
895+ spiffe: {
896+ CertificatePem: certPem,
897+ PrivateKeyPem: pvtPem,
898+ },
899+ },
900+ }
901+
902+ got := WorkloadIdentities{}
903+ if err := json.Unmarshal([]byte(resp), &got); err != nil {
904+ t.Errorf("WorkloadIdentities.UnmarshalJSON(%s) failed unexpectedly with error: %v", resp, err)
905+ }
906+
907+ if diff := cmp.Diff(want, got); diff != "" {
908+ t.Errorf("Workload identities diff (-want +got):\n%s", diff)
909+ }
910+}
911+
912+func TestTrustAnchorsUnmarshal(t *testing.T) {
913+ domain1 := "12345.global.67890.workload.id.goog"
914+ pem1 := "-----BEGIN CERTIFICATE-----datahere1-----END CERTIFICATE-----"
915+ domain2 := "PEER_SPIFFE_TRUST_DOMAIN_2"
916+ pem2 := "-----BEGIN CERTIFICATE-----datahere2-----END CERTIFICATE-----"
917+
918+ resp := fmt.Sprintf(trustAnchorRespTpl, domain1, pem1, domain2, pem2)
919+ want := WorkloadTrustedAnchors{
920+ Status: "Ok",
921+ TrustAnchors: map[string]TrustAnchor{
922+ domain1: {
923+ TrustAnchorsPem: pem1,
924+ },
925+ domain2: {
926+ TrustAnchorsPem: pem2,
927+ },
928+ },
929+ }
930+
931+ got := WorkloadTrustedAnchors{}
932+ if err := json.Unmarshal([]byte(resp), &got); err != nil {
933+ t.Errorf("WorkloadTrustedRootCerts.UnmarshalJSON(%s) failed unexpectedly with error: %v", resp, err)
934+ }
935+
936+ if diff := cmp.Diff(want, got); diff != "" {
937+ t.Errorf("Workload trusted anchors diff (-want +got):\n%s", diff)
938+ }
939+}
940+
941+func TestWriteTrustAnchors(t *testing.T) {
942+ spiffe := "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID"
943+ domain1 := "12345.global.67890.workload.id.goog"
944+ pem1 := "-----BEGIN CERTIFICATE-----datahere1-----END CERTIFICATE-----"
945+ domain2 := "PEER_SPIFFE_TRUST_DOMAIN_2"
946+ pem2 := "-----BEGIN CERTIFICATE-----datahere2-----END CERTIFICATE-----"
947+
948+ resp := fmt.Sprintf(trustAnchorRespTpl, domain1, pem1, domain2, pem2)
949+ dir := t.TempDir()
950+ if err := writeTrustAnchors([]byte(resp), dir, spiffe); err != nil {
951+ t.Errorf("writeTrustAnchors(%s,%s,%s) failed unexpectedly with error %v", resp, dir, spiffe, err)
952+ }
953+
954+ got, err := os.ReadFile(filepath.Join(dir, "ca_certificates.pem"))
955+ if err != nil {
956+ t.Errorf("failed to read file at %s with error: %v", filepath.Join(dir, "ca_certificates.pem"), err)
957+ }
958+ if string(got) != pem1 {
959+ t.Errorf("writeTrustAnchors(%s,%s,%s) wrote %q, expected to write %q", resp, dir, spiffe, string(got), pem1)
960+ }
961+}
962+
963+func TestWriteWorkloadIdentities(t *testing.T) {
964+ certPem := "-----BEGIN CERTIFICATE-----datahere-----END CERTIFICATE-----"
965+ pvtPem := "-----BEGIN PRIVATE KEY-----datahere-----END PRIVATE KEY-----"
966+ spiffe := "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID"
967+
968+ resp := fmt.Sprintf(workloadRespTpl, spiffe, certPem, pvtPem)
969+ dir := t.TempDir()
970+
971+ gotID, err := writeWorkloadIdentities(dir, []byte(resp))
972+ if err != nil {
973+ t.Errorf("writeWorkloadIdentities(%s,%s) failed unexpectedly with error %v", dir, resp, err)
974+ }
975+ if gotID != spiffe {
976+ t.Errorf("writeWorkloadIdentities(%s,%s) = %s, want %s", dir, resp, gotID, spiffe)
977+ }
978+
979+ gotCertPem, err := os.ReadFile(filepath.Join(dir, "certificates.pem"))
980+ if err != nil {
981+ t.Errorf("failed to read file at %s with error: %v", filepath.Join(dir, "certificates.pem"), err)
982+ }
983+ if string(gotCertPem) != certPem {
984+ t.Errorf("writeWorkloadIdentities(%s,%s) wrote %q, expected to write %q", dir, resp, string(gotCertPem), certPem)
985+ }
986+
987+ gotPvtPem, err := os.ReadFile(filepath.Join(dir, "private_key.pem"))
988+ if err != nil {
989+ t.Errorf("failed to read file at %s with error: %v", filepath.Join(dir, "private_key.pem"), err)
990+ }
991+ if string(gotPvtPem) != pvtPem {
992+ t.Errorf("writeWorkloadIdentities(%s,%s) wrote %q, expected to write %q", dir, resp, string(gotPvtPem), pvtPem)
993+ }
994+}
995+
996+func TestFindDomainError(t *testing.T) {
997+ anchors := map[string]TrustAnchor{
998+ "67890.global.12345.workload.id.goog": {},
999+ "55555.global.67890.workload.id.goog": {},
1000+ }
1001+ spiffeID := "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID"
1002+
1003+ if _, err := findDomain(anchors, spiffeID); err == nil {
1004+ t.Errorf("findDomain(%+v, %s) succeded for unknown anchors, want error", anchors, spiffeID)
1005+ }
1006+}
1007+
1008+func TestFindDomain(t *testing.T) {
1009+ tests := []struct {
1010+ desc string
1011+ anchors map[string]TrustAnchor
1012+ spiffeID string
1013+ want string
1014+ }{
1015+ {
1016+ desc: "single_trust_anchor",
1017+ anchors: map[string]TrustAnchor{"12345.global.67890.workload.id.goog": {}},
1018+ spiffeID: "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID",
1019+ want: "12345.global.67890.workload.id.goog",
1020+ },
1021+ {
1022+ desc: "multiple_trust_anchor",
1023+ anchors: map[string]TrustAnchor{
1024+ "67890.global.12345.workload.id.goog": {},
1025+ "12345.global.67890.workload.id.goog": {},
1026+ },
1027+ spiffeID: "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID",
1028+ want: "12345.global.67890.workload.id.goog",
1029+ },
1030+ }
1031+
1032+ for _, test := range tests {
1033+ t.Run(test.desc, func(t *testing.T) {
1034+ got, err := findDomain(test.anchors, test.spiffeID)
1035+ if err != nil {
1036+ t.Errorf("findDomain(%+v, %s) failed unexpectedly with error: %v", test.anchors, test.spiffeID, err)
1037+ }
1038+ if got != test.want {
1039+ t.Errorf("findDomain(%+v, %s) = %s, want %s", test.anchors, test.spiffeID, got, test.want)
1040+ }
1041+ })
1042+ }
1043+}
1044+
1045+func TestIsEnabled(t *testing.T) {
1046+ ctx := context.Background()
1047+
1048+ tests := []struct {
1049+ desc string
1050+ enabled string
1051+ want bool
1052+ err string
1053+ }{
1054+ {
1055+ desc: "attr_correctly_added",
1056+ enabled: "true",
1057+ want: true,
1058+ },
1059+ {
1060+ desc: "attr_incorrectly_added",
1061+ enabled: "blaah",
1062+ want: false,
1063+ },
1064+ {
1065+ desc: "attr_not_added",
1066+ want: false,
1067+ err: enableWorkloadCertsKey,
1068+ },
1069+ }
1070+
1071+ for _, test := range tests {
1072+ t.Run(test.desc, func(t *testing.T) {
1073+ mdsClient = &mdsTestClient{enabled: test.enabled, throwErrOn: test.err}
1074+ if got := isEnabled(ctx); got != test.want {
1075+ t.Errorf("isEnabled(ctx) = %t, want %t", got, test.want)
1076+ }
1077+ })
1078+ }
1079+}
1080+
1081+// mdsTestClient is fake client to stub MDS response in unit tests.
1082+type mdsTestClient struct {
1083+ // Is credential generation enabled.
1084+ enabled string
1085+ // Workload template.
1086+ spiffe, certPem, pvtPem string
1087+ // Trust Anchor template.
1088+ domain1, pem1, domain2, pem2 string
1089+ // Throw error on MDS request for "key".
1090+ throwErrOn string
1091+}
1092+
1093+func (mds *mdsTestClient) Get(ctx context.Context) (*metadata.Descriptor, error) {
1094+ return nil, fmt.Errorf("Get() not yet implemented")
1095+}
1096+
1097+func (mds *mdsTestClient) GetKey(ctx context.Context, key string, headers map[string]string) (string, error) {
1098+ if mds.throwErrOn == key {
1099+ return "", fmt.Errorf("this is fake error for testing")
1100+ }
1101+
1102+ switch key {
1103+ case enableWorkloadCertsKey:
1104+ return mds.enabled, nil
1105+ case configStatusKey:
1106+ return testConfigStatusResp, nil
1107+ case workloadIdentitiesKey:
1108+ return fmt.Sprintf(workloadRespTpl, mds.spiffe, mds.certPem, mds.pvtPem), nil
1109+ case trustAnchorsKey:
1110+ return fmt.Sprintf(trustAnchorRespTpl, mds.domain1, mds.pem1, mds.domain2, mds.pem2), nil
1111+ default:
1112+ return "", fmt.Errorf("unknown key %q", key)
1113+ }
1114+}
1115+
1116+func (mds *mdsTestClient) GetKeyRecursive(ctx context.Context, key string) (string, error) {
1117+ return "", fmt.Errorf("GetKeyRecursive() not yet implemented")
1118+}
1119+
1120+func (mds *mdsTestClient) Watch(ctx context.Context) (*metadata.Descriptor, error) {
1121+ return nil, fmt.Errorf("Watch() not yet implemented")
1122+}
1123+
1124+func (mds *mdsTestClient) WriteGuestAttributes(ctx context.Context, key string, value string) error {
1125+ return fmt.Errorf("WriteGuestattributes() not yet implemented")
1126+}
1127+
1128+func TestRefreshCreds(t *testing.T) {
1129+ ctx := context.Background()
1130+ tmp := t.TempDir()
1131+
1132+ // Templates to use in iterations.
1133+ spiffeTpl := "spiffe://12345.global.67890.workload.id.goog.%d/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID"
1134+ domain1Tpl := "12345.global.67890.workload.id.goog.%d"
1135+ pem1Tpl := "-----BEGIN CERTIFICATE-----datahere1.%d-----END CERTIFICATE-----"
1136+ domain2 := "PEER_SPIFFE_TRUST_DOMAIN_2_IGNORE"
1137+ pem2Tpl := "-----BEGIN CERTIFICATE-----datahere2.%d-----END CERTIFICATE-----"
1138+ certPemTpl := "-----BEGIN CERTIFICATE-----datahere.%d-----END CERTIFICATE-----"
1139+ pvtPemTpl := "-----BEGIN PRIVATE KEY-----datahere.%d-----END PRIVATE KEY-----"
1140+
1141+ contentPrefix := filepath.Join(tmp, "workload-spiffe-contents")
1142+ tmpSymlinkPrefix := filepath.Join(tmp, "workload-spiffe-symlink")
1143+ link := filepath.Join(tmp, "workload-spiffe-credentials")
1144+ out := outputOpts{contentPrefix, tmpSymlinkPrefix, link}
1145+
1146+ // Run refresh creds thrice to test updates.
1147+ // Link (workload-spiffe-credentials) should always refer to the updated content
1148+ // and previous directories should be removed.
1149+ for i := 1; i <= 3; i++ {
1150+ timeNow = func() string { return fmt.Sprintf("%d", i) }
1151+ spiffe := fmt.Sprintf(spiffeTpl, i)
1152+ domain1 := fmt.Sprintf(domain1Tpl, i)
1153+ pem1 := fmt.Sprintf(pem1Tpl, i)
1154+ pem2 := fmt.Sprintf(pem2Tpl, i)
1155+ certPem := fmt.Sprintf(certPemTpl, i)
1156+ pvtPem := fmt.Sprintf(pvtPemTpl, i)
1157+
1158+ mdsClient = &mdsTestClient{
1159+ spiffe: spiffe,
1160+ certPem: certPem,
1161+ pvtPem: pvtPem,
1162+ domain1: domain1,
1163+ pem1: pem1,
1164+ domain2: domain2,
1165+ pem2: pem2,
1166+ }
1167+
1168+ if err := refreshCreds(ctx, out); err != nil {
1169+ t.Errorf("refreshCreds(ctx, %+v) failed unexpectedly with error: %v", out, err)
1170+ }
1171+
1172+ // Verify all files are created with the content as expected.
1173+ tests := []struct {
1174+ path string
1175+ content string
1176+ }{
1177+ {
1178+ path: filepath.Join(link, "ca_certificates.pem"),
1179+ content: pem1,
1180+ },
1181+ {
1182+ path: filepath.Join(link, "certificates.pem"),
1183+ content: certPem,
1184+ },
1185+ {
1186+ path: filepath.Join(link, "private_key.pem"),
1187+ content: pvtPem,
1188+ },
1189+ {
1190+ path: filepath.Join(link, "config_status"),
1191+ content: testConfigStatusResp,
1192+ },
1193+ }
1194+
1195+ for _, test := range tests {
1196+ t.Run(test.path, func(t *testing.T) {
1197+ got, err := os.ReadFile(test.path)
1198+ if err != nil {
1199+ t.Errorf("failed to read expected file %q and content %q with error: %v", test.path, test.content, err)
1200+ }
1201+ if string(got) != test.content {
1202+ t.Errorf("refreshCreds(ctx, %+v) wrote %q, want content %q", out, string(got), test.content)
1203+ }
1204+ })
1205+ }
1206+
1207+ // Verify the symlink was created and references the right destination directory.
1208+ want := fmt.Sprintf("%s-%d", contentPrefix, i)
1209+ got, err := os.Readlink(link)
1210+ if err != nil {
1211+ t.Errorf("os.Readlink(%s) failed unexpectedly with error %v", link, err)
1212+ }
1213+ if got != want {
1214+ t.Errorf("os.Readlink(%s) = %s, want %s", link, got, want)
1215+ }
1216+
1217+ // If its not first run make sure prev creds are deleted.
1218+ if i > 1 {
1219+ prevDir := fmt.Sprintf("%s-%d", contentPrefix, i-1)
1220+ if _, err := os.Stat(prevDir); err == nil {
1221+ t.Errorf("os.Stat(%s) succeeded on prev content directory, want error", prevDir)
1222+ }
1223+ }
1224+ }
1225+}
1226+
1227+func TestRefreshCredsError(t *testing.T) {
1228+ ctx := context.Background()
1229+ tmp := t.TempDir()
1230+
1231+ // Templates to use in iterations.
1232+ spiffe := "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID"
1233+ domain1 := "12345.global.67890.workload.id.goog"
1234+ pem1 := "-----BEGIN CERTIFICATE-----datahere1-----END CERTIFICATE-----"
1235+ domain2 := "PEER_SPIFFE_TRUST_DOMAIN_2_IGNORE"
1236+ pem2 := "-----BEGIN CERTIFICATE-----datahere2-----END CERTIFICATE-----"
1237+ certPem := "-----BEGIN CERTIFICATE-----datahere-----END CERTIFICATE-----"
1238+ pvtPem := "-----BEGIN PRIVATE KEY-----datahere-----END PRIVATE KEY-----"
1239+
1240+ contentPrefix := filepath.Join(tmp, "workload-spiffe-contents")
1241+ tmpSymlinkPrefix := filepath.Join(tmp, "workload-spiffe-symlink")
1242+ link := filepath.Join(tmp, "workload-spiffe-credentials")
1243+ out := outputOpts{contentPrefix, tmpSymlinkPrefix, link}
1244+
1245+ client := &mdsTestClient{
1246+ spiffe: spiffe,
1247+ certPem: certPem,
1248+ pvtPem: pvtPem,
1249+ domain1: domain1,
1250+ pem1: pem1,
1251+ domain2: domain2,
1252+ pem2: pem2,
1253+ }
1254+
1255+ mdsClient = client
1256+
1257+ // Run refresh creds twice. First run would succeed and second would fail. Verify all
1258+ // creds generated on the first run are present as is after failed second run.
1259+ for i := 1; i <= 2; i++ {
1260+ timeNow = func() string { return fmt.Sprintf("%d", i) }
1261+
1262+ if i == 1 {
1263+ // First run should succeed.
1264+ if err := refreshCreds(ctx, out); err != nil {
1265+ t.Errorf("refreshCreds(ctx, %+v) failed unexpectedly with error: %v", out, err)
1266+ }
1267+ } else if i == 2 {
1268+ // Second run should fail. Fail in getting last metadata entry.
1269+ client.throwErrOn = trustAnchorsKey
1270+ if err := refreshCreds(ctx, out); err == nil {
1271+ t.Errorf("refreshCreds(ctx, %+v) succeeded for fake metadata error, should've failed", out)
1272+ }
1273+ }
1274+
1275+ // Verify all files are created and are still present with the content as expected.
1276+ tests := []struct {
1277+ path string
1278+ content string
1279+ }{
1280+ {
1281+ path: filepath.Join(link, "ca_certificates.pem"),
1282+ content: pem1,
1283+ },
1284+ {
1285+ path: filepath.Join(link, "certificates.pem"),
1286+ content: certPem,
1287+ },
1288+ {
1289+ path: filepath.Join(link, "private_key.pem"),
1290+ content: pvtPem,
1291+ },
1292+ {
1293+ path: filepath.Join(link, "config_status"),
1294+ content: testConfigStatusResp,
1295+ },
1296+ }
1297+
1298+ for _, test := range tests {
1299+ t.Run(test.path, func(t *testing.T) {
1300+ got, err := os.ReadFile(test.path)
1301+ if err != nil {
1302+ t.Errorf("failed to read expected file %q and content %q with error: %v", test.path, test.content, err)
1303+ }
1304+ if string(got) != test.content {
1305+ t.Errorf("refreshCreds(ctx, %+v) wrote %q, want content %q", out, string(got), test.content)
1306+ }
1307+ })
1308+ }
1309+
1310+ // Verify the symlink was created and references the same destination directory.
1311+ want := fmt.Sprintf("%s-%d", contentPrefix, 1)
1312+ got, err := os.Readlink(link)
1313+ if err != nil {
1314+ t.Errorf("os.Readlink(%s) failed unexpectedly with error %v", link, err)
1315+ }
1316+ if got != want {
1317+ t.Errorf("os.Readlink(%s) = %s, want %s", link, got, want)
1318+ }
1319+ }
1320+}
1321diff --git a/go.mod b/go.mod
1322index d5f8c28..1e4ea08 100644
1323--- a/go.mod
1324+++ b/go.mod
1325@@ -10,12 +10,14 @@ require (
1326 github.com/go-ini/ini v1.66.6
1327 github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da
1328 github.com/golang/protobuf v1.5.3
1329+ github.com/google/go-cmp v0.5.9
1330 github.com/google/go-tpm v0.9.0
1331 github.com/google/go-tpm-tools v0.4.0
1332 github.com/google/tink/go v1.7.0
1333 github.com/kardianos/service v1.2.1
1334 github.com/robfig/cron/v3 v3.0.1
1335 github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07
1336+ golang.org/x/crypto v0.11.0
1337 golang.org/x/sys v0.11.0
1338 google.golang.org/grpc v1.57.0
1339 google.golang.org/protobuf v1.31.0
1340@@ -29,7 +31,7 @@ require (
1341 cloud.google.com/go/iam v1.1.1 // indirect
1342 cloud.google.com/go/logging v1.7.0 // indirect
1343 cloud.google.com/go/longrunning v0.5.1 // indirect
1344- github.com/google/go-cmp v0.5.9 // indirect
1345+ github.com/Microsoft/go-winio v0.6.1 // indirect
1346 github.com/google/go-sev-guest v0.7.0 // indirect
1347 github.com/google/logger v1.1.1 // indirect
1348 github.com/google/s2a-go v0.1.4 // indirect
1349@@ -39,11 +41,12 @@ require (
1350 github.com/pborman/uuid v1.2.1 // indirect
1351 github.com/pkg/errors v0.9.1 // indirect
1352 go.opencensus.io v0.24.0 // indirect
1353- golang.org/x/crypto v0.11.0 // indirect
1354+ golang.org/x/mod v0.8.0 // indirect
1355 golang.org/x/net v0.12.0 // indirect
1356 golang.org/x/oauth2 v0.10.0 // indirect
1357 golang.org/x/sync v0.3.0 // indirect
1358 golang.org/x/text v0.11.0 // indirect
1359+ golang.org/x/tools v0.6.0 // indirect
1360 golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
1361 google.golang.org/api v0.134.0 // indirect
1362 google.golang.org/appengine v1.6.7 // indirect
1363diff --git a/go.sum b/go.sum
1364index 55c71c0..9c5c73b 100644
1365--- a/go.sum
1366+++ b/go.sum
1367@@ -17,6 +17,8 @@ cloud.google.com/go/storage v1.31.0/go.mod h1:81ams1PrhW16L4kF7qg+4mTq7SRs5HsbDT
1368 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
1369 github.com/GoogleCloudPlatform/guest-logging-go v0.0.0-20230710215706-450679fd88a9 h1:b3geIwOPAShYtR4F0XFt+2NJXTHVTfbxUFmrpiZXHdQ=
1370 github.com/GoogleCloudPlatform/guest-logging-go v0.0.0-20230710215706-450679fd88a9/go.mod h1:6ZqSUIZRAPR5dNMWJ+FwIarFFQ9t5qalaKQs20o6h+I=
1371+github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
1372+github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
1373 github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
1374 github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
1375 github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
1376@@ -94,8 +96,10 @@ github.com/googleapis/enterprise-certificate-proxy v0.2.5/go.mod h1:RxW0N9901Cko
1377 github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas=
1378 github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU=
1379 github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
1380+github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
1381 github.com/kardianos/service v1.2.1 h1:AYndMsehS+ywIS6RB9KOlcXzteWUzxgMgBymJD7+BYk=
1382 github.com/kardianos/service v1.2.1/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
1383+github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
1384 github.com/pborman/uuid v1.2.1 h1:+ZZIw58t/ozdjRaXh/3awHfmWRbzYxJoAdNJxe/3pvw=
1385 github.com/pborman/uuid v1.2.1/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
1386 github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
1387@@ -134,6 +138,8 @@ golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTk
1388 golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
1389 golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
1390 golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
1391+golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
1392+golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
1393 golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
1394 golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
1395 golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
1396@@ -176,6 +182,7 @@ golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
1397 golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
1398 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
1399 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
1400+golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
1401 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
1402 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
1403 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
1404@@ -191,6 +198,8 @@ golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3
1405 golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
1406 golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
1407 golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
1408+golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
1409+golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
1410 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
1411 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
1412 golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
1413diff --git a/google_authorized_keys/main.go b/google_authorized_keys/main.go
1414index 4961397..634ab5a 100644
1415--- a/google_authorized_keys/main.go
1416+++ b/google_authorized_keys/main.go
1417@@ -20,24 +20,27 @@ import (
1418 "encoding/json"
1419 "fmt"
1420 "io"
1421- "net/http"
1422 "os"
1423+ "path"
1424 "runtime"
1425 "strconv"
1426 "strings"
1427 "time"
1428
1429+ "github.com/GoogleCloudPlatform/guest-agent/metadata"
1430 "github.com/GoogleCloudPlatform/guest-agent/utils"
1431 "github.com/GoogleCloudPlatform/guest-logging-go/logger"
1432 )
1433
1434 var (
1435- programName = "GoogleAuthorizedKeysCommand"
1436- metadataURL = "http://169.254.169.254/computeMetadata/v1/"
1437- metadataRecursive = "/?recursive=true&alt=json"
1438- defaultTimeout = 2 * time.Second
1439+ client metadata.MDSClientInterface
1440+ programName = path.Base(os.Args[0])
1441 )
1442
1443+func init() {
1444+ client = metadata.New()
1445+}
1446+
1447 func logFormat(e logger.LogEntry) string {
1448 now := time.Now().Format("2006/01/02 15:04:05")
1449 return fmt.Sprintf("%s %s: %s", now, programName, e.Message)
1450@@ -49,45 +52,6 @@ func logFormatWindows(e logger.LogEntry) string {
1451 return fmt.Sprintf("%s %s: %s", now, programName, e.Message)
1452 }
1453
1454-func getMetadata(key string, recurse bool) ([]byte, error) {
1455- client := &http.Client{
1456- Timeout: defaultTimeout,
1457- }
1458-
1459- url := metadataURL + key
1460- if recurse {
1461- url += metadataRecursive
1462- }
1463- req, err := http.NewRequest("GET", url, nil)
1464- if err != nil {
1465- return nil, err
1466- }
1467- req.Header.Add("Metadata-Flavor", "Google")
1468-
1469- var res *http.Response
1470- // Retry forever, increase sleep between retries (up to 5 times) in order
1471- // to wait for slow network initialization.
1472- var rt time.Duration
1473- for i := 1; ; i++ {
1474- res, err = client.Do(req)
1475- if err == nil {
1476- break
1477- }
1478- if i < 6 {
1479- rt = time.Duration(3*i) * time.Second
1480- }
1481- logger.Errorf("error connecting to metadata server, retrying in %s, error: %v", rt, err)
1482- time.Sleep(rt)
1483- }
1484- defer res.Body.Close()
1485-
1486- md, err := io.ReadAll(res.Body)
1487- if err != nil {
1488- return nil, err
1489- }
1490- return md, nil
1491-}
1492-
1493 func parseSSHKeys(username string, keys []string) []string {
1494 var keyList []string
1495 for _, key := range keys {
1496@@ -143,7 +107,7 @@ type attributes struct {
1497 SSHKeys []string
1498 }
1499
1500-func getMetadataAttributes(metadataKey string) (*attributes, error) {
1501+func getMetadataAttributes(ctx context.Context, metadataKey string) (*attributes, error) {
1502 var a attributes
1503 type jsonAttributes struct {
1504 EnableWindowsSSH string `json:"enable-windows-ssh"`
1505@@ -151,11 +115,12 @@ func getMetadataAttributes(metadataKey string) (*attributes, error) {
1506 SSHKeys string `json:"ssh-keys"`
1507 }
1508 var ja jsonAttributes
1509- metadata, err := getMetadata(metadataKey, true)
1510+ metadata, err := client.GetKeyRecursive(ctx, metadataKey)
1511 if err != nil {
1512 return nil, err
1513 }
1514- if err := json.Unmarshal(metadata, &ja); err != nil {
1515+
1516+ if err := json.Unmarshal([]byte(metadata), &ja); err != nil {
1517 return nil, err
1518 }
1519
1520@@ -191,12 +156,12 @@ func main() {
1521 }
1522 logger.Init(ctx, opts)
1523
1524- instanceAttributes, err := getMetadataAttributes("instance/attributes/")
1525+ instanceAttributes, err := getMetadataAttributes(ctx, "instance/attributes/")
1526 if err != nil {
1527 logger.Errorf("Cannot read instance metadata attributes: %v", err)
1528 os.Exit(1)
1529 }
1530- projectAttributes, err := getMetadataAttributes("project/attributes/")
1531+ projectAttributes, err := getMetadataAttributes(ctx, "project/attributes/")
1532 if err != nil {
1533 logger.Errorf("Cannot read project metadata attributes: %v", err)
1534 os.Exit(1)
1535diff --git a/google_authorized_keys/main_test.go b/google_authorized_keys/main_test.go
1536index 49315f1..ea739a1 100644
1537--- a/google_authorized_keys/main_test.go
1538+++ b/google_authorized_keys/main_test.go
1539@@ -15,14 +15,15 @@
1540 package main
1541
1542 import (
1543+ "context"
1544 "fmt"
1545- "net/http"
1546- "net/http/httptest"
1547 "reflect"
1548 "strconv"
1549 "strings"
1550 "testing"
1551- "time"
1552+
1553+ "github.com/GoogleCloudPlatform/guest-agent/metadata"
1554+ "github.com/GoogleCloudPlatform/guest-agent/utils"
1555 )
1556
1557 func stringSliceEqual(a, b []string) bool {
1558@@ -50,16 +51,20 @@ var truebool *bool = &t
1559 var falsebool *bool = &f
1560
1561 func TestParseSSHKeys(t *testing.T) {
1562+ pubKeyA := utils.MakeRandRSAPubKey(t)
1563+ pubKeyB := utils.MakeRandRSAPubKey(t)
1564+ pubKey := utils.MakeRandRSAPubKey(t)
1565+
1566 keys := []string{
1567 "# Here is some random data in the file.",
1568- "usera:ssh-rsa AAAA1234USERA",
1569- "userb:ssh-rsa AAAA1234USERB",
1570- `usera:ssh-rsa AAAA1234 google-ssh {"userName":"usera@example.com","expireOn":"2095-04-23T12:34:56+0000"}`,
1571- `usera:ssh-rsa AAAA1234 google-ssh {"userName":"usera@example.com","expireOn":"2020-04-23T12:34:56+0000"}`,
1572+ fmt.Sprintf("usera:ssh-rsa %s", pubKeyA),
1573+ fmt.Sprintf("userb:ssh-rsa %s", pubKeyB),
1574+ fmt.Sprintf(`usera:ssh-rsa %s google-ssh {"userName":"usera@example.com","expireOn":"2095-04-23T12:34:56+0000"}`, pubKey),
1575+ fmt.Sprintf(`usera:ssh-rsa %s google-ssh {"userName":"usera@example.com","expireOn":"2020-04-23T12:34:56+0000"}`, pubKey),
1576 }
1577 expected := []string{
1578- "ssh-rsa AAAA1234USERA",
1579- `ssh-rsa AAAA1234 google-ssh {"userName":"usera@example.com","expireOn":"2095-04-23T12:34:56+0000"}`,
1580+ fmt.Sprintf("ssh-rsa %s", pubKeyA),
1581+ fmt.Sprintf(`ssh-rsa %s google-ssh {"userName":"usera@example.com","expireOn":"2095-04-23T12:34:56+0000"}`, pubKey),
1582 }
1583
1584 user := "usera"
1585@@ -117,6 +122,8 @@ func TestCheckWinSSHEnabled(t *testing.T) {
1586 }
1587
1588 func TestGetUserKeysNew(t *testing.T) {
1589+ pubKey := utils.MakeRandRSAPubKey(t)
1590+
1591 tests := []struct {
1592 userName string
1593 instanceMetadata attributes
1594@@ -125,102 +132,114 @@ func TestGetUserKeysNew(t *testing.T) {
1595 }{
1596 {
1597 userName: "name",
1598- instanceMetadata: attributes{BlockProjectSSHKeys: false,
1599- SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"},
1600+ instanceMetadata: attributes{
1601+ BlockProjectSSHKeys: false,
1602+ SSHKeys: []string{
1603+ fmt.Sprintf("name:ssh-rsa %s instance1", pubKey),
1604+ fmt.Sprintf("othername:ssh-rsa %s instance2", pubKey),
1605+ },
1606 },
1607 projectMetadata: attributes{
1608- SSHKeys: []string{"name:ssh-rsa [KEY] project1", "othername:ssh-rsa [KEY] project2"},
1609+ SSHKeys: []string{
1610+ fmt.Sprintf("name:ssh-rsa %s project1", pubKey),
1611+ fmt.Sprintf("othername:ssh-rsa %s project2", pubKey),
1612+ },
1613+ },
1614+ expectedKeys: []string{
1615+ fmt.Sprintf("ssh-rsa %s instance1", pubKey),
1616+ fmt.Sprintf("ssh-rsa %s project1", pubKey),
1617 },
1618- expectedKeys: []string{"ssh-rsa [KEY] instance1", "ssh-rsa [KEY] project1"},
1619 },
1620 {
1621 userName: "name",
1622- instanceMetadata: attributes{BlockProjectSSHKeys: true,
1623- SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"},
1624+ instanceMetadata: attributes{
1625+ BlockProjectSSHKeys: true,
1626+ SSHKeys: []string{
1627+ fmt.Sprintf("name:ssh-rsa %s instance1", pubKey),
1628+ fmt.Sprintf("othername:ssh-rsa %s instance2", pubKey),
1629+ },
1630 },
1631 projectMetadata: attributes{
1632- SSHKeys: []string{"name:ssh-rsa [KEY] project1", "othername:ssh-rsa [KEY] project2"},
1633+ SSHKeys: []string{
1634+ fmt.Sprintf("name:ssh-rsa %s project1", pubKey),
1635+ fmt.Sprintf("othername:ssh-rsa %s project2", pubKey),
1636+ },
1637 },
1638- expectedKeys: []string{"ssh-rsa [KEY] instance1"},
1639+ expectedKeys: []string{fmt.Sprintf("ssh-rsa %s instance1", pubKey)},
1640 },
1641 {
1642 userName: "name",
1643- instanceMetadata: attributes{BlockProjectSSHKeys: false,
1644- SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"},
1645+ instanceMetadata: attributes{
1646+ BlockProjectSSHKeys: false,
1647+ SSHKeys: []string{
1648+ fmt.Sprintf("name:ssh-rsa %s instance1", pubKey),
1649+ fmt.Sprintf("othername:ssh-rsa %s instance2", pubKey),
1650+ },
1651 },
1652 projectMetadata: attributes{
1653 SSHKeys: nil,
1654 },
1655- expectedKeys: []string{"ssh-rsa [KEY] instance1"},
1656+ expectedKeys: []string{fmt.Sprintf("ssh-rsa %s instance1", pubKey)},
1657 },
1658 {
1659 userName: "name",
1660- instanceMetadata: attributes{BlockProjectSSHKeys: false,
1661- SSHKeys: nil,
1662+ instanceMetadata: attributes{
1663+ BlockProjectSSHKeys: false,
1664+ SSHKeys: nil,
1665 },
1666 projectMetadata: attributes{
1667- SSHKeys: []string{"name:ssh-rsa [KEY] project1", "othername:ssh-rsa [KEY] project2"},
1668+ SSHKeys: []string{
1669+ fmt.Sprintf("name:ssh-rsa %s project1", pubKey),
1670+ fmt.Sprintf("othername:ssh-rsa %s project2", pubKey),
1671+ },
1672 },
1673- expectedKeys: []string{"ssh-rsa [KEY] project1"},
1674+ expectedKeys: []string{fmt.Sprintf("ssh-rsa %s project1", pubKey)},
1675 },
1676 }
1677
1678 for count, tt := range tests {
1679- if got, want := getUserKeys(tt.userName, &tt.instanceMetadata, &tt.projectMetadata), tt.expectedKeys; !stringSliceEqual(got, want) {
1680- t.Errorf("getUserKeys[%d] incorrect return: got %v, want %v", count, got, want)
1681- }
1682+ t.Run(fmt.Sprintf("test-%d", count), func(t *testing.T) {
1683+ if got, want := getUserKeys(tt.userName, &tt.instanceMetadata, &tt.projectMetadata), tt.expectedKeys; !stringSliceEqual(got, want) {
1684+ t.Errorf("getUserKeys[%d] incorrect return: got %v, want %v", count, got, want)
1685+ }
1686+ })
1687 }
1688 }
1689
1690 func TestGetMetadataAttributes(t *testing.T) {
1691 tests := []struct {
1692- metadata string
1693 att *attributes
1694 expectErr bool
1695 }{
1696 {
1697- metadata: `{"enable-windows-ssh":"true","ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","block-project-ssh-keys":"false","other-metadata":"foo"}`,
1698 att: &attributes{EnableWindowsSSH: truebool, SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"}, BlockProjectSSHKeys: false},
1699 expectErr: false,
1700 },
1701 {
1702- metadata: `{"enable-windows-ssh":"true","ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","block-project-ssh-keys":"true","other-metadata":"foo"}`,
1703 att: &attributes{EnableWindowsSSH: truebool, SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"}, BlockProjectSSHKeys: true},
1704 expectErr: false,
1705 },
1706 {
1707- metadata: `{"ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","block-project-ssh-keys":"false","other-metadata":"foo"}`,
1708 att: &attributes{EnableWindowsSSH: nil, SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"}, BlockProjectSSHKeys: false},
1709 expectErr: false,
1710 },
1711 {
1712- metadata: `{"enable-windows-ssh":"false","ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","other-metadata":"foo"}`,
1713 att: &attributes{EnableWindowsSSH: falsebool, SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"}, BlockProjectSSHKeys: false},
1714 expectErr: false,
1715 },
1716 {
1717- metadata: `BADJSON`,
1718 att: nil,
1719 expectErr: true,
1720 },
1721 }
1722
1723- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1724- // Get test number from request path
1725- tnum, _ := strconv.Atoi(strings.Split(r.URL.Path, "/")[2])
1726- fmt.Fprintf(w, tests[tnum].metadata)
1727- }))
1728-
1729- defer ts.Close()
1730-
1731- metadataURL = ts.URL
1732- defaultTimeout = 1 * time.Second
1733+ client = &mdsClient{}
1734
1735 for count, tt := range tests {
1736 want := tt.att
1737 hasErr := false
1738 reqStr := fmt.Sprintf("/attributes/%d", count)
1739- got, err := getMetadataAttributes(reqStr)
1740+ got, err := getMetadataAttributes(context.Background(), reqStr)
1741 if err != nil {
1742 hasErr = true
1743 }
1744@@ -230,3 +249,43 @@ func TestGetMetadataAttributes(t *testing.T) {
1745 }
1746 }
1747 }
1748+
1749+type mdsClient struct{}
1750+
1751+func (mds *mdsClient) Get(ctx context.Context) (*metadata.Descriptor, error) {
1752+ return nil, fmt.Errorf("Get() not yet implemented")
1753+}
1754+
1755+func (mds *mdsClient) GetKey(ctx context.Context, key string, headers map[string]string) (string, error) {
1756+ return "", fmt.Errorf("GetKey() not yet implemented")
1757+}
1758+
1759+func (mds *mdsClient) GetKeyRecursive(ctx context.Context, key string) (string, error) {
1760+ i, err := strconv.Atoi(key[strings.LastIndex(key, "/")+1:])
1761+ if err != nil {
1762+ return "", err
1763+ }
1764+
1765+ switch i {
1766+ case 0:
1767+ return `{"enable-windows-ssh":"true","ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","block-project-ssh-keys":"false","other-metadata":"foo"}`, nil
1768+ case 1:
1769+ return `{"enable-windows-ssh":"true","ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","block-project-ssh-keys":"true","other-metadata":"foo"}`, nil
1770+ case 2:
1771+ return `{"ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","block-project-ssh-keys":"false","other-metadata":"foo"}`, nil
1772+ case 3:
1773+ return `{"enable-windows-ssh":"false","ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","other-metadata":"foo"}`, nil
1774+ case 4:
1775+ return "BADJSON", nil
1776+ default:
1777+ return "", fmt.Errorf("unknown key %q", key)
1778+ }
1779+}
1780+
1781+func (mds *mdsClient) Watch(ctx context.Context) (*metadata.Descriptor, error) {
1782+ return nil, fmt.Errorf("Watch() not yet implemented")
1783+}
1784+
1785+func (mds *mdsClient) WriteGuestAttributes(ctx context.Context, key string, value string) error {
1786+ return fmt.Errorf("WriteGuestattributes() not yet implemented")
1787+}
1788diff --git a/google_guest_agent/addresses.go b/google_guest_agent/addresses.go
1789index 7c0c741..6f13797 100644
1790--- a/google_guest_agent/addresses.go
1791+++ b/google_guest_agent/addresses.go
1792@@ -19,25 +19,21 @@ import (
1793 "errors"
1794 "fmt"
1795 "net"
1796- "os"
1797 "reflect"
1798 "runtime"
1799+ "slices"
1800 "strings"
1801- "time"
1802
1803 "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
1804+ network "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/network/manager"
1805 "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/run"
1806- "github.com/GoogleCloudPlatform/guest-agent/metadata"
1807- "github.com/GoogleCloudPlatform/guest-agent/utils"
1808 "github.com/GoogleCloudPlatform/guest-logging-go/logger"
1809 )
1810
1811 var (
1812- addressKey = regKeyBase + `\ForwardedIps`
1813- oldWSFCAddresses string
1814- oldWSFCEnable bool
1815- interfacesEnabled bool
1816- interfaces []net.Interface
1817+ addressKey = regKeyBase + `\ForwardedIps`
1818+ oldWSFCAddresses string
1819+ oldWSFCEnable bool
1820 )
1821
1822 type addressMgr struct{}
1823@@ -78,7 +74,9 @@ func getForwardsFromRegistry(mac string) ([]string, error) {
1824 oldName := strings.Replace(mac, ":", "", -1)
1825 regFwdIPs, err = readRegMultiString(addressKey, oldName)
1826 if err == nil {
1827- deleteRegKey(addressKey, oldName)
1828+ if err = deleteRegKey(addressKey, oldName); err != nil {
1829+ logger.Warningf("Failed to delete key: %q, name: %q from registry", addressKey, oldName)
1830+ }
1831 }
1832 } else if err != nil {
1833 return nil, err
1834@@ -88,13 +86,13 @@ func getForwardsFromRegistry(mac string) ([]string, error) {
1835
1836 func compareRoutes(configuredRoutes, desiredRoutes []string) (toAdd, toRm []string) {
1837 for _, desiredRoute := range desiredRoutes {
1838- if !utils.ContainsString(desiredRoute, configuredRoutes) {
1839+ if !slices.Contains(configuredRoutes, desiredRoute) {
1840 toAdd = append(toAdd, desiredRoute)
1841 }
1842 }
1843
1844 for _, configuredRoute := range configuredRoutes {
1845- if !utils.ContainsString(configuredRoute, desiredRoutes) {
1846+ if !slices.Contains(desiredRoutes, configuredRoute) {
1847 toRm = append(toRm, configuredRoute)
1848 }
1849 }
1850@@ -103,20 +101,6 @@ func compareRoutes(configuredRoutes, desiredRoutes []string) (toAdd, toRm []stri
1851
1852 var badMAC []string
1853
1854-func getInterfaceByMAC(mac string) (net.Interface, error) {
1855- hwaddr, err := net.ParseMAC(mac)
1856- if err != nil {
1857- return net.Interface{}, err
1858- }
1859-
1860- for _, iface := range interfaces {
1861- if iface.HardwareAddr.String() == hwaddr.String() {
1862- return iface, nil
1863- }
1864- }
1865- return net.Interface{}, fmt.Errorf("no interface found with MAC %s", mac)
1866-}
1867-
1868 // https://www.ietf.org/rfc/rfc1354.txt
1869 // Only fields that we currently care about.
1870 type ipForwardEntry struct {
1871@@ -221,7 +205,7 @@ func (a *addressMgr) applyWSFCFilter(config *cfg.Sections) {
1872 for idx := range interfaces {
1873 var filteredForwardedIps []string
1874 for _, ip := range interfaces[idx].ForwardedIps {
1875- if !utils.ContainsString(ip, wsfcAddrs) {
1876+ if !slices.Contains(wsfcAddrs, ip) {
1877 filteredForwardedIps = append(filteredForwardedIps, ip)
1878 }
1879 }
1880@@ -229,7 +213,7 @@ func (a *addressMgr) applyWSFCFilter(config *cfg.Sections) {
1881
1882 var filteredTargetInstanceIps []string
1883 for _, ip := range interfaces[idx].TargetInstanceIps {
1884- if !utils.ContainsString(ip, wsfcAddrs) {
1885+ if !slices.Contains(wsfcAddrs, ip) {
1886 filteredTargetInstanceIps = append(filteredTargetInstanceIps, ip)
1887 }
1888 }
1889@@ -290,28 +274,10 @@ func (a *addressMgr) Set(ctx context.Context) error {
1890 a.applyWSFCFilter(config)
1891 }
1892
1893- var err error
1894- interfaces, err = net.Interfaces()
1895+ // Setup network interfaces.
1896+ err := network.SetupInterfaces(ctx, config, newMetadata.Instance.NetworkInterfaces)
1897 if err != nil {
1898- return fmt.Errorf("error populating interfaces: %v", err)
1899- }
1900-
1901- if config.NetworkInterfaces.Setup {
1902- if runtime.GOOS != "windows" {
1903- logger.Debugf("Configure IPv6")
1904- if err := configureIPv6(ctx); err != nil {
1905- // Continue through IPv6 configuration errors.
1906- logger.Errorf("Error configuring IPv6: %v", err)
1907- }
1908- }
1909-
1910- if runtime.GOOS != "windows" && !interfacesEnabled {
1911- logger.Debugf("Enable network interfaces")
1912- if err := enableNetworkInterfaces(ctx, config); err != nil {
1913- return err
1914- }
1915- interfacesEnabled = true
1916- }
1917+ return fmt.Errorf("failed to setup network interfaces: %v", err)
1918 }
1919
1920 if !config.NetworkInterfaces.IPForwarding {
1921@@ -321,9 +287,9 @@ func (a *addressMgr) Set(ctx context.Context) error {
1922 logger.Debugf("Add routes for aliases, forwarded IP and target-instance IPs")
1923 // Add routes for IP aliases, forwarded and target-instance IPs.
1924 for _, ni := range newMetadata.Instance.NetworkInterfaces {
1925- iface, err := getInterfaceByMAC(ni.Mac)
1926+ iface, err := network.GetInterfaceByMAC(ni.Mac)
1927 if err != nil {
1928- if !utils.ContainsString(ni.Mac, badMAC) {
1929+ if !slices.Contains(badMAC, ni.Mac) {
1930 logger.Errorf("Error getting interface: %s", err)
1931 badMAC = append(badMAC, ni.Mac)
1932 }
1933@@ -356,7 +322,7 @@ func (a *addressMgr) Set(ctx context.Context) error {
1934 }
1935 for _, ip := range configuredIPs {
1936 // Only add to `forwardedIPs` if it is recorded in the registry.
1937- if utils.ContainsString(ip, regFwdIPs) {
1938+ if slices.Contains(regFwdIPs, ip) {
1939 forwardedIPs = append(forwardedIPs, ip)
1940 }
1941 }
1942@@ -399,14 +365,14 @@ func (a *addressMgr) Set(ctx context.Context) error {
1943 var registryEntries []string
1944 for _, ip := range wantIPs {
1945 // If the IP is not in toAdd, add to registry list and continue.
1946- if !utils.ContainsString(ip, toAdd) {
1947+ if !slices.Contains(toAdd, ip) {
1948 registryEntries = append(registryEntries, ip)
1949 continue
1950 }
1951 var err error
1952 if runtime.GOOS == "windows" {
1953 // Don't addAddress if this is already configured.
1954- if !utils.ContainsString(ip, configuredIPs) {
1955+ if !slices.Contains(configuredIPs, ip) {
1956 err = addAddress(net.ParseIP(ip), net.IPv4Mask(255, 255, 255, 255), uint32(iface.Index))
1957 }
1958 } else {
1959@@ -422,7 +388,7 @@ func (a *addressMgr) Set(ctx context.Context) error {
1960 for _, ip := range toRm {
1961 var err error
1962 if runtime.GOOS == "windows" {
1963- if !utils.ContainsString(ip, configuredIPs) {
1964+ if !slices.Contains(configuredIPs, ip) {
1965 continue
1966 }
1967 err = removeAddress(net.ParseIP(ip), uint32(iface.Index))
1968@@ -445,193 +411,3 @@ func (a *addressMgr) Set(ctx context.Context) error {
1969
1970 return nil
1971 }
1972-
1973-// Enables or disables IPv6 on network interfaces.
1974-func configureIPv6(ctx context.Context) error {
1975- var newNi, oldNi metadata.NetworkInterfaces
1976- if len(newMetadata.Instance.NetworkInterfaces) == 0 {
1977- return fmt.Errorf("no interfaces found in metadata")
1978- }
1979- newNi = newMetadata.Instance.NetworkInterfaces[0]
1980- if len(oldMetadata.Instance.NetworkInterfaces) > 0 {
1981- oldNi = oldMetadata.Instance.NetworkInterfaces[0]
1982- }
1983- iface, err := getInterfaceByMAC(newNi.Mac)
1984- if err != nil {
1985- return err
1986- }
1987- switch {
1988- case oldNi.DHCPv6Refresh != "" && newNi.DHCPv6Refresh == "",
1989- newNi.DHCPv6Refresh == "" && len(oldMetadata.Instance.NetworkInterfaces) == 0:
1990- // disable
1991- // uses empty old interface slice to indicate this is first-run.
1992-
1993- // Before obtaining or releasing an IPv6 lease, we wait for
1994- // 'tentative' IPs as part of SLAAC. We wait up to 5 seconds
1995- // for this condition to automatically resolve.
1996- tentative := []string{"-6", "-o", "a", "s", "dev", iface.Name, "scope", "link", "tentative"}
1997- for i := 0; i < 5; i++ {
1998- res := run.WithOutput(ctx, "ip", tentative...)
1999- if res.ExitCode == 0 && res.StdOut == "" {
2000- break
2001- }
2002- time.Sleep(1 * time.Second)
2003- }
2004- if err := run.Quiet(ctx, "dhclient", "-r", "-6", "-1", "-v", iface.Name); err != nil {
2005- return err
2006- }
2007- case oldNi.DHCPv6Refresh == "" && newNi.DHCPv6Refresh != "":
2008- // enable
2009- tentative := []string{"-6", "-o", "a", "s", "dev", iface.Name, "scope", "link", "tentative"}
2010- for i := 0; i < 5; i++ {
2011- res := run.WithOutput(ctx, "ip", tentative...)
2012- if res.ExitCode == 0 && res.StdOut == "" {
2013- break
2014- }
2015- time.Sleep(1 * time.Second)
2016- }
2017- val := fmt.Sprintf("net.ipv6.conf.%s.accept_ra_rt_info_max_plen=128", iface.Name)
2018- if err := run.Quiet(ctx, "sysctl", val); err != nil {
2019- return err
2020- }
2021- if err := run.Quiet(ctx, "dhclient", "-1", "-6", "-v", iface.Name); err != nil {
2022- return err
2023- }
2024- }
2025- return nil
2026-}
2027-
2028-// enableNetworkInterfaces runs `dhclient eth1 eth2 ... ethN`
2029-// and `dhclient -6 eth1 eth2 ... ethN`.
2030-// On RHEL7, it also calls disableNM for each interface.
2031-// On SLES, it calls enableSLESInterfaces instead of dhclient.
2032-func enableNetworkInterfaces(ctx context.Context, config *cfg.Sections) error {
2033- if len(newMetadata.Instance.NetworkInterfaces) < 2 {
2034- return nil
2035- }
2036- var googleInterfaces []string
2037- // The primary (first) interface is managed by the OS, we only handle
2038- // secondary interfaces in this code.
2039- for _, ni := range newMetadata.Instance.NetworkInterfaces[1:] {
2040- iface, err := getInterfaceByMAC(ni.Mac)
2041- if err != nil {
2042- if !utils.ContainsString(ni.Mac, badMAC) {
2043- logger.Errorf("Error getting interface: %s", err)
2044- badMAC = append(badMAC, ni.Mac)
2045- }
2046- continue
2047- }
2048- googleInterfaces = append(googleInterfaces, iface.Name)
2049- }
2050- var googleIpv6Interfaces []string
2051- for _, ni := range newMetadata.Instance.NetworkInterfaces[1:] {
2052- if ni.DHCPv6Refresh == "" {
2053- // This interface is not IPv6 enabled
2054- continue
2055- }
2056- iface, err := getInterfaceByMAC(ni.Mac)
2057- if err != nil {
2058- if !utils.ContainsString(ni.Mac, badMAC) {
2059- logger.Errorf("Error getting interface: %s", err)
2060- badMAC = append(badMAC, ni.Mac)
2061- }
2062- continue
2063- }
2064- googleIpv6Interfaces = append(googleIpv6Interfaces, iface.Name)
2065- }
2066-
2067- switch {
2068- case osInfo.OS == "sles":
2069- return enableSLESInterfaces(ctx, googleInterfaces)
2070- case (osInfo.OS == "rhel" || osInfo.OS == "centos") && osInfo.Version.Major >= 7:
2071- for _, iface := range googleInterfaces {
2072- err := disableNM(iface)
2073- if err != nil {
2074- return err
2075- }
2076- }
2077- fallthrough
2078- default:
2079- dhcpCommand := config.NetworkInterfaces.DHCPCommand
2080- if dhcpCommand != "" {
2081- tokens := strings.Split(dhcpCommand, " ")
2082- return run.Quiet(ctx, tokens[0], tokens[1:]...)
2083- }
2084-
2085- // Try IPv4 first as it's higher priority.
2086- if err := run.Quiet(ctx, "dhclient", googleInterfaces...); err != nil {
2087- return err
2088- }
2089-
2090- if len(googleIpv6Interfaces) == 0 {
2091- return nil
2092- }
2093- for _, iface := range googleIpv6Interfaces {
2094- // Enable kernel to accept to route advertisements.
2095- val := fmt.Sprintf("net.ipv6.conf.%s.accept_ra_rt_info_max_plen=128", iface)
2096- if err := run.Quiet(ctx, "sysctl", val); err != nil {
2097- return err
2098- }
2099- }
2100-
2101- var dhclientArgs6 []string
2102- dhclientArgs6 = append([]string{"-6"}, googleIpv6Interfaces...)
2103- return run.Quiet(ctx, "dhclient", dhclientArgs6...)
2104- }
2105-}
2106-
2107-// enableSLESInterfaces writes one ifcfg file for each interface, then
2108-// runs `wicked ifup eth1 eth2 ... ethN`
2109-func enableSLESInterfaces(ctx context.Context, interfaces []string) error {
2110- var err error
2111- var priority = 10100
2112- for _, iface := range interfaces {
2113- logger.Debugf("write enabling ifcfg-%s config", iface)
2114-
2115- var ifcfg *os.File
2116- ifcfg, err = os.Create("/etc/sysconfig/network/ifcfg-" + iface)
2117- if err != nil {
2118- return err
2119- }
2120- defer closer(ifcfg)
2121- contents := []string{
2122- googleComment,
2123- "STARTMODE=hotplug",
2124- // NOTE: 'dhcp' is the dhcp4+dhcp6 option.
2125- "BOOTPROTO=dhcp",
2126- fmt.Sprintf("DHCLIENT_ROUTE_PRIORITY=%d", priority),
2127- }
2128- _, err = ifcfg.WriteString(strings.Join(contents, "\n"))
2129- if err != nil {
2130- return err
2131- }
2132- priority += 100
2133- }
2134- args := append([]string{"ifup", "--timeout", "1"}, interfaces...)
2135- return run.Quiet(ctx, "/usr/sbin/wicked", args...)
2136-}
2137-
2138-// disableNM writes an ifcfg file with DHCP and NetworkManager disabled.
2139-func disableNM(iface string) error {
2140- logger.Debugf("write disabling ifcfg-%s config", iface)
2141- filename := "/etc/sysconfig/network-scripts/ifcfg-" + iface
2142- ifcfg, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644)
2143- if err == nil {
2144- defer closer(ifcfg)
2145- contents := []string{
2146- googleComment,
2147- fmt.Sprintf("DEVICE=%s", iface),
2148- "BOOTPROTO=none",
2149- "DEFROUTE=no",
2150- "IPV6INIT=no",
2151- "NM_CONTROLLED=no",
2152- "NOZEROCONF=yes",
2153- }
2154- _, err = ifcfg.WriteString(strings.Join(contents, "\n"))
2155- return err
2156- }
2157- if os.IsExist(err) {
2158- return nil
2159- }
2160- return err
2161-}
2162diff --git a/google_guest_agent/addresses_integ_test.go b/google_guest_agent/addresses_integ_test.go
2163deleted file mode 100644
2164index 1486986..0000000
2165--- a/google_guest_agent/addresses_integ_test.go
2166+++ /dev/null
2167@@ -1,99 +0,0 @@
2168-// Copyright 2021 Google LLC
2169-
2170-// Licensed under the Apache License, Version 2.0 (the "License");
2171-// you may not use this file except in compliance with the License.
2172-// You may obtain a copy of the License at
2173-
2174-// https://www.apache.org/licenses/LICENSE-2.0
2175-
2176-// Unless required by applicable law or agreed to in writing, software
2177-// distributed under the License is distributed on an "AS IS" BASIS,
2178-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2179-// See the License for the specific language governing permissions and
2180-// limitations under the License.
2181-
2182-//go:build integration
2183-// +build integration
2184-
2185-package main
2186-
2187-import (
2188- "context"
2189- "fmt"
2190- "strings"
2191- "testing"
2192-)
2193-
2194-const testIp = "192.168.0.0"
2195-
2196-func TestAddAndRemoveLocalRoute(t *testing.T) {
2197- metdata, err := getMetadata(context.Context(), false)
2198- if err != nil {
2199- t.Fatalf("failed to get metadata, err %v", err)
2200- }
2201- iface, err := getInterfaceByMAC(metdata.Instance.NetworkInterfaces[0].Mac)
2202- if err != nil {
2203- t.Fatalf("failed to get interface from mac, err %v", err)
2204- }
2205- // test add local route
2206- if err := removeLocalRoute(testIp, iface.Name); err != nil {
2207- t.Fatalf("failed to remove local route, err %v", err)
2208- }
2209- if err := addLocalRoute(testIp, iface.Name); err != nil {
2210- t.Fatalf("add test local route should not failed, err %v", err)
2211- }
2212-
2213- res, err := getLocalRoutes(iface.Name)
2214- if err != nil {
2215- t.Fatalf("get local route should not failed, err %v", err)
2216- }
2217- exist := false
2218- for _, route := range res {
2219- if strings.Contains(route, fmt.Sprintf("local %s/24", testIp)) {
2220- exist = true
2221- }
2222- }
2223- if !exist {
2224- t.Fatalf("route %s is not added", testIp)
2225- }
2226-
2227- // test remove local route
2228- if err := removeLocalRoute(testIp, iface.Name); err != nil {
2229- t.Fatalf("add test local route should not failed")
2230- }
2231- res, err := getLocalRoutes(iface.Name)
2232- if err != nil {
2233- t.Fatalf("ip route list should not failed, err %s", res.err)
2234- }
2235-
2236- for _, route := range res {
2237- if strings.Contains(route, fmt.Sprintf("local %s/24", testIp)) {
2238- t.Fatalf("route %s should be removed but exist", testIp)
2239- }
2240- }
2241-}
2242-
2243-func TestGetLocalRoute(t *testing.T) {
2244- metdata, err := getMetadata(context.Context(), false)
2245- if err != nil {
2246- t.Fatalf("failed to get metadata, err %v", err)
2247- }
2248- iface, err := getInterfaceByMAC(metdata.Instance.NetworkInterfaces[0].Mac)
2249- if err != nil {
2250- t.Fatalf("failed to get interface from mac, err %v", err)
2251- }
2252-
2253- if err := addLocalRoute(testIp, iface.Name); err != nil {
2254- t.Fatalf("add test local route should not failed, err %v", err)
2255- }
2256- routes, err := getLocalRoutes(iface.Name)
2257- if err != nil {
2258- t.Fatalf("get local routes should not failed, err %v", err)
2259- }
2260- if len(routes) != 1 {
2261- t.Fatal("find unexpected local route %s.", routes[0])
2262- }
2263- if routes[0] != testIp {
2264- t.Fatal("find unexpected local route %s.", routes[0])
2265- }
2266-}
2267diff --git a/google_guest_agent/agentcrypto/mtls_mds.go b/google_guest_agent/agentcrypto/mtls_mds.go
2268index ad61634..ca3dfe1 100644
2269--- a/google_guest_agent/agentcrypto/mtls_mds.go
2270+++ b/google_guest_agent/agentcrypto/mtls_mds.go
2271@@ -35,7 +35,7 @@ import (
2272 const (
2273 // UEFI variables are of format {VariableName}-{VendorGUID}
2274 // googleGUID is Google's (vendors/variable owners) GUID used to prevent name collision with other vendors.
2275- googleGUID = "8be4df61-93ca-11d2-aa0d-00e098032b8c"
2276+ googleGUID = "a2858e46-a37f-456a-8c79-0c1fe48b65ff"
2277 // googleRootCACertEFIVarName is predefined string part of the UEFI variable name that holds Root CA cert.
2278 googleRootCACertEFIVarName = "InstanceRootCACertificate"
2279 // clientCertsKey is the metadata server key at which client identity certificate is exposed.
2280diff --git a/google_guest_agent/agentcrypto/mtls_mds_linux.go b/google_guest_agent/agentcrypto/mtls_mds_linux.go
2281index ee64ff1..3a75ff0 100644
2282--- a/google_guest_agent/agentcrypto/mtls_mds_linux.go
2283+++ b/google_guest_agent/agentcrypto/mtls_mds_linux.go
2284@@ -17,6 +17,7 @@ package agentcrypto
2285 import (
2286 "context"
2287 "fmt"
2288+ "os"
2289 "os/exec"
2290 "path/filepath"
2291
2292@@ -35,8 +36,25 @@ const (
2293 clientCredsFileName = "client.key"
2294 )
2295
2296+var (
2297+ // certUpdaters is a map of known CA certificate updaters with the local directory paths for certificates.
2298+ certUpdaters = map[string][]string{
2299+ // SUSE, Debian and Ubuntu distributions.
2300+ // https://manpages.ubuntu.com/manpages/xenial/man8/update-ca-certificates.8.html
2301+ // https://github.com/openSUSE/ca-certificates
2302+ "update-ca-certificates": {"/usr/local/share/ca-certificates", "/usr/share/pki/trust/anchors"},
2303+ // CentOS, Fedora, RedHat distributions.
2304+ // https://www.unix.com/man-page/centos/8/UPDATE-CA-TRUST
2305+ "update-ca-trust": {"/etc/pki/ca-trust/source/anchors"},
2306+ }
2307+)
2308+
2309 // writeRootCACert writes Root CA cert from UEFI variable to output file.
2310 func (j *CredsJob) writeRootCACert(ctx context.Context, content []byte, outputFile string) error {
2311+ // The directory should be executable, but the file does not need to be.
2312+ if err := os.MkdirAll(filepath.Dir(outputFile), 0655); err != nil {
2313+ return err
2314+ }
2315 if err := utils.SaferWriteFile(content, outputFile, 0644); err != nil {
2316 return err
2317 }
2318@@ -51,39 +69,42 @@ func (j *CredsJob) writeRootCACert(ctx context.Context, content []byte, outputFi
2319
2320 // writeClientCredentials stores client credentials (certificate and private key).
2321 func (j *CredsJob) writeClientCredentials(plaintext []byte, outputFile string) error {
2322+ // The directory should be executable, but the file does not need to be.
2323+ if err := os.MkdirAll(filepath.Dir(outputFile), 0655); err != nil {
2324+ return err
2325+ }
2326 return utils.SaferWriteFile(plaintext, outputFile, 0644)
2327 }
2328
2329 // getCAStoreUpdater interates over known system trust store updaters and returns the first found.
2330 func getCAStoreUpdater() (string, error) {
2331- knownUpdaters := []string{"update-ca-certificates", "update-ca-trust"}
2332 var errs []string
2333
2334- for _, u := range knownUpdaters {
2335+ for u := range certUpdaters {
2336 _, err := exec.LookPath(u)
2337 if err == nil {
2338 return u, nil
2339 }
2340- errs = append(errs, err.Error())
2341+ errs = append(errs, fmt.Sprintf("lookup for %q failed with error: %v", u, err))
2342 }
2343
2344- return "", fmt.Errorf("no known trust updaters %v were found: %v", knownUpdaters, errs)
2345+ return "", fmt.Errorf("no known trust updaters were found: %v", errs)
2346 }
2347
2348 // certificateDirFromUpdater returns directory of local CA certificates for the given updater tool.
2349 func certificateDirFromUpdater(updater string) (string, error) {
2350- switch updater {
2351- // SUSE, Debian and Ubuntu distributions.
2352- // https://manpages.ubuntu.com/manpages/xenial/man8/update-ca-certificates.8.html
2353- case "update-ca-certificates":
2354- return "/usr/local/share/ca-certificates/", nil
2355- // CentOS, Fedora, RedHat distributions.
2356- // https://www.unix.com/man-page/centos/8/UPDATE-CA-TRUST/
2357- case "update-ca-trust":
2358- return "/etc/pki/ca-trust/source/anchors/", nil
2359- default:
2360+ dirs, ok := certUpdaters[updater]
2361+ if !ok {
2362 return "", fmt.Errorf("unknown updater %q, no local trusted CA certificate directory found", updater)
2363 }
2364+
2365+ for _, dir := range dirs {
2366+ fi, err := os.Stat(dir)
2367+ if err == nil && fi.IsDir() {
2368+ return dir, nil
2369+ }
2370+ }
2371+ return "", fmt.Errorf("no of the known directories %v found for updater %q", dirs, updater)
2372 }
2373
2374 // updateSystemStore updates the local system store with the cert.
2375diff --git a/google_guest_agent/agentcrypto/mtls_mds_linux_test.go b/google_guest_agent/agentcrypto/mtls_mds_linux_test.go
2376index 020af48..831e02f 100644
2377--- a/google_guest_agent/agentcrypto/mtls_mds_linux_test.go
2378+++ b/google_guest_agent/agentcrypto/mtls_mds_linux_test.go
2379@@ -132,17 +132,24 @@ func TestShouldEnableError(t *testing.T) {
2380 }
2381
2382 func TestCertificateDirFromUpdater(t *testing.T) {
2383+ updater1Dir := t.TempDir()
2384+ updater2Dir := t.TempDir()
2385+ certUpdaters = map[string][]string{
2386+ "updater1": {updater1Dir},
2387+ "updater2": {"/does/not/exist", updater2Dir},
2388+ }
2389+
2390 tests := []struct {
2391 updater string
2392 want string
2393 }{
2394 {
2395- updater: "update-ca-certificates",
2396- want: "/usr/local/share/ca-certificates/",
2397+ updater: "updater1",
2398+ want: updater1Dir,
2399 },
2400 {
2401- updater: "update-ca-trust",
2402- want: "/etc/pki/ca-trust/source/anchors/",
2403+ updater: "updater2",
2404+ want: updater2Dir,
2405 },
2406 }
2407
2408@@ -160,8 +167,18 @@ func TestCertificateDirFromUpdater(t *testing.T) {
2409 }
2410
2411 func TestCertificateDirFromUpdaterError(t *testing.T) {
2412+ // Fail for unknown updater.
2413 _, err := certificateDirFromUpdater("unknown")
2414 if err == nil {
2415 t.Errorf("certificateDirFromUpdater(unknown) succeeded for unknown updater, want error")
2416 }
2417+
2418+ // Fail for missing known cert dir.
2419+ certUpdaters = map[string][]string{
2420+ "updater1": {"/no/dir/exist"},
2421+ }
2422+ _, err = certificateDirFromUpdater("updater1")
2423+ if err == nil {
2424+ t.Errorf("certificateDirFromUpdater(unknown) succeeded for missing cert dir, want error")
2425+ }
2426 }
2427diff --git a/google_guest_agent/agentcrypto/mtls_mds_windows.go b/google_guest_agent/agentcrypto/mtls_mds_windows.go
2428index a11f113..3f70954 100644
2429--- a/google_guest_agent/agentcrypto/mtls_mds_windows.go
2430+++ b/google_guest_agent/agentcrypto/mtls_mds_windows.go
2431@@ -59,6 +59,12 @@ var (
2432
2433 // writeRootCACert writes Root CA cert from UEFI variable to output file.
2434 func (j *CredsJob) writeRootCACert(_ context.Context, cacert []byte, outputFile string) error {
2435+ // Try to fetch previous certificate's serial number before it gets overwritten.
2436+ num, err := serialNumber(outputFile)
2437+ if err != nil {
2438+ logger.Debugf("No previous MDS root certificate was found, will skip cleanup: %v", err)
2439+ }
2440+
2441 if err := utils.SaferWriteFile(cacert, outputFile, 0644); err != nil {
2442 return err
2443 }
2444@@ -84,6 +90,25 @@ func (j *CredsJob) writeRootCACert(_ context.Context, cacert []byte, outputFile
2445 return fmt.Errorf("failed to store root cert ctx in store: %w", err)
2446 }
2447
2448+ // MDS root cert was not refreshed or there's no previous cert, nothing to do, return.
2449+ if num == "" || fmt.Sprintf("%x", x509Cert.SerialNumber) == num {
2450+ return nil
2451+ }
2452+
2453+ // Certificate is refreshed. Best effort to find the certcontext and delete it.
2454+ // Don't throw error here, it would skip client credential generation which
2455+ // may be about to expire.
2456+ oldCtx, err := findCert(root, certificateIssuer, num)
2457+ if err != nil {
2458+ logger.Warningf("Failed to find previous MDS root certificate with error: %v", err)
2459+ return nil
2460+ }
2461+
2462+ if err := deleteCert(oldCtx, root); err != nil {
2463+ logger.Warningf("Failed to delete previous MDS root certificate(%s) with error: %v", num, err)
2464+ return nil
2465+ }
2466+
2467 return nil
2468 }
2469
2470diff --git a/google_guest_agent/cfg/cfg.go b/google_guest_agent/cfg/cfg.go
2471index 3a54c47..c59ec14 100644
2472--- a/google_guest_agent/cfg/cfg.go
2473+++ b/google_guest_agent/cfg/cfg.go
2474@@ -27,6 +27,10 @@ var (
2475 // should always return it.
2476 instance *Sections
2477
2478+ // configFile is a pointer to a function which takes the current OS name and returns
2479+ // an appropriate config file name. Replaceable by unit tests.
2480+ configFile = defaultConfigFile
2481+
2482 // dataSource is a pointer to a data source loading/defining function, unit tests will
2483 // want to change this pointer to whatever makes sense to its implementation.
2484 dataSources = defaultDataSources
2485@@ -87,6 +91,9 @@ setup = true
2486 [OSLogin]
2487 cert_authentication = true
2488
2489+[MDS]
2490+mtls_bootstrapping_enabled = true
2491+
2492 [Snapshots]
2493 enabled = false
2494 snapshot_service_ip = 169.254.169.254
2495@@ -94,7 +101,10 @@ snapshot_service_port = 8081
2496 timeout_in_seconds = 60
2497
2498 [Unstable]
2499-mds_mtls = false
2500+command_monitor_enabled = false
2501+command_pipe_mode = 0770
2502+command_pipe_group =
2503+command_request_timeout = 10s
2504 `
2505 )
2506
2507@@ -144,6 +154,9 @@ type Sections struct {
2508 // OSLogin defines the OS Login configuration options.
2509 OSLogin *OSLogin `ini:"OSLogin,omitempty"`
2510
2511+ // MDS defines the MDS configuration options.
2512+ MDS *MDS `ini:"MDS,omitempty"`
2513+
2514 // Snpashots defines the snapshot listener configuration and behavior i.e. the server address and port.
2515 Snapshots *Snapshots `ini:"Snapshots,omitempty"`
2516
2517@@ -237,6 +250,12 @@ type OSLogin struct {
2518 CertAuthentication bool `ini:"cert_authentication,omitempty"`
2519 }
2520
2521+// MDS contains the configurations for MDS section.
2522+type MDS struct {
2523+ // MTLSBootstrappingEnabled enables/disables the mTLS credential refresher.
2524+ MTLSBootstrappingEnabled bool `ini:"mtls_bootstrapping_enabled,omitempty"`
2525+}
2526+
2527 // NetworkInterfaces contains the configurations of NetworkInterfaces section.
2528 type NetworkInterfaces struct {
2529 DHCPCommand string `ini:"dhcp_command,omitempty"`
2530@@ -256,7 +275,11 @@ type Snapshots struct {
2531 // is guaranteed for configurations defined in the Unstable section. By default all flags defined
2532 // in this section is disabled and is intended to isolate under development features.
2533 type Unstable struct {
2534- MDSMTLS bool `ini:"mds_mtls,omitempty"`
2535+ CommandMonitorEnabled bool `ini:"command_monitor_enabled,omitempty"`
2536+ CommandPipePath string `ini:"command_pipe_path,omitempty"`
2537+ CommandRequestTimeout string `ini:"command_request_timeout,omitempty"`
2538+ CommandPipeMode string `ini:"command_pipe_mode,omitempty"`
2539+ CommandPipeGroup string `ini:"command_pipe_group,omitempty"`
2540 }
2541
2542 // WSFC contains the configurations of WSFC section.
2543@@ -274,18 +297,17 @@ func defaultConfigFile(osName string) string {
2544 }
2545
2546 func defaultDataSources(extraDefaults []byte) []interface{} {
2547- var res []interface{}
2548- configFile := defaultConfigFile(runtime.GOOS)
2549+ var res = []interface{}{[]byte(defaultConfig)}
2550+ config := configFile(runtime.GOOS)
2551
2552 if len(extraDefaults) > 0 {
2553 res = append(res, extraDefaults)
2554 }
2555
2556 return append(res, []interface{}{
2557- []byte(defaultConfig),
2558- configFile,
2559- configFile + ".distro",
2560- configFile + ".template",
2561+ config,
2562+ config + ".distro",
2563+ config + ".template",
2564 }...)
2565 }
2566
2567diff --git a/google_guest_agent/cfg/cfg_test.go b/google_guest_agent/cfg/cfg_test.go
2568index efea06f..20321c1 100644
2569--- a/google_guest_agent/cfg/cfg_test.go
2570+++ b/google_guest_agent/cfg/cfg_test.go
2571@@ -14,7 +14,9 @@
2572
2573 package cfg
2574
2575-import "testing"
2576+import (
2577+ "testing"
2578+)
2579
2580 func TestLoad(t *testing.T) {
2581 if err := Load(nil); err != nil {
2582diff --git a/google_guest_agent/command/Readme.md b/google_guest_agent/command/Readme.md
2583new file mode 100644
2584index 0000000..4fabd96
2585--- /dev/null
2586+++ b/google_guest_agent/command/Readme.md
2587@@ -0,0 +1,24 @@
2588+# Guest Agent Command Monitor
2589+## Overview
2590+The Guest Agent command monitor is a system used for executing commands in the guest agent on behalf of components in the guest os.
2591+
2592+The events layer is formed of a **Monitor**, a **Server** and a **Handler** where the **Monitor** handles command registration for guest agent components, the **Server** is the component which listens for events from the gueest os, and the **Handler** is the function executed by the agent.
2593+
2594+Each **Handler** is identified by a string ID, provided when sending commands to the server. Requests and response to and from the server are structured in JSON format. A request must contain the name field, specifying the handler to be executed. A request may contain arbitrary other fields to be passed to the handler. An example request is below:
2595+
2596+```
2597+{"Name":"agent.ExampleCommand","ArbitraryArgument":123}
2598+```
2599+
2600+A response will be valid JSON and has two required fields: Status and StatusMessage. Status is an int which follows unix status code conventions (ie zero is success, status codes are arbitrary and meaning is defined by the function called) and StatusMessage is an explanatory string accompanying the Status. Two example responses are below.
2601+
2602+```
2603+{"Status":0,"StatusMessage":""}
2604+
2605+{"Status":7,"StatusMessage":"Failure message"}
2606+```
2607+
2608+By default, the Server listens on a unix socket or a named pipe, depending on platform. Permissions for the pipe and the pipe path can be set in the guest-agent [configuration](https://github.com/GoogleCloudPlatform/guest-agent#configuration). The default pipe path for windows and linux systems are `\\.\pipe\google-guest-agent-commands` non-windows and `/run/google-guest-agent/commands.sock` respectively.
2609+
2610+## Implementing a command handler
2611+Registering a command handler will expose the handler function to be called by anyone with write permission to the underlying socket. To do so, call `command.Get().RegisterHandler(name, handerFunc)` to get the current command monitor and register the handlerFunc with it. Note that if the command system is disabled by user configuration, handler registration will succeed but the server will not be available for callers to send commands to.
2612diff --git a/google_guest_agent/command/command.go b/google_guest_agent/command/command.go
2613new file mode 100644
2614index 0000000..527d20f
2615--- /dev/null
2616+++ b/google_guest_agent/command/command.go
2617@@ -0,0 +1,146 @@
2618+// Copyright 2023 Google Inc. All Rights Reserved.
2619+//
2620+// Licensed under the Apache License, Version 2.0 (the "License");
2621+// you may not use this file except in compliance with the License.
2622+// You may obtain a copy of the License at
2623+//
2624+// http://www.apache.org/licenses/LICENSE-2.0
2625+//
2626+// Unless required by applicable law or agreed to in writing, software
2627+// distributed under the License is distributed on an "AS IS" BASIS,
2628+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2629+// See the License for the specific language governing permissions and
2630+// limitations under the License.
2631+
2632+// Package command facilitates calling commands within the guest-agent.
2633+package command
2634+
2635+import (
2636+ "context"
2637+ "encoding/json"
2638+ "fmt"
2639+ "io"
2640+
2641+ "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
2642+)
2643+
2644+// Get returns the current command monitor which can be used to register command handlers.
2645+func Get() *Monitor {
2646+ return cmdMonitor
2647+}
2648+
2649+// Handler functions are the business logic of commands. They must process json
2650+// encoded as a byte slice which contains a Command field and optional arbitrary
2651+// data, and return json which contains a Status, StatusMessage, and optional
2652+// arbitrary data (again encoded as a byte slice). Returned errors will be
2653+// passed onto the command requester.
2654+type Handler func([]byte) ([]byte, error)
2655+
2656+// Request is the basic request structure. Command determines which handler the
2657+// request is routed to. Callers may set additional arbitrary fields.
2658+type Request struct {
2659+ Command string
2660+}
2661+
2662+// Response is the basic response structure. Handlers may set additional
2663+// arbitrary fields.
2664+type Response struct {
2665+ // Status code for the request. Meaning is defined by the caller, but
2666+ // conventially zero is success.
2667+ Status int
2668+ // StatusMessage is an optional message defined by the caller. Should generally
2669+ // help a human understand what happened.
2670+ StatusMessage string
2671+}
2672+
2673+var (
2674+ // CmdNotFoundError is return when there is no handler for the request command
2675+ CmdNotFoundError = Response{
2676+ Status: 101,
2677+ StatusMessage: "Could not find a handler for the requested command",
2678+ }
2679+ // BadRequestError is returned for invalid or unparseable JSON
2680+ BadRequestError = Response{
2681+ Status: 102,
2682+ StatusMessage: "Could not parse valid JSON from request",
2683+ }
2684+ // ConnError is returned for errors from the underlying communication protocol
2685+ ConnError = Response{
2686+ Status: 103,
2687+ StatusMessage: "Connection error",
2688+ }
2689+ // TimeoutError is returned when the timeout period elapses before valid JSON is receieved
2690+ TimeoutError = Response{
2691+ Status: 104,
2692+ StatusMessage: "Connection timeout before reading valid request",
2693+ }
2694+ // HandlerError is returned when the handler function returns an non-nil error. The status message will be replaced with the returnd error string.
2695+ HandlerError = Response{
2696+ Status: 105,
2697+ StatusMessage: "The command handler encountered an error processing your request",
2698+ }
2699+ // InternalErrorCode is the error code for internal command server errors. Returned when failing to marshal a response.
2700+ InternalErrorCode = 106
2701+ internalError = []byte(`{"Status":106,"StatusMessage":"The command server encountered an internal error trying to respond to your request"}`)
2702+)
2703+
2704+// RegisterHandler registers f as the handler for cmd. If a command.Server has
2705+// been initialized, it will be signalled to start listening for commands.
2706+func (m *Monitor) RegisterHandler(cmd string, f Handler) error {
2707+ m.handlersMu.Lock()
2708+ defer m.handlersMu.Unlock()
2709+ if _, ok := m.handlers[cmd]; ok {
2710+ return fmt.Errorf("cmd %s is already handled", cmd)
2711+ }
2712+ m.handlers[cmd] = f
2713+ return nil
2714+}
2715+
2716+// UnregisterHandler clears the handlers for cmd. If a command.Server has been
2717+// intialized and there are no more handlers registered, the server will be
2718+// signalled to stop listening for commands.
2719+func (m *Monitor) UnregisterHandler(cmd string) error {
2720+ m.handlersMu.Lock()
2721+ defer m.handlersMu.Unlock()
2722+ if _, ok := m.handlers[cmd]; !ok {
2723+ return fmt.Errorf("cmd %s is not registered", cmd)
2724+ }
2725+ delete(m.handlers, cmd)
2726+ return nil
2727+}
2728+
2729+// SendCommand sends a command request over the configured pipe.
2730+func SendCommand(ctx context.Context, req []byte) []byte {
2731+ pipe := cfg.Get().Unstable.CommandPipePath
2732+ if pipe == "" {
2733+ pipe = DefaultPipePath
2734+ }
2735+ return SendCmdPipe(ctx, pipe, req)
2736+}
2737+
2738+// SendCmdPipe sends a command request over a specific pipe. Most callers
2739+// should use SendCommand() instead.
2740+func SendCmdPipe(ctx context.Context, pipe string, req []byte) []byte {
2741+ conn, err := dialPipe(ctx, pipe)
2742+ if err != nil {
2743+ if b, err := json.Marshal(ConnError); err != nil {
2744+ return b
2745+ }
2746+ return internalError
2747+ }
2748+ i, err := conn.Write(req)
2749+ if err != nil || i != len(req) {
2750+ if b, err := json.Marshal(ConnError); err != nil {
2751+ return b
2752+ }
2753+ return internalError
2754+ }
2755+ data, err := io.ReadAll(conn)
2756+ if err != nil {
2757+ if b, err := json.Marshal(ConnError); err != nil {
2758+ return b
2759+ }
2760+ return internalError
2761+ }
2762+ return data
2763+}
2764diff --git a/google_guest_agent/command/command_linux.go b/google_guest_agent/command/command_linux.go
2765new file mode 100644
2766index 0000000..cf7dab1
2767--- /dev/null
2768+++ b/google_guest_agent/command/command_linux.go
2769@@ -0,0 +1,140 @@
2770+// Copyright 2023 Google Inc. All Rights Reserved.
2771+//
2772+// Licensed under the Apache License, Version 2.0 (the "License");
2773+// you may not use this file except in compliance with the License.
2774+// You may obtain a copy of the License at
2775+//
2776+// http://www.apache.org/licenses/LICENSE-2.0
2777+//
2778+// Unless required by applicable law or agreed to in writing, software
2779+// distributed under the License is distributed on an "AS IS" BASIS,
2780+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2781+// See the License for the specific language governing permissions and
2782+// limitations under the License.
2783+
2784+package command
2785+
2786+import (
2787+ "context"
2788+ "fmt"
2789+ "net"
2790+ "os"
2791+ "os/user"
2792+ "path"
2793+ "runtime"
2794+ "strconv"
2795+ "syscall"
2796+
2797+ "github.com/GoogleCloudPlatform/guest-logging-go/logger"
2798+)
2799+
2800+// DefaultPipePath is the default unix socket path for linux.
2801+const DefaultPipePath = "/run/google-guest-agent/commands.sock"
2802+
2803+func mkdirpWithPerms(dir string, p os.FileMode, uid, gid int) error {
2804+ stat, err := os.Stat(dir)
2805+ if err == nil {
2806+ statT, ok := stat.Sys().(*syscall.Stat_t)
2807+ if !ok {
2808+ return fmt.Errorf("could not determine owner of %s", dir)
2809+ }
2810+ if !stat.IsDir() {
2811+ return fmt.Errorf("%s exists and is not a directory", dir)
2812+ }
2813+ if morePermissive(int(stat.Mode()), int(p)) {
2814+ if err := os.Chmod(dir, p); err != nil {
2815+ return fmt.Errorf("could not correct %s permissions to %d: %v", dir, p, err)
2816+ }
2817+ }
2818+ if statT.Uid != 0 && statT.Uid != uint32(uid) {
2819+ if err := os.Chown(dir, uid, -1); err != nil {
2820+ return fmt.Errorf("could not correct %s owner to %d: %v", dir, uid, err)
2821+ }
2822+ }
2823+ if statT.Gid != 0 && statT.Gid != uint32(gid) {
2824+ if err := os.Chown(dir, -1, gid); err != nil {
2825+ return fmt.Errorf("could not correct %s group to %d: %v", dir, gid, err)
2826+ }
2827+ }
2828+ } else {
2829+ parent, _ := path.Split(dir)
2830+ if parent != "/" && parent != "" {
2831+ if err := mkdirpWithPerms(parent, p, uid, gid); err != nil {
2832+ return err
2833+ }
2834+ }
2835+ if err := os.Mkdir(dir, p); err != nil {
2836+ return err
2837+ }
2838+ }
2839+ return nil
2840+}
2841+
2842+func morePermissive(i, j int) bool {
2843+ for k := 0; k < 3; k++ {
2844+ if (i % 010) > (j % 10) {
2845+ return true
2846+ }
2847+ i = i / 010
2848+ j = j / 010
2849+ }
2850+ return false
2851+}
2852+
2853+func listen(ctx context.Context, pipe string, filemode int, grp string) (net.Listener, error) {
2854+ // If grp is an int, use it as a GID
2855+ gid, err := strconv.Atoi(grp)
2856+ if err != nil {
2857+ // Otherwise lookup GID
2858+ group, err := user.LookupGroup(grp)
2859+ if err != nil {
2860+ logger.Errorf("guest agent command pipe group %s is not a GID nor a valid group, not changing socket ownership", grp)
2861+ gid = -1
2862+ } else {
2863+ gid, err = strconv.Atoi(group.Gid)
2864+ if err != nil {
2865+ logger.Errorf("os reported group %s has gid %s which is not a valid int, not changing socket ownership. this should never happen", grp, group.Gid)
2866+ gid = -1
2867+ }
2868+ }
2869+ }
2870+ // socket owner group does not need to have permissions to everything in the directory containing it, whatever user and group we are should own that
2871+ user, err := user.Current()
2872+ if err != nil {
2873+ return nil, fmt.Errorf("could not lookup current user")
2874+ }
2875+ currentuid, err := strconv.Atoi(user.Uid)
2876+ if err != nil {
2877+ return nil, fmt.Errorf("os reported user %s has uid %s which is not a valid int, can't determine directory owner. this should never happen", user.Username, user.Uid)
2878+ }
2879+ currentgid, err := strconv.Atoi(user.Gid)
2880+ if err != nil {
2881+ return nil, fmt.Errorf("os reported user %s has gid %s which is not a valid int, can't determine directory owner. this should never happen", user.Username, user.Gid)
2882+ }
2883+ if err := mkdirpWithPerms(path.Dir(pipe), os.FileMode(filemode), currentuid, currentgid); err != nil {
2884+ return nil, err
2885+ }
2886+ // Mutating the umask of the process for this is not ideal, but tightening permissions with chown after creation is not really secure.
2887+ // Lock OS thread while mutating umask so we don't lose a thread with a mutated mask.
2888+ runtime.LockOSThread()
2889+ oldmask := syscall.Umask(777 - filemode)
2890+ var lc net.ListenConfig
2891+ l, err := lc.Listen(ctx, "unix", pipe)
2892+ syscall.Umask(oldmask)
2893+ runtime.UnlockOSThread()
2894+ if err != nil {
2895+ return nil, err
2896+ }
2897+ // But we need to chown anyway to loosen permissions to include whatever group the user has configured
2898+ err = os.Chown(pipe, int(currentuid), gid)
2899+ if err != nil {
2900+ l.Close()
2901+ return nil, err
2902+ }
2903+ return l, nil
2904+}
2905+
2906+func dialPipe(ctx context.Context, pipe string) (net.Conn, error) {
2907+ var dialer net.Dialer
2908+ return dialer.DialContext(ctx, "unix", pipe)
2909+}
2910diff --git a/google_guest_agent/command/command_monitor.go b/google_guest_agent/command/command_monitor.go
2911new file mode 100644
2912index 0000000..6e1ee86
2913--- /dev/null
2914+++ b/google_guest_agent/command/command_monitor.go
2915@@ -0,0 +1,228 @@
2916+// Copyright 2023 Google Inc. All Rights Reserved.
2917+//
2918+// Licensed under the Apache License, Version 2.0 (the "License");
2919+// you may not use this file except in compliance with the License.
2920+// You may obtain a copy of the License at
2921+//
2922+// http://www.apache.org/licenses/LICENSE-2.0
2923+//
2924+// Unless required by applicable law or agreed to in writing, software
2925+// distributed under the License is distributed on an "AS IS" BASIS,
2926+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2927+// See the License for the specific language governing permissions and
2928+// limitations under the License.
2929+
2930+/*
2931+ * This file contains the details of command's internal communication protocol
2932+ * listener. Most callers should not need to call anything in this file. The
2933+ * command handler and caller API is contained in command.go.
2934+ */
2935+
2936+package command
2937+
2938+import (
2939+ "bufio"
2940+ "context"
2941+ "encoding/json"
2942+ "errors"
2943+ "net"
2944+ "os"
2945+ "strconv"
2946+ "sync"
2947+ "time"
2948+
2949+ "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
2950+ "github.com/GoogleCloudPlatform/guest-logging-go/logger"
2951+)
2952+
2953+var cmdMonitor *Monitor = &Monitor{
2954+ handlersMu: new(sync.RWMutex),
2955+ handlers: make(map[string]Handler),
2956+}
2957+
2958+// Init starts an internally managed command server. The agent configuration
2959+// will decide the server options. Returns a reference to the internally managed
2960+// command monitor which the caller can Close() when appropriate.
2961+func Init(ctx context.Context) {
2962+ if cmdMonitor.srv != nil {
2963+ return
2964+ }
2965+ pipe := cfg.Get().Unstable.CommandPipePath
2966+ if pipe == "" {
2967+ pipe = DefaultPipePath
2968+ }
2969+ to, err := time.ParseDuration(cfg.Get().Unstable.CommandRequestTimeout)
2970+ if err != nil {
2971+ logger.Errorf("commmand request timeout configuration is not a valid duration string, falling back to 10s timeout")
2972+ to = time.Duration(10) * time.Second
2973+ }
2974+ var pipemode int64 = 0770
2975+ pipemode, err = strconv.ParseInt(cfg.Get().Unstable.CommandPipeMode, 8, 32)
2976+ if err != nil {
2977+ logger.Errorf("could not parse command_pipe_mode as octal integer: %v falling back to mode 0770", err)
2978+ }
2979+ cmdMonitor.srv = &Server{
2980+ pipe: pipe,
2981+ pipeMode: int(pipemode),
2982+ pipeGroup: cfg.Get().Unstable.CommandPipeGroup,
2983+ timeout: to,
2984+ monitor: cmdMonitor,
2985+ }
2986+ err = cmdMonitor.srv.start(ctx)
2987+ if err != nil {
2988+ logger.Errorf("failed to start command server: %s", err)
2989+ }
2990+}
2991+
2992+// Close will close the internally managed command server, if it was initialized.
2993+func Close() error {
2994+ if cmdMonitor.srv != nil {
2995+ return cmdMonitor.srv.Close()
2996+ }
2997+ return nil
2998+}
2999+
3000+// Monitor is the structure which handles command registration and deregistration.
3001+type Monitor struct {
3002+ srv *Server
3003+ handlersMu *sync.RWMutex
3004+ handlers map[string]Handler
3005+}
3006+
3007+// Close stops the server from listening to commands.
3008+func (m *Monitor) Close() error { return m.srv.Close() }
3009+
3010+// Start begins listening for commands.
3011+func (m *Monitor) Start(ctx context.Context) error { return m.srv.start(ctx) }
3012+
3013+// Server is the server structure which will listen for command requests and
3014+// route them to handlers. Most callers should not interact with this directly.
3015+type Server struct {
3016+ pipe string
3017+ pipeMode int
3018+ pipeGroup string
3019+ timeout time.Duration
3020+ srv net.Listener
3021+ monitor *Monitor
3022+}
3023+
3024+// Close signals the server to stop listening for commands and stop waiting to
3025+// listen.
3026+func (c *Server) Close() error {
3027+ if c.srv != nil {
3028+ return c.srv.Close()
3029+ }
3030+ return nil
3031+}
3032+
3033+func (c *Server) start(ctx context.Context) error {
3034+ if c.srv != nil {
3035+ return errors.New("server already listening")
3036+ }
3037+ srv, err := listen(ctx, c.pipe, c.pipeMode, c.pipeGroup)
3038+ if err != nil {
3039+ return err
3040+ }
3041+ go func() {
3042+ defer srv.Close()
3043+ for {
3044+ if ctx.Err() != nil {
3045+ return
3046+ }
3047+ conn, err := srv.Accept()
3048+ if err != nil {
3049+ if err == net.ErrClosed {
3050+ break
3051+ }
3052+ logger.Infof("error on connection to pipe %s: %v", c.pipe, err)
3053+ continue
3054+ }
3055+ go func(conn net.Conn) {
3056+ defer conn.Close()
3057+ // Go has lots of helpers to do this for us but none of them return the byte
3058+ // slice afterwards, and we need it for the handler
3059+ var b []byte
3060+ r := bufio.NewReader(conn)
3061+ var depth int
3062+ deadline := time.Now().Add(c.timeout)
3063+ e := conn.SetReadDeadline(deadline)
3064+ if e != nil {
3065+ logger.Infof("could not set read deadline on command request: %v", e)
3066+ return
3067+ }
3068+ for {
3069+ if time.Now().After(deadline) {
3070+ if b, err := json.Marshal(TimeoutError); err != nil {
3071+ conn.Write(internalError)
3072+ } else {
3073+ conn.Write(b)
3074+ }
3075+ return
3076+ }
3077+ rune, _, err := r.ReadRune()
3078+ if err != nil {
3079+ logger.Debugf("connection read error: %v", err)
3080+ if errors.Is(err, os.ErrDeadlineExceeded) {
3081+ if b, err := json.Marshal(TimeoutError); err != nil {
3082+ conn.Write(internalError)
3083+ } else {
3084+ conn.Write(b)
3085+ }
3086+ } else {
3087+ if b, err := json.Marshal(ConnError); err != nil {
3088+ conn.Write(internalError)
3089+ } else {
3090+ conn.Write(b)
3091+ }
3092+ }
3093+ return
3094+ }
3095+ b = append(b, byte(rune))
3096+ switch rune {
3097+ case '{':
3098+ depth++
3099+ case '}':
3100+ depth--
3101+ }
3102+ // Must check here because the first pass always depth = 0
3103+ if depth == 0 {
3104+ break
3105+ }
3106+ }
3107+ var req Request
3108+ err := json.Unmarshal(b, &req)
3109+ if err != nil {
3110+ if b, err := json.Marshal(BadRequestError); err != nil {
3111+ conn.Write(internalError)
3112+ } else {
3113+ conn.Write(b)
3114+ }
3115+ return
3116+ }
3117+ c.monitor.handlersMu.RLock()
3118+ defer c.monitor.handlersMu.RUnlock()
3119+ handler, ok := c.monitor.handlers[req.Command]
3120+ if !ok {
3121+ if b, err := json.Marshal(CmdNotFoundError); err != nil {
3122+ conn.Write(internalError)
3123+ } else {
3124+ conn.Write(b)
3125+ }
3126+ return
3127+ }
3128+ resp, err := handler(b)
3129+ if err != nil {
3130+ re := Response{Status: HandlerError.Status, StatusMessage: err.Error()}
3131+ if b, err := json.Marshal(re); err != nil {
3132+ resp = internalError
3133+ } else {
3134+ resp = b
3135+ }
3136+ }
3137+ conn.Write(resp)
3138+ }(conn)
3139+ }
3140+ }()
3141+ c.srv = srv
3142+ return nil
3143+}
3144diff --git a/google_guest_agent/command/command_test.go b/google_guest_agent/command/command_test.go
3145new file mode 100644
3146index 0000000..96c2d9f
3147--- /dev/null
3148+++ b/google_guest_agent/command/command_test.go
3149@@ -0,0 +1,209 @@
3150+// Copyright 2023 Google Inc. All Rights Reserved.
3151+//
3152+// Licensed under the Apache License, Version 2.0 (the "License");
3153+// you may not use this file except in compliance with the License.
3154+// You may obtain a copy of the License at
3155+//
3156+// http://www.apache.org/licenses/LICENSE-2.0
3157+//
3158+// Unless required by applicable law or agreed to in writing, software
3159+// distributed under the License is distributed on an "AS IS" BASIS,
3160+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3161+// See the License for the specific language governing permissions and
3162+// limitations under the License.
3163+
3164+package command
3165+
3166+import (
3167+ "context"
3168+ "encoding/json"
3169+ "fmt"
3170+ "io"
3171+ "math/rand"
3172+ "os/user"
3173+ "path"
3174+ "runtime"
3175+ "sync"
3176+ "testing"
3177+ "time"
3178+
3179+ "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
3180+)
3181+
3182+func cmdServerForTest(t *testing.T, pipeMode int, pipeGroup string, timeout time.Duration) *Server {
3183+ cs := &Server{
3184+ pipe: getTestPipePath(t),
3185+ pipeMode: pipeMode,
3186+ pipeGroup: pipeGroup,
3187+ timeout: timeout,
3188+ monitor: &Monitor{
3189+ handlersMu: new(sync.RWMutex),
3190+ handlers: make(map[string]Handler),
3191+ },
3192+ }
3193+ cs.monitor.srv = cs
3194+ err := cs.start(testctx(t))
3195+ if err != nil {
3196+ t.Fatal(err)
3197+ }
3198+ t.Cleanup(func() {
3199+ err := cs.Close()
3200+ if err != nil {
3201+ t.Errorf("error closing command server: %v", err)
3202+ }
3203+ })
3204+ return cs
3205+}
3206+
3207+func getTestPipePath(t *testing.T) string {
3208+ if runtime.GOOS == "windows" {
3209+ return `\\.\pipe\google-guest-agent-commands-test-` + t.Name()
3210+ }
3211+ return path.Join(t.TempDir(), "run", "pipe")
3212+}
3213+
3214+func testctx(t *testing.T) context.Context {
3215+ d, ok := t.Deadline()
3216+ if !ok {
3217+ ctx, cancel := context.WithCancel(context.Background())
3218+ t.Cleanup(cancel)
3219+ return ctx
3220+ }
3221+ ctx, cancel := context.WithDeadline(context.Background(), d)
3222+ t.Cleanup(cancel)
3223+ return ctx
3224+}
3225+
3226+type testRequest struct {
3227+ Command string
3228+ ArbitraryData int
3229+}
3230+
3231+func TestInit(t *testing.T) {
3232+ cfg.Load(nil)
3233+ cfg.Get().Unstable.CommandPipePath = getTestPipePath(t)
3234+ if cmdMonitor.srv != nil {
3235+ t.Fatal("internal command server already exists")
3236+ }
3237+ Init(testctx(t))
3238+ if cmdMonitor.srv == nil {
3239+ t.Errorf("could not start internally managed command server")
3240+ }
3241+ if err := Close(); err != nil {
3242+ t.Errorf("could not close managed command server: %s", err)
3243+ }
3244+}
3245+
3246+func TestListen(t *testing.T) {
3247+ cu, err := user.Current()
3248+ if err != nil {
3249+ t.Fatalf("could not get current user: %v", err)
3250+ }
3251+ ug, err := cu.GroupIds()
3252+ if err != nil {
3253+ t.Fatalf("could not get user groups for %s: %v", cu.Name, err)
3254+ }
3255+ resp := []byte(`{"Status":0,"StatusMessage":"OK"}`)
3256+ errresp := []byte(`{"Status":1,"StatusMessage":"ERR"}`)
3257+ req := []byte(`{"ArbitraryData":1234,"Command":"TestListen"}`)
3258+ h := func(b []byte) ([]byte, error) {
3259+ var r testRequest
3260+ err := json.Unmarshal(b, &r)
3261+ if err != nil || r.ArbitraryData != 1234 {
3262+ return errresp, nil
3263+ }
3264+ return resp, nil
3265+ }
3266+
3267+ testcases := []struct {
3268+ name string
3269+ filemode int
3270+ group string
3271+ }{
3272+ {
3273+ name: "world read/writeable",
3274+ filemode: 0777,
3275+ group: "-1",
3276+ },
3277+ {
3278+ name: "group read/writeable",
3279+ filemode: 0770,
3280+ group: "-1",
3281+ },
3282+ {
3283+ name: "user read/writeable",
3284+ filemode: 0700,
3285+ group: "-1",
3286+ },
3287+ {
3288+ name: "additional user group as group owner",
3289+ filemode: 0770,
3290+ group: ug[rand.Intn(len(ug))],
3291+ },
3292+ }
3293+ for _, tc := range testcases {
3294+ t.Run(tc.name, func(t *testing.T) {
3295+ cs := cmdServerForTest(t, tc.filemode, tc.group, time.Second)
3296+ err := cs.monitor.RegisterHandler("TestListen", h)
3297+ if err != nil {
3298+ t.Errorf("could not register handler: %v", err)
3299+ }
3300+ d := SendCmdPipe(testctx(t), cs.pipe, req)
3301+ var r Response
3302+ err = json.Unmarshal(d, &r)
3303+ if err != nil {
3304+ t.Error(err)
3305+ }
3306+ if r.Status != 0 || r.StatusMessage != "OK" {
3307+ t.Errorf("unexpected status from test-cmd, want 0, \"OK\" but got %d, %q", r.Status, r.StatusMessage)
3308+ }
3309+ })
3310+ }
3311+}
3312+
3313+func TestHandlerFailure(t *testing.T) {
3314+ req := []byte(`{"Command":"TestHandlerFailure"}`)
3315+ h := func(b []byte) ([]byte, error) {
3316+ return nil, fmt.Errorf("always fail")
3317+ }
3318+
3319+ cs := cmdServerForTest(t, 0777, "-1", time.Second)
3320+ cs.monitor.RegisterHandler("TestHandlerFailure", h)
3321+ d := SendCmdPipe(testctx(t), cs.pipe, req)
3322+ var r Response
3323+ err := json.Unmarshal(d, &r)
3324+ if err != nil {
3325+ t.Error(err)
3326+ }
3327+ if r.Status != HandlerError.Status || r.StatusMessage != "always fail" {
3328+ t.Errorf("unexpected status from TestHandlerFailure, want %d, \"always fail\" but got %d, %q", HandlerError.Status, r.Status, r.StatusMessage)
3329+ }
3330+}
3331+
3332+func TestListenTimeout(t *testing.T) {
3333+ expect, err := json.Marshal(TimeoutError)
3334+ if err != nil {
3335+ t.Fatal(err)
3336+ }
3337+ if runtime.GOOS == "windows" {
3338+ // winio library does not surface timeouts from the underlying net.Conn as
3339+ // timeouts, but as generic errors. Timeouts still work they just can't be
3340+ // detected as timeouts, so they are generic connErrors here.
3341+ expect, err = json.Marshal(ConnError)
3342+ if err != nil {
3343+ t.Fatal(err)
3344+ }
3345+ }
3346+ cs := cmdServerForTest(t, 0770, "-1", time.Millisecond)
3347+ conn, err := dialPipe(testctx(t), cs.pipe)
3348+ if err != nil {
3349+ t.Errorf("could not connect to command server: %v", err)
3350+ }
3351+ data, err := io.ReadAll(conn)
3352+ if err != nil {
3353+ t.Errorf("error reading response from command server: %v", err)
3354+ }
3355+ if string(data) != string(expect) {
3356+ t.Errorf("unexpected response from timed out connection, got %s but want %s", data, expect)
3357+ }
3358+}
3359diff --git a/google_guest_agent/command/command_windows.go b/google_guest_agent/command/command_windows.go
3360new file mode 100644
3361index 0000000..0ac7213
3362--- /dev/null
3363+++ b/google_guest_agent/command/command_windows.go
3364@@ -0,0 +1,104 @@
3365+// Copyright 2023 Google Inc. All Rights Reserved.
3366+//
3367+// Licensed under the Apache License, Version 2.0 (the "License");
3368+// you may not use this file except in compliance with the License.
3369+// You may obtain a copy of the License at
3370+//
3371+// http://www.apache.org/licenses/LICENSE-2.0
3372+//
3373+// Unless required by applicable law or agreed to in writing, software
3374+// distributed under the License is distributed on an "AS IS" BASIS,
3375+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3376+// See the License for the specific language governing permissions and
3377+// limitations under the License.
3378+
3379+package command
3380+
3381+import (
3382+ "context"
3383+ "fmt"
3384+ "net"
3385+ "os/user"
3386+
3387+ "github.com/GoogleCloudPlatform/guest-logging-go/logger"
3388+ "github.com/Microsoft/go-winio"
3389+)
3390+
3391+const (
3392+ // DefaultPipePath is the default named pipe path for windows.
3393+ DefaultPipePath = `\\.\pipe\google-guest-agent-commands`
3394+ nullSID = "S-1-0-0"
3395+ worldSID = "S-1-1-0"
3396+ creatorOwnerSID = "S-1-3-0"
3397+ creatorGroupSID = "S-1-3-1"
3398+)
3399+
3400+func genSecurityDescriptor(filemode int, grp string) string {
3401+ // This function translates the intention of a unix file mode and owner group into an appropriate SDDL security descriptor for a windows named pipe.
3402+ owner := creatorOwnerSID
3403+ group := creatorGroupSID
3404+
3405+ wPerm := filemode % 010
3406+ filemode /= 010
3407+ gPerm := filemode % 010
3408+ filemode /= 010
3409+ uPerm := filemode % 010
3410+
3411+ // Having only read or only write access to a bidirectional pipe is pointless so we treat access for user/group as yes or no based on whether the permission grants RW access
3412+ if uPerm < 06 {
3413+ owner = nullSID
3414+ }
3415+ if gPerm < 06 {
3416+ group = nullSID
3417+ }
3418+ // If permissions grant world RW, make world the owner
3419+ if wPerm > 05 {
3420+ owner = worldSID
3421+ group = worldSID
3422+ }
3423+
3424+ // Group is handled as supplemental DACL, but ignore it if user specified no group rw permission
3425+ var dacl string
3426+ if gPerm > 05 {
3427+ g, err := user.LookupGroup(grp)
3428+ if err != nil {
3429+ logger.Errorf("Could not lookup group %s SID, this group will not be included in the command server security descriptor: %v", grp, err)
3430+ } else {
3431+ // Allow access;Protected DACL;Allow all general access;Empty object guid;Empty inherit object guid;group sid from lookup
3432+ dacl = fmt.Sprintf("D:(A;P;GA;;;%s)", g.Gid)
3433+ }
3434+ }
3435+
3436+ sddl := "O:%sG:%s%s"
3437+ return fmt.Sprintf(sddl, owner, group, dacl)
3438+}
3439+
3440+func listen(ctx context.Context, path string, filemode int, group string) (net.Listener, error) {
3441+ // Winio library does not provide any method to listen on context. Failing to
3442+ // specify a pipeconfig (or using the zero value) results in flaky ACCESS_DENIED
3443+ // errors when re-opening the same pipe (~1/10).
3444+ // https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#remarks
3445+ // Even with a pipeconfig, this flakes ~1/200 runs, hence the retry until the
3446+ // context is expired or listen is successful.
3447+ var l net.Listener
3448+ var lastError error
3449+ for {
3450+ if ctx.Err() != nil {
3451+ return nil, fmt.Errorf("context expired: %v before successful listen (last error: %v)", ctx.Err(), lastError)
3452+ }
3453+ config := &winio.PipeConfig{
3454+ MessageMode: false,
3455+ InputBufferSize: 1024,
3456+ OutputBufferSize: 1024,
3457+ SecurityDescriptor: genSecurityDescriptor(filemode, group),
3458+ }
3459+ l, lastError = winio.ListenPipe(path, config)
3460+ if lastError == nil {
3461+ return l, lastError
3462+ }
3463+ }
3464+}
3465+
3466+func dialPipe(ctx context.Context, pipe string) (net.Conn, error) {
3467+ return winio.DialPipeContext(ctx, pipe)
3468+}
3469diff --git a/google_guest_agent/command/command_windows_test.go b/google_guest_agent/command/command_windows_test.go
3470new file mode 100644
3471index 0000000..b4789e2
3472--- /dev/null
3473+++ b/google_guest_agent/command/command_windows_test.go
3474@@ -0,0 +1,73 @@
3475+//go:build windows
3476+
3477+// Copyright 2023 Google Inc. All Rights Reserved.
3478+//
3479+// Licensed under the Apache License, Version 2.0 (the "License");
3480+// you may not use this file except in compliance with the License.
3481+// You may obtain a copy of the License at
3482+//
3483+// http://www.apache.org/licenses/LICENSE-2.0
3484+//
3485+// Unless required by applicable law or agreed to in writing, software
3486+// distributed under the License is distributed on an "AS IS" BASIS,
3487+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3488+// See the License for the specific language governing permissions and
3489+// limitations under the License.
3490+package command
3491+
3492+import (
3493+ "os/user"
3494+ "testing"
3495+)
3496+
3497+func TestGenSecurityDescriptor(t *testing.T) {
3498+ guest, err := user.LookupGroup("Guests")
3499+ if err != nil {
3500+ t.Fatal(err)
3501+ }
3502+ testcases := []struct {
3503+ name string
3504+ filemode int
3505+ group string
3506+ output string
3507+ }{
3508+ {
3509+ name: "world writeable",
3510+ filemode: 0777,
3511+ group: nullSID,
3512+ output: "O:" + worldSID + "G:" + worldSID,
3513+ },
3514+ {
3515+ name: "user+group writable",
3516+ filemode: 0770,
3517+ group: "",
3518+ output: "O:" + creatorOwnerSID + "G:" + creatorGroupSID,
3519+ },
3520+ {
3521+ name: "user writable",
3522+ filemode: 0700,
3523+ group: nullSID,
3524+ output: "O:" + creatorOwnerSID + "G:" + nullSID,
3525+ },
3526+ {
3527+ name: "no write permissions",
3528+ filemode: 000,
3529+ group: nullSID,
3530+ output: "O:" + nullSID + "G:" + nullSID,
3531+ },
3532+ {
3533+ name: "custom named group",
3534+ filemode: 0770,
3535+ group: "Guests",
3536+ output: "O:" + creatorOwnerSID + "G:" + creatorGroupSID + "D:(A;P;GA;;;" + guest.Gid + ")",
3537+ },
3538+ }
3539+ for _, tc := range testcases {
3540+ t.Run(tc.name, func(t *testing.T) {
3541+ sd := genSecurityDescriptor(tc.filemode, tc.group)
3542+ if sd != tc.output {
3543+ t.Errorf("unexpected output from genSecurityDescriptor(%d, %s), got %s want %s", tc.filemode, tc.group, sd, tc.output)
3544+ }
3545+ })
3546+ }
3547+}
3548diff --git a/google_guest_agent/diagnostics.go b/google_guest_agent/diagnostics.go
3549index b2d239b..62cde39 100644
3550--- a/google_guest_agent/diagnostics.go
3551+++ b/google_guest_agent/diagnostics.go
3552@@ -19,6 +19,7 @@ import (
3553 "encoding/json"
3554 "reflect"
3555 "runtime"
3556+ "slices"
3557 "sync/atomic"
3558
3559 "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
3560@@ -94,7 +95,7 @@ func (d *diagnosticsMgr) Set(ctx context.Context) error {
3561 }
3562
3563 strEntry := newMetadata.Instance.Attributes.Diagnostics
3564- if utils.ContainsString(strEntry, diagnosticsEntries) {
3565+ if slices.Contains(diagnosticsEntries, strEntry) {
3566 return nil
3567 }
3568 diagnosticsEntries = append(diagnosticsEntries, strEntry)
3569diff --git a/google_guest_agent/events/events.go b/google_guest_agent/events/events.go
3570index 2bf57e3..24d10dd 100644
3571--- a/google_guest_agent/events/events.go
3572+++ b/google_guest_agent/events/events.go
3573@@ -21,13 +21,14 @@ import (
3574 "sync"
3575
3576 "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/events/metadata"
3577- "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/events/sshtrustedca"
3578 "github.com/GoogleCloudPlatform/guest-logging-go/logger"
3579 )
3580
3581 var (
3582- // availableWatchers mapps all kown available event watchers.
3583- availableWatchers = make(map[string]Watcher)
3584+ defaultWatchers = []Watcher{
3585+ metadata.New(),
3586+ }
3587+ instance *Manager
3588 )
3589
3590 // Watcher defines the interface between the events manager and the actual
3591@@ -48,14 +49,65 @@ type Watcher interface {
3592 // Manager defines the interface between events management layer and the
3593 // core guest agent implementation.
3594 type Manager struct {
3595- watchers map[string]Watcher
3596+ // watcherEvents maps the registered watchers and their events.
3597+ watcherEvents []*WatcherEventType
3598+
3599+ // watchersMap is a convenient manager's mapping of registered watcher instances.
3600+ watchersMap map[string]bool
3601+
3602+ // watchersMutex protects the watchers map.
3603+ watchersMutex sync.Mutex
3604+
3605+ // removingWatcherEvents is a map of watchers being removed.
3606+ removingWatcherEvents map[string]bool
3607+
3608+ // running is a flag indicating if the Run() was previously called.
3609+ running bool
3610+
3611+ // runningMutex protects the running flag.
3612+ runningMutex sync.RWMutex
3613+
3614+ // subscribers maps the subscribed callbacks.
3615 subscribers map[string][]*eventSubscriber
3616+
3617+ // subscribersMutex protects subscribers member/map of the manager object.
3618+ subscribersMutex sync.Mutex
3619+
3620+ // queue queue struct manages the running watchers, when it gets to len()
3621+ // down to zero means all watchers are done and we can signal the other
3622+ // control go routines to leave(given we don't have any more job left to
3623+ // process).
3624+ queue *watcherQueue
3625 }
3626
3627-// Config offers a mechanism for the consumers to configure the Manager behavior.
3628-type Config struct {
3629- // Watchers lists the enabled watchers, of not provided all available watchers will be enabled.
3630- Watchers []string
3631+// watcherQueue wraps the watchers <-> callbacks communication as well as the
3632+// communication/coordination of the multiple control go routine i.e. the one
3633+// responsible to calling callbacks after a event is produced by the watcher etc.
3634+type watcherQueue struct {
3635+ // queueMutex protects the access to watchersMap
3636+ queueMutex sync.RWMutex
3637+
3638+ // watchersMap maps the currently running watchers.
3639+ watchersMap map[string]bool
3640+
3641+ // finishContextHandler is a channel used to communicate with the context handling
3642+ // go routine that it should finish/end its job (usually after all watchers are done).
3643+ finishContextHandler chan bool
3644+
3645+ // finishCallbackHandler is a channel used to communicate with the callback handling
3646+ // go routine that it should finish/end its job (usually after all watchers are done).
3647+ finishCallbackHandler chan bool
3648+
3649+ // watcherDone is a channel used to communicate that a given watcher is finished/done.
3650+ watcherDone chan string
3651+
3652+ // dataBus is the channel used to communicate between watchers (event producer) and the
3653+ // callback handler (event consumer managing go routine).
3654+ dataBus chan eventBusData
3655+
3656+ // leaving is a flag that indicates no more job should be processed as we are done
3657+ // with all watchers and callbacks.
3658+ leaving bool
3659 }
3660
3661 // EventData wraps the data communicated from a Watcher to a Subscriber.
3662@@ -72,6 +124,20 @@ type WatcherEventType struct {
3663 watcher Watcher
3664 // evType idenfities the event type this object refences to.
3665 evType string
3666+ // removed is a channel used to communicate with the running watcher go routine
3667+ // that it shouldn't renew even if the watcher requested a renew (in response of a
3668+ // RemoveWatcher() call.
3669+ removed chan bool
3670+}
3671+
3672+type eventSubscriber struct {
3673+ data interface{}
3674+ cb *EventCb
3675+}
3676+
3677+type eventBusData struct {
3678+ evType string
3679+ data *EventData
3680 }
3681
3682 // EventCb defines the callback interface between watchers and subscribers. The arguments are:
3683@@ -84,224 +150,328 @@ type WatcherEventType struct {
3684 // to be unregistered/unsubscribed.
3685 type EventCb func(ctx context.Context, evType string, data interface{}, evData *EventData) bool
3686
3687-type eventSubscriber struct {
3688- data interface{}
3689- cb EventCb
3690+// length returns how many watchers are currently running.
3691+func (ep *watcherQueue) length() int {
3692+ ep.queueMutex.RLock()
3693+ defer ep.queueMutex.RUnlock()
3694+ return len(ep.watchersMap)
3695 }
3696
3697-type eventBusData struct {
3698- evType string
3699- data *EventData
3700+// add adds a new watcher to the queue.
3701+func (ep *watcherQueue) add(evType string) {
3702+ ep.queueMutex.Lock()
3703+ defer ep.queueMutex.Unlock()
3704+ ep.watchersMap[evType] = true
3705 }
3706
3707-func init() {
3708- err := initWatchers([]Watcher{
3709- metadata.New(),
3710- sshtrustedca.New(sshtrustedca.DefaultPipePath),
3711- })
3712- if err != nil {
3713- logger.Errorf("Failed to initialize watchers: %+v", err)
3714- }
3715+// del removes a watcher from the queue.
3716+func (ep *watcherQueue) del(evType string) int {
3717+ ep.queueMutex.Lock()
3718+ defer ep.queueMutex.Unlock()
3719+ delete(ep.watchersMap, evType)
3720+ return len(ep.watchersMap)
3721 }
3722
3723-// init initializes the known available event watchers.
3724-func initWatchers(watchers []Watcher) error {
3725- for _, curr := range watchers {
3726- // Error if we are accidentaly not properly setting the id.
3727- if curr.ID() == "" {
3728- return fmt.Errorf("invalid event watcher id, skipping")
3729+// AddDefaultWatchers add the default watchers:
3730+// - metadata
3731+func (mngr *Manager) AddDefaultWatchers(ctx context.Context) error {
3732+ for _, curr := range defaultWatchers {
3733+ if err := mngr.AddWatcher(ctx, curr); err != nil {
3734+ return err
3735 }
3736- availableWatchers[curr.ID()] = curr
3737 }
3738 return nil
3739 }
3740
3741-// New allocates and initializes a events Manager based on provided cfg.
3742-func New(cfg *Config) (*Manager, error) {
3743- res := &Manager{
3744- watchers: availableWatchers,
3745- subscribers: make(map[string][]*eventSubscriber),
3746+// newManager allocates and initializes a events Manager.
3747+func newManager() *Manager {
3748+ return &Manager{
3749+ watchersMap: make(map[string]bool),
3750+ removingWatcherEvents: make(map[string]bool),
3751+ subscribers: make(map[string][]*eventSubscriber),
3752+ queue: &watcherQueue{
3753+ watchersMap: make(map[string]bool),
3754+ dataBus: make(chan eventBusData),
3755+ finishCallbackHandler: make(chan bool),
3756+ finishContextHandler: make(chan bool),
3757+ watcherDone: make(chan string),
3758+ },
3759 }
3760+}
3761
3762- // Align manager's config based on consumers provided watchers If it is
3763- // passing in wanted/expected watchers, otherwise use all available ones.
3764- if cfg != nil && len(cfg.Watchers) > 0 {
3765- res.watchers = make(map[string]Watcher)
3766- for _, curr := range cfg.Watchers {
3767- // Report back if we don't know the provided watcher id.
3768- if _, found := availableWatchers[curr]; !found {
3769- return nil, fmt.Errorf("invalid/unknown watcher id: %s", curr)
3770- }
3771- res.watchers[curr] = availableWatchers[curr]
3772- }
3773- }
3774+func init() {
3775+ instance = newManager()
3776+}
3777
3778- return res, nil
3779+// Get allocates a new manager if one doesn't exists or returns the one previously allocated.
3780+func Get() *Manager {
3781+ if instance == nil {
3782+ panic("The event's manager instance should had being initialized.")
3783+ }
3784+ return instance
3785 }
3786
3787 // Subscribe registers an event consumer/subscriber callback to a given event type, data
3788 // is a context pointer provided by the caller to be passed down when calling cb when
3789 // a new event happens.
3790 func (mngr *Manager) Subscribe(evType string, data interface{}, cb EventCb) {
3791+ mngr.subscribersMutex.Lock()
3792+ defer mngr.subscribersMutex.Unlock()
3793 mngr.subscribers[evType] = append(mngr.subscribers[evType],
3794 &eventSubscriber{
3795 data: data,
3796- cb: cb,
3797+ cb: &cb,
3798 },
3799 )
3800 }
3801
3802-func (mngr *Manager) eventTypes() []*WatcherEventType {
3803- var res []*WatcherEventType
3804- for _, watcher := range mngr.watchers {
3805- for _, evType := range watcher.Events() {
3806- res = append(res, &WatcherEventType{watcher, evType})
3807+func (mngr *Manager) unsubscribe(evType string, cb *EventCb) {
3808+ var keepMe []*eventSubscriber
3809+ for _, curr := range mngr.subscribers[evType] {
3810+ if curr.cb != cb {
3811+ keepMe = append(keepMe, curr)
3812 }
3813 }
3814- return res
3815+
3816+ mngr.subscribers[evType] = keepMe
3817+
3818+ if len(keepMe) == 0 {
3819+ logger.Debugf("No more subscribers left for evType: %s", evType)
3820+ delete(mngr.subscribers, evType)
3821+ }
3822 }
3823
3824-type watcherQueue struct {
3825- mutex sync.Mutex
3826- watchersMap map[string]bool
3827+// Unsubscribe removes the subscription of a given callback for a given event type.
3828+func (mngr *Manager) Unsubscribe(evType string, cb EventCb) {
3829+ mngr.subscribersMutex.Lock()
3830+ defer mngr.subscribersMutex.Unlock()
3831+ mngr.unsubscribe(evType, &cb)
3832 }
3833
3834-func (ep *watcherQueue) add(evType string) {
3835- ep.mutex.Lock()
3836- defer ep.mutex.Unlock()
3837- ep.watchersMap[evType] = true
3838+// RemoveWatcher removes a watcher from the event manager. Each running watcher has its own
3839+// context (derived from the one provided in the AddWatcher() call) and will have it canceled
3840+// after calling this method.
3841+func (mngr *Manager) RemoveWatcher(ctx context.Context, watcher Watcher) error {
3842+ mngr.watchersMutex.Lock()
3843+ defer mngr.watchersMutex.Unlock()
3844+
3845+ id := watcher.ID()
3846+ logger.Debugf("Got a request to remove watcher: %s", id)
3847+ if _, found := mngr.watchersMap[id]; !found {
3848+ return fmt.Errorf("unknown Watcher(%s)", id)
3849+ }
3850+
3851+ for _, curr := range mngr.watcherEvents {
3852+ if _, found := mngr.removingWatcherEvents[curr.evType]; found {
3853+ logger.Debugf("Watcher(%s) is being removed, skipping removal request: %s", id, curr.evType)
3854+ continue
3855+ }
3856+
3857+ if curr.watcher.ID() == id {
3858+ mngr.removingWatcherEvents[curr.evType] = true
3859+ logger.Debugf("Removing watcher: %s, event type: %s", id, curr.evType)
3860+ curr.removed <- true
3861+ }
3862+ }
3863+
3864+ return nil
3865 }
3866
3867-func (ep *watcherQueue) del(evType string) int {
3868- ep.mutex.Lock()
3869- defer ep.mutex.Unlock()
3870- delete(ep.watchersMap, evType)
3871- return len(ep.watchersMap)
3872+// AddWatcher adds/enables a new watcher. The watcher will be fired up right away if the
3873+// event manager is already running, otherwise it's scheduled to run when Run() is called.
3874+func (mngr *Manager) AddWatcher(ctx context.Context, watcher Watcher) error {
3875+ mngr.watchersMutex.Lock()
3876+ defer mngr.watchersMutex.Unlock()
3877+ id := watcher.ID()
3878+ if _, found := mngr.watchersMap[id]; found {
3879+ return fmt.Errorf("watcher(%s) was previously added", id)
3880+ }
3881+
3882+ // Add the watchers and its events to internal mappings.
3883+ evTypes := make(map[string]*WatcherEventType)
3884+ mngr.watchersMap[id] = true
3885+
3886+ for _, curr := range watcher.Events() {
3887+ evType := &WatcherEventType{
3888+ watcher: watcher,
3889+ evType: curr,
3890+ removed: make(chan bool),
3891+ }
3892+
3893+ evTypes[curr] = evType
3894+ mngr.watcherEvents = append(mngr.watcherEvents, evType)
3895+ }
3896+
3897+ mngr.runningMutex.RLock()
3898+ defer mngr.runningMutex.RUnlock()
3899+ // If we are not running don't bother "running" the watcher, Run() will do it later.
3900+ if !mngr.running {
3901+ return nil
3902+ }
3903+
3904+ // If we are already running the "run/launch" the watcher.
3905+ for _, curr := range watcher.Events() {
3906+ logger.Debugf("Adding watcher for event: %s", curr)
3907+ mngr.queue.add(curr)
3908+ go func(watcher Watcher, evType string, removed chan bool) {
3909+ mngr.runWatcher(ctx, watcher, evType, removed)
3910+ }(watcher, curr, evTypes[curr].removed)
3911+ }
3912+
3913+ return nil
3914+}
3915+
3916+func (mngr *Manager) runWatcher(ctx context.Context, watcher Watcher, evType string, removed chan bool) {
3917+ nCtx, cancel := context.WithCancel(ctx)
3918+ abort := false
3919+ id := watcher.ID()
3920+
3921+ go func() {
3922+ abort = <-removed
3923+ logger.Debugf("Got a request to abort watcher(%s) for event: %s", id, evType)
3924+ cancel()
3925+ }()
3926+
3927+ for renew := true; renew; {
3928+ var evData interface{}
3929+ var err error
3930+
3931+ renew, evData, err = watcher.Run(nCtx, evType)
3932+
3933+ logger.Debugf("Watcher(%s) returned event: %q, should renew?: %t", id, evType, renew)
3934+
3935+ if abort || mngr.queue.leaving {
3936+ logger.Debugf("Watcher(%s), either are aborting(%t) or leaving(%t), breaking renew cycle",
3937+ id, abort, mngr.queue.leaving)
3938+ break
3939+ }
3940+
3941+ mngr.queue.dataBus <- eventBusData{
3942+ evType: evType,
3943+ data: &EventData{
3944+ Data: evData,
3945+ Error: err,
3946+ },
3947+ }
3948+ }
3949+
3950+ logger.Debugf("watcher finishing: %s", evType)
3951+ if !abort {
3952+ removed <- true
3953+ }
3954+
3955+ mngr.queue.watcherDone <- evType
3956 }
3957
3958 // Run runs the event manager, it will block until all watchers have given up/failed.
3959-func (mngr *Manager) Run(ctx context.Context) {
3960+// The event manager is meant to be started right after the early initialization code
3961+// and live until the application ends, the event manager can not be restarted - the Run()
3962+// method will return an error if one tries to run it twice.
3963+func (mngr *Manager) Run(ctx context.Context) error {
3964 var wg sync.WaitGroup
3965- var leaving bool
3966-
3967- syncBus := make(chan eventBusData)
3968- defer close(syncBus)
3969
3970- cancelContext := make(chan bool)
3971- cancelCallback := make(chan bool)
3972+ mngr.runningMutex.Lock()
3973+ if mngr.running {
3974+ mngr.runningMutex.Unlock()
3975+ return fmt.Errorf("tried calling event manager's Run() twice")
3976+ }
3977+ mngr.running = true
3978+ mngr.runningMutex.Unlock()
3979
3980- defer close(cancelContext)
3981- defer close(cancelCallback)
3982+ queue := mngr.queue
3983
3984 // Manages the context's done signal, pass it down to the other go routines to
3985 // finish its job and leave. Additionally, if the remaining go routines are leaving
3986- // we get it handled via syncBus channel and drop this go routine as well.
3987+ // we get it handled via dataBus channel and drop this go routine as well.
3988 wg.Add(1)
3989- go func(done <-chan struct{}, cancelContext <-chan bool, cancelCallback chan<- bool) {
3990+ go func(done <-chan struct{}, finishContextHandler <-chan bool, finishCallbackHandler chan<- bool) {
3991 defer wg.Done()
3992
3993 for {
3994 select {
3995 case <-done:
3996 logger.Debugf("Got context's Done() signal, leaving.")
3997- leaving = true
3998- cancelCallback <- true
3999+ queue.leaving = true
4000+ finishCallbackHandler <- true
4001 return
4002- case <-cancelContext:
4003- leaving = true
4004+ case <-finishContextHandler:
4005+ logger.Debugf("Got context handler finish signal, leaving.")
4006+ queue.leaving = true
4007 return
4008 }
4009 }
4010- }(ctx.Done(), cancelContext, cancelCallback)
4011+ }(ctx.Done(), queue.finishContextHandler, queue.finishCallbackHandler)
4012
4013 // Manages the event processing avoiding blocking the watcher's go routines.
4014- // This will listen to syncBus and call the events handlers/callbacks.
4015+ // This will listen to dataBus and call the events handlers/callbacks.
4016 wg.Add(1)
4017- go func(bus <-chan eventBusData, cancelCallback <-chan bool) {
4018+ go func(bus <-chan eventBusData, finishCallbackHandler <-chan bool) {
4019 defer wg.Done()
4020
4021 for {
4022 select {
4023- case <-cancelCallback:
4024+ case <-finishCallbackHandler:
4025 return
4026 case busData := <-bus:
4027- subscribers, found := mngr.subscribers[busData.evType]
4028- if !found || len(subscribers) == 0 {
4029+ subscribers := mngr.subscribers[busData.evType]
4030+ if subscribers == nil {
4031 logger.Debugf("No subscriber found for event: %s, returning.", busData.evType)
4032 continue
4033 }
4034
4035- keepMe := make([]*eventSubscriber, 0)
4036- for _, curr := range mngr.subscribers[busData.evType] {
4037+ deleteMe := make([]*eventSubscriber, 0)
4038+ for _, curr := range subscribers {
4039 logger.Debugf("Running registered callback for event: %s", busData.evType)
4040- renew := curr.cb(ctx, busData.evType, curr.data, busData.data)
4041- if renew {
4042- keepMe = append(keepMe, curr)
4043+ renew := (*curr.cb)(ctx, busData.evType, curr.data, busData.data)
4044+ if !renew {
4045+ deleteMe = append(deleteMe, curr)
4046 }
4047 logger.Debugf("Returning from event %q subscribed callback, should renew?: %t", busData.evType, renew)
4048 }
4049
4050- mngr.subscribers[busData.evType] = keepMe
4051-
4052- // No more subscribers for this event type, delete it from the subscribers map.
4053- if len(keepMe) == 0 {
4054- logger.Debugf("No more subscribers left for evType: %s", busData.evType)
4055- delete(mngr.subscribers, busData.evType)
4056+ mngr.subscribersMutex.Lock()
4057+ for _, curr := range deleteMe {
4058+ mngr.unsubscribe(busData.evType, curr.cb)
4059 }
4060+ leave := mngr.subscribers[busData.evType] == nil
4061+ mngr.subscribersMutex.Unlock()
4062
4063 // No more subscribers at all, we have nothing more left to do here.
4064- if len(mngr.subscribers) == 0 {
4065+ if leave {
4066 logger.Debugf("No subscribers left, leaving")
4067 break
4068 }
4069 }
4070 }
4071- }(syncBus, cancelCallback)
4072-
4073- // This control struct manages the registered watchers, when it gets to len()
4074- // down to zero means all watchers are done and we can signal the other 2 control
4075- // go routines to leave(given we don't have any more job left to process).
4076- control := &watcherQueue{
4077- watchersMap: make(map[string]bool),
4078- }
4079+ }(queue.dataBus, queue.finishCallbackHandler)
4080
4081 // Creates a goroutine for each registered watcher's event and keep handling its
4082 // execution until they give up/finishes their job by returning renew = false.
4083- for _, curr := range mngr.eventTypes() {
4084- control.add(curr.evType)
4085- wg.Add(1)
4086-
4087- go func(bus chan<- eventBusData, watcher Watcher, evType string, cancelContext chan<- bool, cancelCallback chan<- bool) {
4088- var evData interface{}
4089- var err error
4090-
4091- defer wg.Done()
4092-
4093- for renew := true; renew; {
4094- renew, evData, err = watcher.Run(ctx, evType)
4095-
4096- logger.Debugf("Watcher(%s) returned event: %q, should renew?: %t", watcher.ID(), evType, renew)
4097-
4098- if leaving {
4099- break
4100- }
4101+ for _, curr := range mngr.watcherEvents {
4102+ queue.add(curr.evType)
4103+ go func(watcher Watcher, evType string, removed chan bool) {
4104+ mngr.runWatcher(ctx, watcher, evType, removed)
4105+ }(curr.watcher, curr.evType, curr.removed)
4106+ }
4107
4108- bus <- eventBusData{
4109- evType: evType,
4110- data: &EventData{
4111- Data: evData,
4112- Error: err,
4113- },
4114- }
4115- }
4116+ // Controls the completion of the watcher go routines, their removal from the queue
4117+ // and signals to context & callback control go routines about watchers completion.
4118+ wg.Add(1)
4119+ go func() {
4120+ defer wg.Done()
4121
4122- if !leaving && control.del(evType) == 0 {
4123+ for len := queue.length(); len > 0; {
4124+ doneStr := <-queue.watcherDone
4125+ len = queue.del(doneStr)
4126+ delete(mngr.removingWatcherEvents, doneStr)
4127+ if !queue.leaving && len == 0 {
4128 logger.Debugf("All watchers are finished, signaling to leave.")
4129- cancelContext <- true
4130- cancelCallback <- true
4131+ queue.finishContextHandler <- true
4132+ queue.finishCallbackHandler <- true
4133 }
4134- }(syncBus, curr.watcher, curr.evType, cancelContext, cancelCallback)
4135- }
4136+ }
4137+ }()
4138
4139 wg.Wait()
4140+ return nil
4141 }
4142diff --git a/google_guest_agent/events/events_test.go b/google_guest_agent/events/events_test.go
4143index 6fcb2c3..a77bfb9 100644
4144--- a/google_guest_agent/events/events_test.go
4145+++ b/google_guest_agent/events/events_test.go
4146@@ -17,48 +17,26 @@ package events
4147 import (
4148 "context"
4149 "fmt"
4150+ "sync"
4151 "testing"
4152 "time"
4153
4154 "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/events/metadata"
4155 )
4156
4157-func TestConstructor(t *testing.T) {
4158- tests := []struct {
4159- config *Config
4160- success bool
4161- }{
4162- {config: nil, success: true},
4163- {config: &Config{Watchers: []string{metadata.WatcherID}}, success: true},
4164- {config: &Config{Watchers: []string{"foobar"}}, success: false},
4165- }
4166+func TestAddWatcher(t *testing.T) {
4167+ eventManager := newManager()
4168+ metadataWatcher := metadata.New()
4169+ ctx := context.Background()
4170
4171- for i, tt := range tests {
4172- t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
4173- _, err := New(tt.config)
4174- if err != nil && tt.success {
4175- t.Errorf("expected success, got error: %+v", err)
4176- }
4177- })
4178- }
4179-}
4180-
4181-func TestInitWatcers(t *testing.T) {
4182- tests := []struct {
4183- watchers []Watcher
4184- success bool
4185- }{
4186- {watchers: []Watcher{metadata.New()}, success: true},
4187- {watchers: []Watcher{&testWatcher{}}, success: false},
4188+ err := eventManager.AddWatcher(ctx, metadataWatcher)
4189+ if err != nil {
4190+ t.Errorf("expected success, got error: %+v", err)
4191 }
4192
4193- for i, tt := range tests {
4194- t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
4195- err := initWatchers(tt.watchers)
4196- if err != nil && tt.success {
4197- t.Errorf("expected success, got error: %+v", err)
4198- }
4199- })
4200+ err = eventManager.AddWatcher(ctx, metadataWatcher)
4201+ if err == nil {
4202+ t.Errorf("expected error, had success, event manager shouldn't add same watcher twice")
4203 }
4204 }
4205
4206@@ -91,20 +69,16 @@ func TestRun(t *testing.T) {
4207 watcherID := "test-watcher"
4208 maxCount := 10
4209
4210- err := initWatchers([]Watcher{
4211- &testWatcher{
4212- watcherID: watcherID,
4213- maxCount: maxCount,
4214- },
4215- })
4216+ ctx := context.Background()
4217+ eventManager := newManager()
4218
4219- if err != nil {
4220- t.Fatalf("Failed to init/register watcher: %+v", err)
4221- }
4222+ err := eventManager.AddWatcher(ctx, &testWatcher{
4223+ watcherID: watcherID,
4224+ maxCount: maxCount,
4225+ })
4226
4227- eventManager, err := New(&Config{Watchers: []string{watcherID}})
4228 if err != nil {
4229- t.Fatalf("Failed to init event manager: %+v", err)
4230+ t.Fatalf("Failed to add watcher to event manager: %+v", err)
4231 }
4232
4233 counter := 0
4234@@ -114,7 +88,9 @@ func TestRun(t *testing.T) {
4235 return true
4236 })
4237
4238- eventManager.Run(context.Background())
4239+ if err := eventManager.Run(ctx); err != nil {
4240+ t.Errorf("Failed to run event managed, expected success, got error: %+v", err)
4241+ }
4242
4243 if counter != maxCount {
4244 t.Errorf("Failed to increment callback counter, expected: %d, got: %d", maxCount, counter)
4245@@ -126,20 +102,16 @@ func TestUnsubscribe(t *testing.T) {
4246 maxCount := 10
4247 unsubscribeAt := 2
4248
4249- err := initWatchers([]Watcher{
4250- &testWatcher{
4251- watcherID: watcherID,
4252- maxCount: maxCount,
4253- },
4254- })
4255+ ctx := context.Background()
4256+ eventManager := newManager()
4257
4258- if err != nil {
4259- t.Fatalf("Failed to init/register watcher: %+v", err)
4260- }
4261+ err := eventManager.AddWatcher(ctx, &testWatcher{
4262+ watcherID: watcherID,
4263+ maxCount: maxCount,
4264+ })
4265
4266- eventManager, err := New(&Config{Watchers: []string{watcherID}})
4267 if err != nil {
4268- t.Fatalf("Failed to init event manager: %+v", err)
4269+ t.Fatalf("Failed to add watcher to event manager: %+v", err)
4270 }
4271
4272 counter := 0
4273@@ -151,7 +123,9 @@ func TestUnsubscribe(t *testing.T) {
4274 return true
4275 })
4276
4277- eventManager.Run(context.Background())
4278+ if err := eventManager.Run(ctx); err != nil {
4279+ t.Errorf("Failed to run event managed, expected success, got error: %+v", err)
4280+ }
4281
4282 if counter != unsubscribeAt {
4283 t.Errorf("Failed to unsubscribe callback, expected: %d, got: %d", unsubscribeAt, counter)
4284@@ -162,20 +136,16 @@ func TestCancelBeforeCallbacks(t *testing.T) {
4285 watcherID := "test-watcher"
4286 timeout := (1 * time.Second) / 100
4287
4288- err := initWatchers([]Watcher{
4289- &testCancel{
4290- watcherID: watcherID,
4291- timeout: timeout,
4292- },
4293- })
4294+ ctx, cancel := context.WithCancel(context.Background())
4295+ eventManager := newManager()
4296
4297- if err != nil {
4298- t.Fatalf("Failed to init/register watcher: %+v", err)
4299- }
4300+ err := eventManager.AddWatcher(ctx, &testCancel{
4301+ watcherID: watcherID,
4302+ timeout: timeout,
4303+ })
4304
4305- eventManager, err := New(&Config{Watchers: []string{watcherID}})
4306 if err != nil {
4307- t.Fatalf("Failed to init event manager: %+v", err)
4308+ t.Fatalf("Failed to add watcher to event manager: %+v", err)
4309 }
4310
4311 eventManager.Subscribe("test-watcher,test-event", nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool {
4312@@ -183,13 +153,14 @@ func TestCancelBeforeCallbacks(t *testing.T) {
4313 return true
4314 })
4315
4316- ctx, cancel := context.WithCancel(context.Background())
4317 go func() {
4318 time.Sleep(timeout / 2)
4319 cancel()
4320 }()
4321
4322- eventManager.Run(ctx)
4323+ if err := eventManager.Run(ctx); err != nil {
4324+ t.Errorf("Failed to run event managed, expected success, got error: %+v", err)
4325+ }
4326 }
4327
4328 type testCancel struct {
4329@@ -214,33 +185,30 @@ func TestCancelAfterCallbacks(t *testing.T) {
4330 watcherID := "test-watcher"
4331 timeout := (1 * time.Second) / 100
4332
4333- err := initWatchers([]Watcher{
4334- &testCancel{
4335- watcherID: watcherID,
4336- timeout: timeout,
4337- },
4338- })
4339+ ctx, cancel := context.WithCancel(context.Background())
4340+ eventManager := newManager()
4341
4342- if err != nil {
4343- t.Fatalf("Failed to init/register watcher: %+v", err)
4344- }
4345+ err := eventManager.AddWatcher(ctx, &testCancel{
4346+ watcherID: watcherID,
4347+ timeout: timeout,
4348+ })
4349
4350- eventManager, err := New(&Config{Watchers: []string{watcherID}})
4351 if err != nil {
4352- t.Fatalf("Failed to init event manager: %+v", err)
4353+ t.Fatalf("Failed to add watcher to event manager: %+v", err)
4354 }
4355
4356 eventManager.Subscribe("test-watcher,test-event", nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool {
4357 return true
4358 })
4359
4360- ctx, cancel := context.WithCancel(context.Background())
4361 go func() {
4362 time.Sleep(timeout * 10)
4363 cancel()
4364 }()
4365
4366- eventManager.Run(ctx)
4367+ if err := eventManager.Run(ctx); err != nil {
4368+ t.Errorf("Failed to run event managed, expected success, got error: %+v", err)
4369+ }
4370 }
4371
4372 type testCancelWatcher struct {
4373@@ -285,20 +253,16 @@ func TestCancelCallbacksAndWatchers(t *testing.T) {
4374 t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
4375 cancelSubscriberAfter := curr.cancelSubscriberAfter
4376
4377- err := initWatchers([]Watcher{
4378- &testCancelWatcher{
4379- watcherID: watcherID,
4380- after: curr.cancelWatcherAfter,
4381- },
4382- })
4383+ ctx := context.Background()
4384+ eventManager := newManager()
4385
4386- if err != nil {
4387- t.Fatalf("Failed to init/register watcher: %+v", err)
4388- }
4389+ err := eventManager.AddWatcher(ctx, &testCancelWatcher{
4390+ watcherID: watcherID,
4391+ after: curr.cancelWatcherAfter,
4392+ })
4393
4394- eventManager, err := New(&Config{Watchers: []string{watcherID}})
4395 if err != nil {
4396- t.Fatalf("Failed to init event manager: %+v", err)
4397+ t.Fatalf("Failed to add watcher to event manager: %+v", err)
4398 }
4399
4400 eventManager.Subscribe("test-watcher,test-event", nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool {
4401@@ -310,7 +274,9 @@ func TestCancelCallbacksAndWatchers(t *testing.T) {
4402 return true
4403 })
4404
4405- eventManager.Run(context.Background())
4406+ if err := eventManager.Run(ctx); err != nil {
4407+ t.Errorf("Failed to run event managed, expected success, got error: %+v", err)
4408+ }
4409 })
4410 }
4411 }
4412@@ -320,20 +286,16 @@ func TestMultipleEvents(t *testing.T) {
4413 firstEvent := "multiple-events,first-event"
4414 secondEvent := "multiple-events,second-event"
4415
4416- err := initWatchers([]Watcher{
4417- &testMultipleEvents{
4418- watcherID: watcherID,
4419- eventIDS: []string{firstEvent, secondEvent},
4420- },
4421- })
4422+ ctx := context.Background()
4423+ eventManager := newManager()
4424
4425- if err != nil {
4426- t.Fatalf("Failed to init/register watcher: %+v", err)
4427- }
4428+ err := eventManager.AddWatcher(ctx, &testMultipleEvents{
4429+ watcherID: watcherID,
4430+ eventIDS: []string{firstEvent, secondEvent},
4431+ })
4432
4433- eventManager, err := New(&Config{Watchers: []string{watcherID}})
4434 if err != nil {
4435- t.Fatalf("Failed to init event manager: %+v", err)
4436+ t.Fatalf("Failed to add watcher to event manager: %+v", err)
4437 }
4438
4439 var hitFirstEvent bool
4440@@ -348,7 +310,9 @@ func TestMultipleEvents(t *testing.T) {
4441 return false
4442 })
4443
4444- eventManager.Run(context.Background())
4445+ if err := eventManager.Run(ctx); err != nil {
4446+ t.Errorf("Failed to run event managed, expected success, got error: %+v", err)
4447+ }
4448
4449 if !hitFirstEvent || !hitSecondEvent {
4450 t.Errorf("Failed to call back events, first event hit? (%t), second event hit? (%t)", hitFirstEvent, hitSecondEvent)
4451@@ -371,3 +335,304 @@ func (tt *testMultipleEvents) Events() []string {
4452 func (tt *testMultipleEvents) Run(ctx context.Context, evType string) (bool, interface{}, error) {
4453 return false, nil, nil
4454 }
4455+
4456+func TestAddWatcherAfterRun(t *testing.T) {
4457+ firstWatcher := &genericWatcher{
4458+ watcherID: "first-watcher",
4459+ shouldRenew: true,
4460+ }
4461+
4462+ secondWatcher := &genericWatcher{
4463+ watcherID: "second-watcher",
4464+ }
4465+
4466+ ctx := context.Background()
4467+ eventManager := newManager()
4468+
4469+ err := eventManager.AddWatcher(ctx, firstWatcher)
4470+
4471+ if err != nil {
4472+ t.Fatalf("Failed to add watcher to event manager: %+v", err)
4473+ }
4474+
4475+ eventManager.Subscribe(firstWatcher.eventID(), nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool {
4476+ if err := eventManager.AddWatcher(ctx, secondWatcher); err != nil {
4477+ t.Errorf("Failed to add a second watcher: %+v, expected success", err)
4478+ }
4479+ firstWatcher.shouldRenew = false
4480+ return false
4481+ })
4482+
4483+ var hitSecondEvent bool
4484+ eventManager.Subscribe(secondWatcher.eventID(), nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool {
4485+ hitSecondEvent = true
4486+ return false
4487+ })
4488+
4489+ if err := eventManager.Run(ctx); err != nil {
4490+ t.Errorf("Failed to run event managed, expected success, got error: %+v", err)
4491+ }
4492+
4493+ if !hitSecondEvent {
4494+ t.Errorf("Failed registering second watcher, expected hitSecondEvent: false, got: %t", hitSecondEvent)
4495+ }
4496+}
4497+
4498+type genericWatcher struct {
4499+ watcherID string
4500+ shouldRenew bool
4501+ wait time.Duration
4502+}
4503+
4504+func (gw *genericWatcher) eventID() string {
4505+ return gw.watcherID + ",test-event"
4506+}
4507+
4508+func (gw *genericWatcher) ID() string {
4509+ return gw.watcherID
4510+}
4511+
4512+func (gw *genericWatcher) Events() []string {
4513+ return []string{gw.eventID()}
4514+}
4515+
4516+func (gw *genericWatcher) Run(ctx context.Context, evType string) (bool, interface{}, error) {
4517+ if gw.wait > 0 {
4518+ time.Sleep(gw.wait)
4519+ }
4520+ return gw.shouldRenew, nil, nil
4521+}
4522+
4523+func TestAddDefaultWatchers(t *testing.T) {
4524+ firstWatcher := &genericWatcher{
4525+ watcherID: "first-watcher",
4526+ shouldRenew: false,
4527+ }
4528+
4529+ defaultWatchers = []Watcher{
4530+ firstWatcher,
4531+ }
4532+
4533+ ctx := context.Background()
4534+ eventManager := newManager()
4535+
4536+ err := eventManager.AddDefaultWatchers(ctx)
4537+
4538+ if err != nil {
4539+ t.Fatalf("Failed to add watcher to event manager: %+v", err)
4540+ }
4541+
4542+ if len(eventManager.watchersMap) == 0 {
4543+ t.Fatalf("Failed to add default watchers, expected: %d, got: %d", len(defaultWatchers),
4544+ len(eventManager.watchersMap))
4545+ }
4546+
4547+ if len(eventManager.watcherEvents) == 0 {
4548+ t.Fatalf("Failed to add default watchers, expected: %d, got: %d", len(defaultWatchers),
4549+ len(eventManager.watcherEvents))
4550+ }
4551+}
4552+
4553+func TestCallingRunTwice(t *testing.T) {
4554+ firstWatcher := &genericWatcher{
4555+ watcherID: "first-watcher",
4556+ shouldRenew: false,
4557+ }
4558+
4559+ defaultWatchers = []Watcher{
4560+ firstWatcher,
4561+ }
4562+
4563+ timeout := (1 * time.Second) / 100
4564+ ctx, cancel := context.WithCancel(context.Background())
4565+ eventManager := newManager()
4566+
4567+ err := eventManager.AddDefaultWatchers(ctx)
4568+
4569+ if err != nil {
4570+ t.Fatalf("Failed to add watcher to event manager: %+v", err)
4571+ }
4572+
4573+ var wg sync.WaitGroup
4574+ wg.Add(1)
4575+ go func() {
4576+ defer wg.Done()
4577+ time.Sleep(timeout)
4578+ cancel()
4579+ }()
4580+
4581+ errors := []error{}
4582+ wg.Add(1)
4583+ go func() {
4584+ defer wg.Done()
4585+ if err := eventManager.Run(ctx); err != nil {
4586+ errors = append(errors, err)
4587+ }
4588+ }()
4589+
4590+ wg.Add(1)
4591+ go func() {
4592+ defer wg.Done()
4593+ if err := eventManager.Run(ctx); err != nil {
4594+ errors = append(errors, err)
4595+ }
4596+ }()
4597+
4598+ wg.Wait()
4599+
4600+ if len(errors) == 0 {
4601+ t.Errorf("Executing Run() twice should fail, we got not failure")
4602+ }
4603+
4604+ if len(errors) > 1 {
4605+ t.Errorf("Executing Run() twice should produce a single error, got: %+v", errors)
4606+ }
4607+}
4608+
4609+type testRemoveWatcher struct {
4610+ watcherID string
4611+ timeout time.Duration
4612+}
4613+
4614+func (tc *testRemoveWatcher) ID() string {
4615+ return tc.watcherID
4616+}
4617+
4618+func (tc *testRemoveWatcher) Events() []string {
4619+ return []string{tc.watcherID + ",test-event"}
4620+}
4621+
4622+func (tc *testRemoveWatcher) Run(ctx context.Context, evType string) (bool, interface{}, error) {
4623+ select {
4624+ case <-ctx.Done():
4625+ return false, nil, nil
4626+ case <-time.After(tc.timeout):
4627+ return true, nil, nil
4628+ }
4629+}
4630+
4631+func TestRemoveWatcherBeforeCallbacks(t *testing.T) {
4632+ watcherID := "test-watcher"
4633+ timeout := (1 * time.Second) / 100
4634+
4635+ ctx := context.Background()
4636+ eventManager := newManager()
4637+
4638+ watcher := &testRemoveWatcher{
4639+ watcherID: watcherID,
4640+ timeout: timeout,
4641+ }
4642+
4643+ err := eventManager.AddWatcher(ctx, watcher)
4644+
4645+ if err != nil {
4646+ t.Fatalf("Failed to add watcher to event manager: %+v", err)
4647+ }
4648+
4649+ eventManager.Subscribe("test-watcher,test-event", nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool {
4650+ t.Errorf("Expected to have canceled before calling callback")
4651+ return false
4652+ })
4653+
4654+ go func() {
4655+ time.Sleep(timeout / 2)
4656+ if err := eventManager.RemoveWatcher(ctx, watcher); err != nil {
4657+ t.Errorf("Failed to remove watcher: %+v", err)
4658+ }
4659+ }()
4660+
4661+ if err := eventManager.Run(ctx); err != nil {
4662+ t.Errorf("Failed running event manager, expected success, got error: %+v", err)
4663+ }
4664+}
4665+
4666+func TestRemoveWatcherFromCallback(t *testing.T) {
4667+ watcher := &genericWatcher{
4668+ watcherID: "first-watcher",
4669+ shouldRenew: true,
4670+ }
4671+
4672+ ctx := context.Background()
4673+ eventManager := newManager()
4674+
4675+ err := eventManager.AddWatcher(ctx, watcher)
4676+
4677+ if err != nil {
4678+ t.Fatalf("Failed to add watcher to event manager: %+v", err)
4679+ }
4680+
4681+ eventManager.Subscribe(watcher.eventID(), nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool {
4682+ if err := eventManager.RemoveWatcher(ctx, watcher); err != nil {
4683+ t.Fatalf("Failed to remove watcher, it should have succeeded: %+v", err)
4684+ }
4685+ return true
4686+ })
4687+
4688+ if err := eventManager.Run(ctx); err != nil {
4689+ t.Errorf("Failed running event manager, expected success, got error: %+v", err)
4690+ }
4691+}
4692+
4693+func TestCrossWatcherRemovalFromCallback(t *testing.T) {
4694+ firstWatcher := &genericWatcher{
4695+ watcherID: "first-watcher",
4696+ shouldRenew: true,
4697+ }
4698+
4699+ secondWatcher := &genericWatcher{
4700+ watcherID: "second-watcher",
4701+ shouldRenew: true,
4702+ }
4703+
4704+ thirdWatcher := &genericWatcher{
4705+ watcherID: "third-watcher",
4706+ shouldRenew: true,
4707+ wait: (1 * time.Second) / 3,
4708+ }
4709+
4710+ ctx := context.Background()
4711+ eventManager := newManager()
4712+
4713+ watchers := []Watcher{
4714+ firstWatcher,
4715+ secondWatcher,
4716+ thirdWatcher,
4717+ }
4718+
4719+ for _, curr := range watchers {
4720+ err := eventManager.AddWatcher(ctx, curr)
4721+
4722+ if err != nil {
4723+ t.Fatalf("Failed to add watcher to event manager: %+v", err)
4724+ }
4725+ }
4726+
4727+ removed := false
4728+ eventManager.Subscribe(thirdWatcher.eventID(), nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool {
4729+ if !removed {
4730+ if err := eventManager.RemoveWatcher(ctx, firstWatcher); err != nil {
4731+ t.Errorf("Failed to remove firstWatcher, it should have succeeded: %+v", err)
4732+ }
4733+ if err := eventManager.RemoveWatcher(ctx, secondWatcher); err != nil {
4734+ t.Errorf("Failed to remove secondWatcher, it should have succeeded: %+v", err)
4735+ }
4736+ removed = true
4737+ return true
4738+ }
4739+
4740+ queueLen := eventManager.queue.length()
4741+ if queueLen != 1 {
4742+ t.Errorf("Failed to remove watcher, expected remaining watchers: 1, got: %d", queueLen)
4743+ }
4744+
4745+ if err := eventManager.RemoveWatcher(ctx, thirdWatcher); err != nil {
4746+ t.Errorf("Failed to remove thirdWatcher, it should have succeeded: %+v", err)
4747+ }
4748+
4749+ return false
4750+ })
4751+
4752+ if err := eventManager.Run(ctx); err != nil {
4753+ t.Errorf("Failed running event manager, expected success, got error: %+v", err)
4754+ }
4755+}
4756diff --git a/google_guest_agent/events/metadata/metadata.go b/google_guest_agent/events/metadata/metadata.go
4757index 9b93cb3..80a9b15 100644
4758--- a/google_guest_agent/events/metadata/metadata.go
4759+++ b/google_guest_agent/events/metadata/metadata.go
4760@@ -19,7 +19,6 @@ import (
4761 "context"
4762 "net"
4763 "net/url"
4764- "time"
4765
4766 "github.com/GoogleCloudPlatform/guest-agent/metadata"
4767 "github.com/GoogleCloudPlatform/guest-logging-go/logger"
4768@@ -32,11 +31,6 @@ const (
4769 LongpollEvent = "metadata-watcher,longpoll"
4770 )
4771
4772-var (
4773- // arbitrarily defined wait duration(keeps behavioral backward compatibility).
4774- retryWaitDuration = 5 * time.Second
4775-)
4776-
4777 // Watcher is the metadata event watcher implementation.
4778 type Watcher struct {
4779 client metadata.MDSClientInterface
4780diff --git a/google_guest_agent/events/metadata/metadata_test.go b/google_guest_agent/events/metadata/metadata_test.go
4781index 1b8aa5a..291f656 100644
4782--- a/google_guest_agent/events/metadata/metadata_test.go
4783+++ b/google_guest_agent/events/metadata/metadata_test.go
4784@@ -39,6 +39,10 @@ func (mds *mdsClient) GetKey(ctx context.Context, key string, headers map[string
4785 return "", fmt.Errorf("GetKey() not yet implemented")
4786 }
4787
4788+func (mds *mdsClient) GetKeyRecursive(ctx context.Context, key string) (string, error) {
4789+ return "", fmt.Errorf("GetKeyRecursive() not yet implemented")
4790+}
4791+
4792 func (mds *mdsClient) Watch(ctx context.Context) (*metadata.Descriptor, error) {
4793 if !mds.disableUnknownFailure {
4794 return nil, errUnknown
4795diff --git a/google_guest_agent/fakes/fake_mds.go b/google_guest_agent/fakes/fake_mds.go
4796index b6e25b1..0e8213a 100644
4797--- a/google_guest_agent/fakes/fake_mds.go
4798+++ b/google_guest_agent/fakes/fake_mds.go
4799@@ -37,6 +37,11 @@ func NewFakeMDSClient() *MDSClient {
4800 return &MDSClient{}
4801 }
4802
4803+// GetKeyRecursive implements fake GetKeyRecursive MDS method.
4804+func (s MDSClient) GetKeyRecursive(ctx context.Context, key string) (string, error) {
4805+ return "", fmt.Errorf("GetKeyRecursive() not yet implemented")
4806+}
4807+
4808 // GetKey implements fake GetKey MDS method.
4809 func (s MDSClient) GetKey(ctx context.Context, key string, headers map[string]string) (string, error) {
4810 valid := `
4811diff --git a/google_guest_agent/instance_setup.go b/google_guest_agent/instance_setup.go
4812index 38c834a..6327821 100644
4813--- a/google_guest_agent/instance_setup.go
4814+++ b/google_guest_agent/instance_setup.go
4815@@ -25,8 +25,10 @@ import (
4816 "strings"
4817 "time"
4818
4819+ "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/agentcrypto"
4820 "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
4821 "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/run"
4822+ "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/scheduler"
4823 "github.com/GoogleCloudPlatform/guest-logging-go/logger"
4824 "github.com/go-ini/ini"
4825 )
4826@@ -148,13 +150,14 @@ func agentInit(ctx context.Context) {
4827 return
4828 }
4829
4830- // The below actions require metadata to be set, so if it
4831- // hasn't yet been set, wait on it here. In instances without
4832- // network access, this will become an indefinite wait.
4833- // TODO: split agentInit into needs-network and no-network functions.
4834- for newMetadata == nil {
4835- logger.Debugf("populate first time metadata...")
4836- newMetadata, _ = mdsClient.Get(ctx)
4837+ if newMetadata == nil {
4838+ var err error
4839+ logger.Debugf("populate metadata for the first time...")
4840+ newMetadata, err = mdsClient.Get(ctx)
4841+ if err != nil {
4842+ logger.Errorf("Failed to reach MDS(all retries exhausted): %+v", err)
4843+ os.Exit(1)
4844+ }
4845 }
4846
4847 // Disable overcommit accounting; e2 instances only.
4848@@ -204,6 +207,14 @@ func agentInit(ctx context.Context) {
4849 }
4850 }
4851 }
4852+ // Schedules jobs that need to be started before notifying systemd Agent process has started.
4853+ // We want to generate MDS credentials as early as possible so that any process in the Guest can
4854+ // use them. Processes may depend on the Guest Agent at startup to ensure that the credentials are
4855+ // available for use. By generating the credentials before notifying the systemd, we ensure that
4856+ // they are generated for any process that depends on the Guest Agent.
4857+ if config.MDS.MTLSBootstrappingEnabled {
4858+ scheduler.ScheduleJobs(ctx, []scheduler.Job{agentcrypto.New()}, true)
4859+ }
4860 }
4861
4862 func generateSSHKeys(ctx context.Context) error {
4863diff --git a/google_guest_agent/instance_setup_integ_test.go b/google_guest_agent/instance_setup_integ_test.go
4864deleted file mode 100644
4865index 6c343bf..0000000
4866--- a/google_guest_agent/instance_setup_integ_test.go
4867+++ /dev/null
4868@@ -1,208 +0,0 @@
4869-// Copyright 2021 Google LLC
4870-//
4871-// Licensed under the Apache License, Version 2.0 (the "License");
4872-// you may not use this file except in compliance with the License.
4873-// You may obtain a copy of the License at
4874-//
4875-// http://www.apache.org/licenses/LICENSE-2.0
4876-//
4877-// Unless required by applicable law or agreed to in writing, software
4878-// distributed under the License is distributed on an "AS IS" BASIS,
4879-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4880-// See the License for the specific language governing permissions and
4881-// limitations under the License.
4882-
4883-//go:build integration
4884-// +build integration
4885-
4886-package main
4887-
4888-import (
4889- "context"
4890- "os"
4891- "path/filepath"
4892- "strings"
4893- "testing"
4894-
4895- "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
4896-)
4897-
4898-const (
4899- botoCfg = "/etc/boto.cfg"
4900-)
4901-
4902-func getConfig(t *testing.T) (*cfg.Sections, string) {
4903- t.Helper()
4904-
4905- if err := cfg.Load(nil); err != nil {
4906- t.Fatalf("Failed to load configuration: %+v", err)
4907- }
4908-
4909- config, err := cfg.Get()
4910- if err != nil {
4911- t.Fatalf("Failed to get config: %+v", err)
4912- }
4913-
4914- if config == nil {
4915- t.Fatal("cfg.Get() returned a nil config")
4916- }
4917-
4918- tempDir := filepath.Join(t.TempDir(), "test_instance_setup")
4919- err := os.Mkdir(tempDir, 0755)
4920- if err != nil {
4921- t.Fatalf("Failed to create working dir: %+v", err)
4922- }
4923-
4924- // Configure a non-standard instance ID dir for us to play with.
4925- config.Instance.InstanceIDDir = tempDir
4926- config.InstanceSetup.HostKeyDir = tempDir
4927-
4928- return config, tempDir
4929-}
4930-
4931-// TestInstanceSetupSSHKeys validates SSH keys are generated on first boot and not changed afterward.
4932-func TestInstanceSetupSSHKeys(t *testing.T) {
4933- config, tempDir := getConfig(t)
4934- ctx := context.Background()
4935- agentInit(ctx)
4936-
4937- if _, err := os.Stat(tempDir + "/google_instance_id"); err != nil {
4938- t.Fatal("instance ID File was not created by agentInit")
4939- }
4940-
4941- dir, err := os.Open(tempDir)
4942- if err != nil {
4943- t.Fatal("failed to open working dir")
4944- }
4945- defer dir.Close()
4946-
4947- files, err := dir.Readdirnames(0)
4948- if err != nil {
4949- t.Fatal("failed to read files")
4950- }
4951-
4952- var keys []string
4953- for _, file := range files {
4954- if strings.HasPrefix(file, "ssh_host_") {
4955- keys = append(keys, file)
4956- }
4957- }
4958-
4959- if len(keys) == 0 {
4960- t.Fatal("instance setup didn't create SSH host keys")
4961- }
4962-
4963- // Remove one key file and run again to confirm SSH keys have not
4964- // changed because the instance ID file has not changed.
4965- if err := os.Remove(tempDir + "/" + keys[0]); err != nil {
4966- t.Fatal("failed to remove key file")
4967- }
4968-
4969- agentInit(ctx)
4970-
4971- if _, err := dir.Seek(0, 0); err != nil {
4972- t.Fatal("failed to rewind dir for second check")
4973- }
4974- files2, err := dir.Readdirnames(0)
4975- if err != nil {
4976- t.Fatal("failed to read files")
4977- }
4978-
4979- var keys2 []string
4980- for _, file := range files2 {
4981- if strings.HasPrefix(file, "ssh_host_") {
4982- keys2 = append(keys2, file)
4983- }
4984- if file == keys[0] {
4985- t.Fatalf("agentInit recreated key %s", file)
4986- }
4987- }
4988-
4989- if len(keys) == len(keys2) {
4990- t.Fatal("agentInit recreated SSH host keys")
4991- }
4992-}
4993-
4994-// TestInstanceSetupSSHKeysDisabled validates the config option to disable host
4995-// key generation is respected.
4996-func TestInstanceSetupSSHKeysDisabled(t *testing.T) {
4997- config, tempDir := getConfig(t)
4998-
4999- // Disable SSH host key generation.
5000- config.InstanceSetup.SetHostKeys = false
The diff has been truncated for viewing.

Subscribers

People subscribed via source and target branches