Merge ~kajiya/+git/google-guest-agent:kajiya/upstream into ~ubuntu-core-dev/+git/google-guest-agent:upstream
- Git
- lp:~kajiya/+git/google-guest-agent
- kajiya/upstream
- Merge into upstream
Proposed by
Chloé Smith
Status: | Merged | ||||
---|---|---|---|---|---|
Merged at revision: | 382d1f5333b0eb73bae0fe74efc5534c295219f6 | ||||
Proposed branch: | ~kajiya/+git/google-guest-agent:kajiya/upstream | ||||
Merge into: | ~ubuntu-core-dev/+git/google-guest-agent:upstream | ||||
Diff against target: |
10201 lines (+6466/-1274) 70 files modified
.gitignore (+11/-1) OWNERS (+2/-2) THIRD_PARTY_LICENSES/cloud.google.com/go/internal/LICENSE (+2/-1) THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/LICENSE (+2/-2) THIRD_PARTY_LICENSES/github.com/Microsoft/go-winio/LICENSE (+22/-0) THIRD_PARTY_LICENSES/golang.org/x/xerrors/LICENSE (+27/-0) THIRD_PARTY_LICENSES/software.sslmate.com/src/go-pkcs12/LICENSE (+1/-1) dev/null (+0/-93) gce_workload_cert_refresh/main.go (+169/-152) gce_workload_cert_refresh/main_test.go (+495/-0) go.mod (+5/-2) go.sum (+9/-0) google_authorized_keys/main.go (+14/-49) google_authorized_keys/main_test.go (+103/-44) google_guest_agent/addresses.go (+21/-245) google_guest_agent/agentcrypto/mtls_mds.go (+1/-1) google_guest_agent/agentcrypto/mtls_mds_linux.go (+35/-14) google_guest_agent/agentcrypto/mtls_mds_linux_test.go (+21/-4) google_guest_agent/agentcrypto/mtls_mds_windows.go (+25/-0) google_guest_agent/cfg/cfg.go (+30/-8) google_guest_agent/cfg/cfg_test.go (+3/-1) google_guest_agent/command/Readme.md (+24/-0) google_guest_agent/command/command.go (+146/-0) google_guest_agent/command/command_linux.go (+140/-0) google_guest_agent/command/command_monitor.go (+228/-0) google_guest_agent/command/command_test.go (+209/-0) google_guest_agent/command/command_windows.go (+104/-0) google_guest_agent/command/command_windows_test.go (+73/-0) google_guest_agent/diagnostics.go (+2/-1) google_guest_agent/events/events.go (+308/-138) google_guest_agent/events/events_test.go (+372/-107) google_guest_agent/events/metadata/metadata.go (+0/-6) google_guest_agent/events/metadata/metadata_test.go (+4/-0) google_guest_agent/fakes/fake_mds.go (+5/-0) google_guest_agent/instance_setup.go (+18/-7) google_guest_agent/main.go (+22/-20) google_guest_agent/network/manager/common.go (+90/-0) google_guest_agent/network/manager/dhclient_linux.go (+292/-0) google_guest_agent/network/manager/dhclient_linux_test.go (+511/-0) google_guest_agent/network/manager/manager.go (+362/-0) google_guest_agent/network/manager/manager_test.go (+353/-0) google_guest_agent/network/manager/systemd_networkd_linux.go (+295/-0) google_guest_agent/network/manager/systemd_networkd_linux_test.go (+549/-0) google_guest_agent/non_windows_accounts.go (+3/-1) google_guest_agent/oslogin.go (+91/-0) google_guest_agent/ps/ps.go (+37/-0) google_guest_agent/ps/ps_linux.go (+121/-0) google_guest_agent/ps/ps_linux_test.go (+157/-0) google_guest_agent/ps/ps_windows.go (+21/-0) google_guest_agent/run/run.go (+51/-4) google_guest_agent/scheduler/logger.go (+2/-2) google_guest_agent/scheduler/scheduler.go (+1/-0) google_guest_agent/scheduler/scheduler_test.go (+8/-8) google_guest_agent/snapshot_listener.go (+7/-5) google_guest_agent/sshca/sshca.go (+9/-11) google_guest_agent/windows_accounts.go (+2/-1) google_guest_agent/windows_accounts_test.go (+49/-42) google_metadata_script_runner/main.go (+63/-78) google_metadata_script_runner/main_test.go (+78/-11) metadata/metadata.go (+16/-0) metadata/metadata_test.go (+30/-0) packaging/genlicense.sh (+24/-0) retry/retry.go (+107/-0) retry/retry_test.go (+193/-0) utils/file.go (+80/-0) utils/file_test.go (+87/-0) utils/serial_port_logger.go (+35/-0) utils/ssh.go (+19/-108) utils/ssh_test.go (+32/-104) utils/test.go (+38/-0) |
||||
Related bugs: |
|
Reviewer | Review Type | Date Requested | Status |
---|---|---|---|
Utkarsh Gupta | Approve | ||
Review via email: mp+460883@code.launchpad.net |
Commit message
New upstream version 20240213.00
Description of the change
To post a comment you must log in.
Preview Diff
[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1 | diff --git a/.gitignore b/.gitignore |
2 | index 2ceee21..85eb02e 100644 |
3 | --- a/.gitignore |
4 | +++ b/.gitignore |
5 | @@ -1,5 +1,15 @@ |
6 | -# ignore all built binaries |
7 | +# Ignore all built binaries. |
8 | **/gce_workload_cert_refresh |
9 | +**/gce_workload_cert_refresh.exe |
10 | **/google_authorized_keys |
11 | +**/google_authorized_keys.exe |
12 | **/google_guest_agent |
13 | +**/google_guest_agent.exe |
14 | **/google_metadata_script_runner |
15 | +**/google_metadata_script_runner.exe |
16 | + |
17 | +# Don't ignore new content to directories. |
18 | +!**/gce_workload_cert_refresh/ |
19 | +!**/google_authorized_keys/ |
20 | +!**/google_guest_agent/ |
21 | +!**/google_metadata_script_runner/ |
22 | \ No newline at end of file |
23 | diff --git a/OWNERS b/OWNERS |
24 | index edde10e..59fca61 100644 |
25 | --- a/OWNERS |
26 | +++ b/OWNERS |
27 | @@ -3,14 +3,14 @@ |
28 | |
29 | approvers: |
30 | - a-crate |
31 | + - ajorg |
32 | - bkatyl |
33 | - chaitanyakulkarni28 |
34 | - dorileo |
35 | - drewhli |
36 | - elicriffield |
37 | + - gaughen |
38 | - jjerger |
39 | - karnvadaliya |
40 | - koln67 |
41 | - - quintonamore |
42 | - - vorakl |
43 | - zmarano |
44 | diff --git a/THIRD_PARTY_LICENSES/cloud.google.com/go/LICENSE b/THIRD_PARTY_LICENSES/cloud.google.com/go/iam/LICENSE |
45 | similarity index 100% |
46 | rename from THIRD_PARTY_LICENSES/cloud.google.com/go/LICENSE |
47 | rename to THIRD_PARTY_LICENSES/cloud.google.com/go/iam/LICENSE |
48 | diff --git a/THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/LICENSE b/THIRD_PARTY_LICENSES/cloud.google.com/go/internal/LICENSE |
49 | similarity index 100% |
50 | rename from THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/LICENSE |
51 | rename to THIRD_PARTY_LICENSES/cloud.google.com/go/internal/LICENSE |
52 | index 65fb971..d645695 100644 |
53 | --- a/THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/LICENSE |
54 | +++ b/THIRD_PARTY_LICENSES/cloud.google.com/go/internal/LICENSE |
55 | @@ -1,3 +1,4 @@ |
56 | + |
57 | Apache License |
58 | Version 2.0, January 2004 |
59 | http://www.apache.org/licenses/ |
60 | @@ -186,7 +187,7 @@ |
61 | same "printed page" as the copyright notice for easier |
62 | identification within third-party archives. |
63 | |
64 | - Copyright 2020 Google Inc. |
65 | + Copyright [yyyy] [name of copyright owner] |
66 | |
67 | Licensed under the Apache License, Version 2.0 (the "License"); |
68 | you may not use this file except in compliance with the License. |
69 | diff --git a/THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/LICENSE b/THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/LICENSE |
70 | index 65fb971..f49a4e1 100644 |
71 | --- a/THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/LICENSE |
72 | +++ b/THIRD_PARTY_LICENSES/github.com/GoogleCloudPlatform/guest-agent/LICENSE |
73 | @@ -186,7 +186,7 @@ |
74 | same "printed page" as the copyright notice for easier |
75 | identification within third-party archives. |
76 | |
77 | - Copyright 2020 Google Inc. |
78 | + Copyright [yyyy] [name of copyright owner] |
79 | |
80 | Licensed under the Apache License, Version 2.0 (the "License"); |
81 | you may not use this file except in compliance with the License. |
82 | @@ -198,4 +198,4 @@ |
83 | distributed under the License is distributed on an "AS IS" BASIS, |
84 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
85 | See the License for the specific language governing permissions and |
86 | - limitations under the License. |
87 | + limitations under the License. |
88 | \ No newline at end of file |
89 | diff --git a/THIRD_PARTY_LICENSES/github.com/Microsoft/go-winio/LICENSE b/THIRD_PARTY_LICENSES/github.com/Microsoft/go-winio/LICENSE |
90 | new file mode 100644 |
91 | index 0000000..b8b569d |
92 | --- /dev/null |
93 | +++ b/THIRD_PARTY_LICENSES/github.com/Microsoft/go-winio/LICENSE |
94 | @@ -0,0 +1,22 @@ |
95 | +The MIT License (MIT) |
96 | + |
97 | +Copyright (c) 2015 Microsoft |
98 | + |
99 | +Permission is hereby granted, free of charge, to any person obtaining a copy |
100 | +of this software and associated documentation files (the "Software"), to deal |
101 | +in the Software without restriction, including without limitation the rights |
102 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
103 | +copies of the Software, and to permit persons to whom the Software is |
104 | +furnished to do so, subject to the following conditions: |
105 | + |
106 | +The above copyright notice and this permission notice shall be included in all |
107 | +copies or substantial portions of the Software. |
108 | + |
109 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
110 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
111 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
112 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
113 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
114 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
115 | +SOFTWARE. |
116 | + |
117 | diff --git a/THIRD_PARTY_LICENSES/golang.org/x/xerrors/LICENSE b/THIRD_PARTY_LICENSES/golang.org/x/xerrors/LICENSE |
118 | new file mode 100644 |
119 | index 0000000..e4a47e1 |
120 | --- /dev/null |
121 | +++ b/THIRD_PARTY_LICENSES/golang.org/x/xerrors/LICENSE |
122 | @@ -0,0 +1,27 @@ |
123 | +Copyright (c) 2019 The Go Authors. All rights reserved. |
124 | + |
125 | +Redistribution and use in source and binary forms, with or without |
126 | +modification, are permitted provided that the following conditions are |
127 | +met: |
128 | + |
129 | + * Redistributions of source code must retain the above copyright |
130 | +notice, this list of conditions and the following disclaimer. |
131 | + * Redistributions in binary form must reproduce the above |
132 | +copyright notice, this list of conditions and the following disclaimer |
133 | +in the documentation and/or other materials provided with the |
134 | +distribution. |
135 | + * Neither the name of Google Inc. nor the names of its |
136 | +contributors may be used to endorse or promote products derived from |
137 | +this software without specific prior written permission. |
138 | + |
139 | +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
140 | +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
141 | +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
142 | +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
143 | +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
144 | +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
145 | +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
146 | +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
147 | +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
148 | +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
149 | +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
150 | diff --git a/THIRD_PARTY_LICENSES/google.golang.org/genproto/LICENSE b/THIRD_PARTY_LICENSES/google.golang.org/genproto/LICENSE |
151 | deleted file mode 100644 |
152 | index d645695..0000000 |
153 | --- a/THIRD_PARTY_LICENSES/google.golang.org/genproto/LICENSE |
154 | +++ /dev/null |
155 | @@ -1,202 +0,0 @@ |
156 | - |
157 | - Apache License |
158 | - Version 2.0, January 2004 |
159 | - http://www.apache.org/licenses/ |
160 | - |
161 | - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION |
162 | - |
163 | - 1. Definitions. |
164 | - |
165 | - "License" shall mean the terms and conditions for use, reproduction, |
166 | - and distribution as defined by Sections 1 through 9 of this document. |
167 | - |
168 | - "Licensor" shall mean the copyright owner or entity authorized by |
169 | - the copyright owner that is granting the License. |
170 | - |
171 | - "Legal Entity" shall mean the union of the acting entity and all |
172 | - other entities that control, are controlled by, or are under common |
173 | - control with that entity. For the purposes of this definition, |
174 | - "control" means (i) the power, direct or indirect, to cause the |
175 | - direction or management of such entity, whether by contract or |
176 | - otherwise, or (ii) ownership of fifty percent (50%) or more of the |
177 | - outstanding shares, or (iii) beneficial ownership of such entity. |
178 | - |
179 | - "You" (or "Your") shall mean an individual or Legal Entity |
180 | - exercising permissions granted by this License. |
181 | - |
182 | - "Source" form shall mean the preferred form for making modifications, |
183 | - including but not limited to software source code, documentation |
184 | - source, and configuration files. |
185 | - |
186 | - "Object" form shall mean any form resulting from mechanical |
187 | - transformation or translation of a Source form, including but |
188 | - not limited to compiled object code, generated documentation, |
189 | - and conversions to other media types. |
190 | - |
191 | - "Work" shall mean the work of authorship, whether in Source or |
192 | - Object form, made available under the License, as indicated by a |
193 | - copyright notice that is included in or attached to the work |
194 | - (an example is provided in the Appendix below). |
195 | - |
196 | - "Derivative Works" shall mean any work, whether in Source or Object |
197 | - form, that is based on (or derived from) the Work and for which the |
198 | - editorial revisions, annotations, elaborations, or other modifications |
199 | - represent, as a whole, an original work of authorship. For the purposes |
200 | - of this License, Derivative Works shall not include works that remain |
201 | - separable from, or merely link (or bind by name) to the interfaces of, |
202 | - the Work and Derivative Works thereof. |
203 | - |
204 | - "Contribution" shall mean any work of authorship, including |
205 | - the original version of the Work and any modifications or additions |
206 | - to that Work or Derivative Works thereof, that is intentionally |
207 | - submitted to Licensor for inclusion in the Work by the copyright owner |
208 | - or by an individual or Legal Entity authorized to submit on behalf of |
209 | - the copyright owner. For the purposes of this definition, "submitted" |
210 | - means any form of electronic, verbal, or written communication sent |
211 | - to the Licensor or its representatives, including but not limited to |
212 | - communication on electronic mailing lists, source code control systems, |
213 | - and issue tracking systems that are managed by, or on behalf of, the |
214 | - Licensor for the purpose of discussing and improving the Work, but |
215 | - excluding communication that is conspicuously marked or otherwise |
216 | - designated in writing by the copyright owner as "Not a Contribution." |
217 | - |
218 | - "Contributor" shall mean Licensor and any individual or Legal Entity |
219 | - on behalf of whom a Contribution has been received by Licensor and |
220 | - subsequently incorporated within the Work. |
221 | - |
222 | - 2. Grant of Copyright License. Subject to the terms and conditions of |
223 | - this License, each Contributor hereby grants to You a perpetual, |
224 | - worldwide, non-exclusive, no-charge, royalty-free, irrevocable |
225 | - copyright license to reproduce, prepare Derivative Works of, |
226 | - publicly display, publicly perform, sublicense, and distribute the |
227 | - Work and such Derivative Works in Source or Object form. |
228 | - |
229 | - 3. Grant of Patent License. Subject to the terms and conditions of |
230 | - this License, each Contributor hereby grants to You a perpetual, |
231 | - worldwide, non-exclusive, no-charge, royalty-free, irrevocable |
232 | - (except as stated in this section) patent license to make, have made, |
233 | - use, offer to sell, sell, import, and otherwise transfer the Work, |
234 | - where such license applies only to those patent claims licensable |
235 | - by such Contributor that are necessarily infringed by their |
236 | - Contribution(s) alone or by combination of their Contribution(s) |
237 | - with the Work to which such Contribution(s) was submitted. If You |
238 | - institute patent litigation against any entity (including a |
239 | - cross-claim or counterclaim in a lawsuit) alleging that the Work |
240 | - or a Contribution incorporated within the Work constitutes direct |
241 | - or contributory patent infringement, then any patent licenses |
242 | - granted to You under this License for that Work shall terminate |
243 | - as of the date such litigation is filed. |
244 | - |
245 | - 4. Redistribution. You may reproduce and distribute copies of the |
246 | - Work or Derivative Works thereof in any medium, with or without |
247 | - modifications, and in Source or Object form, provided that You |
248 | - meet the following conditions: |
249 | - |
250 | - (a) You must give any other recipients of the Work or |
251 | - Derivative Works a copy of this License; and |
252 | - |
253 | - (b) You must cause any modified files to carry prominent notices |
254 | - stating that You changed the files; and |
255 | - |
256 | - (c) You must retain, in the Source form of any Derivative Works |
257 | - that You distribute, all copyright, patent, trademark, and |
258 | - attribution notices from the Source form of the Work, |
259 | - excluding those notices that do not pertain to any part of |
260 | - the Derivative Works; and |
261 | - |
262 | - (d) If the Work includes a "NOTICE" text file as part of its |
263 | - distribution, then any Derivative Works that You distribute must |
264 | - include a readable copy of the attribution notices contained |
265 | - within such NOTICE file, excluding those notices that do not |
266 | - pertain to any part of the Derivative Works, in at least one |
267 | - of the following places: within a NOTICE text file distributed |
268 | - as part of the Derivative Works; within the Source form or |
269 | - documentation, if provided along with the Derivative Works; or, |
270 | - within a display generated by the Derivative Works, if and |
271 | - wherever such third-party notices normally appear. The contents |
272 | - of the NOTICE file are for informational purposes only and |
273 | - do not modify the License. You may add Your own attribution |
274 | - notices within Derivative Works that You distribute, alongside |
275 | - or as an addendum to the NOTICE text from the Work, provided |
276 | - that such additional attribution notices cannot be construed |
277 | - as modifying the License. |
278 | - |
279 | - You may add Your own copyright statement to Your modifications and |
280 | - may provide additional or different license terms and conditions |
281 | - for use, reproduction, or distribution of Your modifications, or |
282 | - for any such Derivative Works as a whole, provided Your use, |
283 | - reproduction, and distribution of the Work otherwise complies with |
284 | - the conditions stated in this License. |
285 | - |
286 | - 5. Submission of Contributions. Unless You explicitly state otherwise, |
287 | - any Contribution intentionally submitted for inclusion in the Work |
288 | - by You to the Licensor shall be under the terms and conditions of |
289 | - this License, without any additional terms or conditions. |
290 | - Notwithstanding the above, nothing herein shall supersede or modify |
291 | - the terms of any separate license agreement you may have executed |
292 | - with Licensor regarding such Contributions. |
293 | - |
294 | - 6. Trademarks. This License does not grant permission to use the trade |
295 | - names, trademarks, service marks, or product names of the Licensor, |
296 | - except as required for reasonable and customary use in describing the |
297 | - origin of the Work and reproducing the content of the NOTICE file. |
298 | - |
299 | - 7. Disclaimer of Warranty. Unless required by applicable law or |
300 | - agreed to in writing, Licensor provides the Work (and each |
301 | - Contributor provides its Contributions) on an "AS IS" BASIS, |
302 | - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
303 | - implied, including, without limitation, any warranties or conditions |
304 | - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A |
305 | - PARTICULAR PURPOSE. You are solely responsible for determining the |
306 | - appropriateness of using or redistributing the Work and assume any |
307 | - risks associated with Your exercise of permissions under this License. |
308 | - |
309 | - 8. Limitation of Liability. In no event and under no legal theory, |
310 | - whether in tort (including negligence), contract, or otherwise, |
311 | - unless required by applicable law (such as deliberate and grossly |
312 | - negligent acts) or agreed to in writing, shall any Contributor be |
313 | - liable to You for damages, including any direct, indirect, special, |
314 | - incidental, or consequential damages of any character arising as a |
315 | - result of this License or out of the use or inability to use the |
316 | - Work (including but not limited to damages for loss of goodwill, |
317 | - work stoppage, computer failure or malfunction, or any and all |
318 | - other commercial damages or losses), even if such Contributor |
319 | - has been advised of the possibility of such damages. |
320 | - |
321 | - 9. Accepting Warranty or Additional Liability. While redistributing |
322 | - the Work or Derivative Works thereof, You may choose to offer, |
323 | - and charge a fee for, acceptance of support, warranty, indemnity, |
324 | - or other liability obligations and/or rights consistent with this |
325 | - License. However, in accepting such obligations, You may act only |
326 | - on Your own behalf and on Your sole responsibility, not on behalf |
327 | - of any other Contributor, and only if You agree to indemnify, |
328 | - defend, and hold each Contributor harmless for any liability |
329 | - incurred by, or claims asserted against, such Contributor by reason |
330 | - of your accepting any such warranty or additional liability. |
331 | - |
332 | - END OF TERMS AND CONDITIONS |
333 | - |
334 | - APPENDIX: How to apply the Apache License to your work. |
335 | - |
336 | - To apply the Apache License to your work, attach the following |
337 | - boilerplate notice, with the fields enclosed by brackets "[]" |
338 | - replaced with your own identifying information. (Don't include |
339 | - the brackets!) The text should be enclosed in the appropriate |
340 | - comment syntax for the file format. We also recommend that a |
341 | - file or class name and description of purpose be included on the |
342 | - same "printed page" as the copyright notice for easier |
343 | - identification within third-party archives. |
344 | - |
345 | - Copyright [yyyy] [name of copyright owner] |
346 | - |
347 | - Licensed under the Apache License, Version 2.0 (the "License"); |
348 | - you may not use this file except in compliance with the License. |
349 | - You may obtain a copy of the License at |
350 | - |
351 | - http://www.apache.org/licenses/LICENSE-2.0 |
352 | - |
353 | - Unless required by applicable law or agreed to in writing, software |
354 | - distributed under the License is distributed on an "AS IS" BASIS, |
355 | - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
356 | - See the License for the specific language governing permissions and |
357 | - limitations under the License. |
358 | diff --git a/THIRD_PARTY_LICENSES/software.sslmate.com/src/go-pkcs12/LICENSE b/THIRD_PARTY_LICENSES/software.sslmate.com/src/go-pkcs12/LICENSE |
359 | index 6ac6b11..bcecd3d 100644 |
360 | --- a/THIRD_PARTY_LICENSES/software.sslmate.com/src/go-pkcs12/LICENSE |
361 | +++ b/THIRD_PARTY_LICENSES/software.sslmate.com/src/go-pkcs12/LICENSE |
362 | @@ -25,4 +25,4 @@ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
363 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
364 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
365 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
366 | -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
367 | \ No newline at end of file |
368 | +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
369 | diff --git a/gce_workload_cert_refresh/main.go b/gce_workload_cert_refresh/main.go |
370 | index adb83aa..c3c0b7d 100644 |
371 | --- a/gce_workload_cert_refresh/main.go |
372 | +++ b/gce_workload_cert_refresh/main.go |
373 | @@ -16,11 +16,15 @@ |
374 | package main |
375 | |
376 | import ( |
377 | + "bytes" |
378 | "context" |
379 | "encoding/json" |
380 | "fmt" |
381 | "io" |
382 | "os" |
383 | + "path" |
384 | + "path/filepath" |
385 | + "strings" |
386 | "time" |
387 | |
388 | "github.com/GoogleCloudPlatform/guest-agent/metadata" |
389 | @@ -28,15 +32,28 @@ import ( |
390 | ) |
391 | |
392 | const ( |
393 | - contentDirPrefix = "/run/secrets/workload-spiffe-contents" |
394 | + // trustAnchorsKey endpoint contains a set of trusted certificates for peer X.509 certificate chain validation. |
395 | + trustAnchorsKey = "instance/gce-workload-certificates/trust-anchors" |
396 | + // workloadIdentitiesKey endpoint contains identities managed by the GCE control plane. This contains the X.509 certificate and the private key for the VM's trust domain. |
397 | + workloadIdentitiesKey = "instance/gce-workload-certificates/workload-identities" |
398 | + // configStatusKey contains status and any errors in the config values provided via the VM metadata. |
399 | + configStatusKey = "instance/gce-workload-certificates/config-status" |
400 | + // enableWorkloadCertsKey is set to true as custom metadata to enable automatic provisioning of credentials. |
401 | + enableWorkloadCertsKey = "instance/attributes/enable-workload-certificate" |
402 | + // contentDirPrefix is used as prefx to create certificate directories on refresh as contentDirPrefix-<time>. |
403 | + contentDirPrefix = "/run/secrets/workload-spiffe-contents" |
404 | + // tempSymlinkPrefix is used as prefix to create temporary symlinks on refresh as tempSymlinkPrefix-<time> to content directories. |
405 | tempSymlinkPrefix = "/run/secrets/workload-spiffe-symlink" |
406 | - symlink = "/run/secrets/workload-spiffe-credentials" |
407 | - programName = "gce_workload_certs_refresh" |
408 | + // symlink points to the directory with current GCE workload certificates and is always expected to be present. |
409 | + symlink = "/run/secrets/workload-spiffe-credentials" |
410 | ) |
411 | |
412 | var ( |
413 | // mdsClient is the client used to query Metadata server. |
414 | - mdsClient *metadata.Client |
415 | + mdsClient metadata.MDSClientInterface |
416 | + programName = path.Base(os.Args[0]) |
417 | + // timeNow returns current time, defining as variable allows the time to be stubbed during testing. |
418 | + timeNow = func() string { return time.Now().Format(time.RFC3339) } |
419 | ) |
420 | |
421 | func init() { |
422 | @@ -48,138 +65,86 @@ func logFormat(e logger.LogEntry) string { |
423 | return fmt.Sprintf("%s: %s", now, e.Message) |
424 | } |
425 | |
426 | +// isEnabled returns true only if enable-workload-certificate metadata attribute is present and set to true. |
427 | +func isEnabled(ctx context.Context) bool { |
428 | + resp, err := getMetadata(ctx, enableWorkloadCertsKey) |
429 | + if err != nil { |
430 | + logger.Debugf("Failed to get %q from MDS with error: %v", enableWorkloadCertsKey, err) |
431 | + return false |
432 | + } |
433 | + |
434 | + return bytes.EqualFold(resp, []byte("true")) |
435 | +} |
436 | + |
437 | func getMetadata(ctx context.Context, key string) ([]byte, error) { |
438 | // GCE Workload Certificate endpoints return 412 Precondition failed if the VM was |
439 | // never configured with valid config values at least once. Without valid config |
440 | // values GCE cannot provision the workload certificates. |
441 | resp, err := mdsClient.GetKey(ctx, key, nil) |
442 | if err != nil { |
443 | - return []byte{}, fmt.Errorf("failed to GET %q from MDS with error: %w", key, err) |
444 | + return nil, fmt.Errorf("failed to GET %q from MDS with error: %w", key, err) |
445 | } |
446 | return []byte(resp), nil |
447 | } |
448 | |
449 | /* |
450 | metadata key instance/gce-workload-certificates/workload-identities |
451 | - |
452 | - { |
453 | - "status": "OK", |
454 | - "workloadCredentials": { |
455 | - "PROJECT_ID.svc.id.goog": { |
456 | - "metadata": { |
457 | - "workload_creds_dir_path": "/var/run/secrets/workload-spiffe-credentials" |
458 | - }, |
459 | - "certificatePem": "-----BEGIN CERTIFICATE-----datahere-----END CERTIFICATE-----", |
460 | - "privateKeyPem": "-----BEGIN PRIVATE KEY-----datahere-----END PRIVATE KEY-----" |
461 | - } |
462 | - } |
463 | - } |
464 | -*/ |
465 | - |
466 | -// WorkloadIdentities represents Workload Identities in metadata. |
467 | -type WorkloadIdentities struct { |
468 | - Status string |
469 | - WorkloadCredentials map[string]WorkloadCredential |
470 | -} |
471 | - |
472 | -// UnmarshalJSON is a custom JSON unmarshaller for WorkloadIdentities. |
473 | -func (wi *WorkloadIdentities) UnmarshalJSON(b []byte) error { |
474 | - tmp := map[string]json.RawMessage{} |
475 | - err := json.Unmarshal(b, &tmp) |
476 | - if err != nil { |
477 | - return err |
478 | - } |
479 | - |
480 | - if err := json.Unmarshal(tmp["status"], &wi.Status); err != nil { |
481 | - return err |
482 | - } |
483 | - |
484 | - wi.WorkloadCredentials = map[string]WorkloadCredential{} |
485 | - wcs := map[string]json.RawMessage{} |
486 | - if err := json.Unmarshal(tmp["workloadCredentials"], &wcs); err != nil { |
487 | - return err |
488 | - } |
489 | - |
490 | - for domain, value := range wcs { |
491 | - wc := WorkloadCredential{} |
492 | - err := json.Unmarshal(value, &wc) |
493 | - if err != nil { |
494 | - return err |
495 | +MANAGED_WORKLOAD_IDENTITY_SPIFFE is of the format: |
496 | +spiffe://POOL_ID.global.PROJECT_NUMBER.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID |
497 | + |
498 | +{ |
499 | + "status": "OK", // Status of the response, |
500 | + "workloadCredentials": { // Credentials for the VM's trust domains |
501 | + "MANAGED_WORKLOAD_IDENTITY_SPIFFE": { |
502 | + "certificatePem": "-----BEGIN CERTIFICATE-----datahere-----END CERTIFICATE-----", |
503 | + "privateKeyPem": "-----BEGIN PRIVATE KEY-----datahere-----END PRIVATE KEY-----" |
504 | } |
505 | - wi.WorkloadCredentials[domain] = wc |
506 | } |
507 | - |
508 | - return nil |
509 | } |
510 | +*/ |
511 | |
512 | // WorkloadCredential represents Workload Credentials in metadata. |
513 | type WorkloadCredential struct { |
514 | - Metadata Metadata |
515 | - CertificatePem string |
516 | - PrivateKeyPem string |
517 | + CertificatePem string `json:"certificatePem"` |
518 | + PrivateKeyPem string `json:"privateKeyPem"` |
519 | } |
520 | |
521 | -/* |
522 | -metadata key instance/gce-workload-certificates/root-certs |
523 | - |
524 | - { |
525 | - "status": "OK", |
526 | - "rootCertificates": { |
527 | - "PROJECT.svc.id.goog": { |
528 | - "metadata": { |
529 | - "workload_creds_dir_path": "/var/run/secrets/workload-spiffe-credentials" |
530 | - }, |
531 | - "rootCertificatesPem": "-----BEGIN CERTIFICATE-----datahere-----END CERTIFICATE-----" |
532 | - } |
533 | - } |
534 | - } |
535 | -*/ |
536 | - |
537 | -// WorkloadTrustedRootCerts represents Workload Trusted Root Certs in metadata. |
538 | -type WorkloadTrustedRootCerts struct { |
539 | - Status string |
540 | - RootCertificates map[string]RootCertificate |
541 | +// WorkloadIdentities represents Workload Identities in metadata. |
542 | +type WorkloadIdentities struct { |
543 | + Status string `json:"status"` |
544 | + WorkloadCredentials map[string]WorkloadCredential `json:"workloadCredentials"` |
545 | } |
546 | |
547 | -// UnmarshalJSON is a custom JSON unmarshaller for WorkloadTrustedRootCerts |
548 | -func (wtrc *WorkloadTrustedRootCerts) UnmarshalJSON(b []byte) error { |
549 | - tmp := map[string]json.RawMessage{} |
550 | - err := json.Unmarshal(b, &tmp) |
551 | - if err != nil { |
552 | - return err |
553 | - } |
554 | - |
555 | - if err := json.Unmarshal(tmp["status"], &wtrc.Status); err != nil { |
556 | - return err |
557 | - } |
558 | - |
559 | - wtrc.RootCertificates = map[string]RootCertificate{} |
560 | - rcs := map[string]json.RawMessage{} |
561 | - if err := json.Unmarshal(tmp["rootCertificates"], &rcs); err != nil { |
562 | - return err |
563 | - } |
564 | - |
565 | - for domain, value := range rcs { |
566 | - rc := RootCertificate{} |
567 | - err := json.Unmarshal(value, &rc) |
568 | - if err != nil { |
569 | - return err |
570 | - } |
571 | - wtrc.RootCertificates[domain] = rc |
572 | - } |
573 | +/* |
574 | +metadata key instance/gce-workload-certificates/trust-anchors |
575 | + |
576 | +{ |
577 | + "status": "<status string>" // Status of the response, |
578 | + "trustAnchors": { // Trust bundle for the VM's trust domains |
579 | + "PEER_SPIFFE_TRUST_DOMAIN_1": { |
580 | + "trustAnchorsPem" : "<Trust bundle containing the X.509 roots certificates>", |
581 | + }, |
582 | + "PEER_SPIFFE_TRUST_DOMAIN_2": { |
583 | + "trustAnchorsPem" : "<Trust bundle containing the X.509 roots certificates>", |
584 | + } |
585 | + } |
586 | +} |
587 | +*/ |
588 | |
589 | - return nil |
590 | +// TrustAnchor represents one or more certificates in an arbitrary order in the metadata. |
591 | +type TrustAnchor struct { |
592 | + TrustAnchorsPem string `json:"trustAnchorsPem"` |
593 | } |
594 | |
595 | -// RootCertificate represents a Root Certificate in metadata |
596 | -type RootCertificate struct { |
597 | - Metadata Metadata |
598 | - RootCertificatesPem string |
599 | +// WorkloadTrustedAnchors represents Workload Trusted Root Certs in metadata. |
600 | +type WorkloadTrustedAnchors struct { |
601 | + Status string `json:"status"` |
602 | + TrustAnchors map[string]TrustAnchor `json:"trustAnchors"` |
603 | } |
604 | |
605 | -// Metadata represents Metadata in metadata |
606 | -type Metadata struct { |
607 | - WorkloadCredsDirPath string |
608 | +// outputOpts is a struct for output directory name and symlink templates. |
609 | +type outputOpts struct { |
610 | + contentDirPrefix, tempSymlinkPrefix, symlink string |
611 | } |
612 | |
613 | func main() { |
614 | @@ -193,37 +158,102 @@ func main() { |
615 | } |
616 | |
617 | opts.Writers = []io.Writer{os.Stderr} |
618 | - logger.Init(ctx, opts) |
619 | + |
620 | + if err := logger.Init(ctx, opts); err != nil { |
621 | + fmt.Printf("Error initializing logger: %v", err) |
622 | + os.Exit(1) |
623 | + } |
624 | + |
625 | defer logger.Infof("Done") |
626 | |
627 | - // TODO: prune old dirs |
628 | - if err := refreshCreds(ctx); err != nil { |
629 | + if !isEnabled(ctx) { |
630 | + logger.Debugf("GCE Workload Certificate refresh is not enabled, skipping cert generation.") |
631 | + return |
632 | + } |
633 | + |
634 | + out := outputOpts{contentDirPrefix, tempSymlinkPrefix, symlink} |
635 | + if err := refreshCreds(ctx, out); err != nil { |
636 | logger.Fatalf("Error refreshCreds: %v", err.Error()) |
637 | } |
638 | |
639 | } |
640 | |
641 | -func refreshCreds(ctx context.Context) error { |
642 | - project, err := getMetadata(ctx, "project/project-id") |
643 | +// findDomain finds the anchor matching with the domain from spiffeID. |
644 | +// spiffeID is of the form - |
645 | +// spiffe://POOL_ID.global.PROJECT_NUMBER.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID |
646 | +// where domain is POOL_ID.global.PROJECT_NUMBER.workload.id.goog |
647 | +// anchors is a map of various domains and their corresponding trust PEMs. |
648 | +// However, if anchor map contains single entry it returns that without any check. |
649 | +func findDomain(anchors map[string]TrustAnchor, spiffeID string) (string, error) { |
650 | + c := len(anchors) |
651 | + for k := range anchors { |
652 | + if c == 1 { |
653 | + return k, nil |
654 | + } |
655 | + if strings.Contains(spiffeID, k) { |
656 | + return k, nil |
657 | + } |
658 | + } |
659 | + |
660 | + return "", fmt.Errorf("no matching trust anchor found") |
661 | +} |
662 | + |
663 | +// writeTrustAnchors parses the input data, finds the domain from spiffeID and writes ca_certificate.pem |
664 | +// in the destDir for that domain. |
665 | +func writeTrustAnchors(wtrcsMd []byte, destDir, spiffeID string) error { |
666 | + wtrcs := WorkloadTrustedAnchors{} |
667 | + if err := json.Unmarshal(wtrcsMd, &wtrcs); err != nil { |
668 | + return fmt.Errorf("error unmarshaling workload trusted root certs: %v", err) |
669 | + } |
670 | + |
671 | + // Currently there's only one trust anchor but there could be multipe trust anchors in future. |
672 | + // In either case we want the trust anchor with domain matching with the one in SPIFFE ID. |
673 | + domain, err := findDomain(wtrcs.TrustAnchors, spiffeID) |
674 | if err != nil { |
675 | - return fmt.Errorf("error getting project ID: %v", err) |
676 | + return err |
677 | } |
678 | |
679 | + return os.WriteFile(fmt.Sprintf("%s/ca_certificates.pem", destDir), []byte(wtrcs.TrustAnchors[domain].TrustAnchorsPem), 0644) |
680 | +} |
681 | + |
682 | +// writeWorkloadIdentities parses the input data, writes the certificates.pem, private_key.pem files in the |
683 | +// destDir, and returns the SPIFFE ID for which it wrote the certificates. |
684 | +func writeWorkloadIdentities(destDir string, wisMd []byte) (string, error) { |
685 | + var spiffeID string |
686 | + wis := WorkloadIdentities{} |
687 | + if err := json.Unmarshal(wisMd, &wis); err != nil { |
688 | + return "", fmt.Errorf("error unmarshaling workload identities response: %w", err) |
689 | + } |
690 | + |
691 | + // Its guaranteed to have single entry in workload credentials map. |
692 | + for k := range wis.WorkloadCredentials { |
693 | + spiffeID = k |
694 | + break |
695 | + } |
696 | + |
697 | + if err := os.WriteFile(filepath.Join(destDir, "certificates.pem"), []byte(wis.WorkloadCredentials[spiffeID].CertificatePem), 0644); err != nil { |
698 | + return "", fmt.Errorf("error writing certificates.pem: %w", err) |
699 | + } |
700 | + |
701 | + if err := os.WriteFile(filepath.Join(destDir, "private_key.pem"), []byte(wis.WorkloadCredentials[spiffeID].PrivateKeyPem), 0644); err != nil { |
702 | + return "", fmt.Errorf("error writing private_key.pem: %w", err) |
703 | + } |
704 | + return spiffeID, nil |
705 | +} |
706 | + |
707 | +func refreshCreds(ctx context.Context, opts outputOpts) error { |
708 | + now := timeNow() |
709 | + contentDir := fmt.Sprintf("%s-%s", opts.contentDirPrefix, now) |
710 | + tempSymlink := fmt.Sprintf("%s-%s", opts.tempSymlinkPrefix, now) |
711 | + |
712 | // Get status first so it can be written even when other endpoints are empty. |
713 | - certConfigStatus, err := getMetadata(ctx, "instance/gce-workload-certificates/config-status") |
714 | + certConfigStatus, err := getMetadata(ctx, configStatusKey) |
715 | if err != nil { |
716 | // Return success when certs are not configured to avoid unnecessary systemd failed units. |
717 | logger.Infof("Error getting config status, workload certificates may not be configured: %v", err) |
718 | return nil |
719 | } |
720 | |
721 | - domain := fmt.Sprintf("%s.svc.id.goog", project) |
722 | - logger.Infof("Rotating workload credentials for trust domain %s", domain) |
723 | - |
724 | - now := time.Now().Format(time.RFC3339) |
725 | - contentDir := fmt.Sprintf("%s-%s", contentDirPrefix, now) |
726 | - tempSymlink := fmt.Sprintf("%s-%s", tempSymlinkPrefix, now) |
727 | - |
728 | logger.Infof("Creating timestamp contents dir %s", contentDir) |
729 | |
730 | if err := os.MkdirAll(contentDir, 0755); err != nil { |
731 | @@ -231,72 +261,59 @@ func refreshCreds(ctx context.Context) error { |
732 | } |
733 | |
734 | // Write config_status first even if remaining endpoints are empty. |
735 | - if err := os.WriteFile(fmt.Sprintf("%s/config_status", contentDir), certConfigStatus, 0644); err != nil { |
736 | + if err := os.WriteFile(filepath.Join(contentDir, "config_status"), certConfigStatus, 0644); err != nil { |
737 | return fmt.Errorf("error writing config_status: %v", err) |
738 | } |
739 | |
740 | // Handles the edge case where the config values provided for the first time may be invalid. This ensures |
741 | - // that the symlink directory alwasys exists and contains the config_status to surface config errors to the VM. |
742 | - if _, err := os.Stat(symlink); os.IsNotExist(err) { |
743 | + // that the symlink directory always exists and contains the config_status to surface config errors to the VM. |
744 | + if _, err := os.Stat(opts.symlink); os.IsNotExist(err) { |
745 | logger.Infof("Creating new symlink %s", symlink) |
746 | |
747 | - if err := os.Symlink(contentDir, symlink); err != nil { |
748 | + if err := os.Symlink(contentDir, opts.symlink); err != nil { |
749 | return fmt.Errorf("error creating symlink: %v", err) |
750 | } |
751 | } |
752 | |
753 | // Now get the rest of the content. |
754 | - wisMd, err := getMetadata(ctx, "instance/gce-workload-certificates/workload-identities") |
755 | + wisMd, err := getMetadata(ctx, workloadIdentitiesKey) |
756 | if err != nil { |
757 | return fmt.Errorf("error getting workload-identities: %v", err) |
758 | } |
759 | |
760 | - wtrcsMd, err := getMetadata(ctx, "instance/gce-workload-certificates/root-certs") |
761 | + spiffeID, err := writeWorkloadIdentities(contentDir, wisMd) |
762 | if err != nil { |
763 | - return fmt.Errorf("error getting workload-trusted-root-certs: %v", err) |
764 | + return fmt.Errorf("failed to write workload identities with error: %w", err) |
765 | } |
766 | |
767 | - wis := WorkloadIdentities{} |
768 | - if err := json.Unmarshal(wisMd, &wis); err != nil { |
769 | - return fmt.Errorf("error unmarshaling workload identities response: %v", err) |
770 | - } |
771 | - |
772 | - wtrcs := WorkloadTrustedRootCerts{} |
773 | - if err := json.Unmarshal(wtrcsMd, &wtrcs); err != nil { |
774 | - return fmt.Errorf("error unmarshaling workload trusted root certs: %v", err) |
775 | - } |
776 | - |
777 | - if err := os.WriteFile(fmt.Sprintf("%s/certificates.pem", contentDir), []byte(wis.WorkloadCredentials[domain].CertificatePem), 0644); err != nil { |
778 | - return fmt.Errorf("error writing certificates.pem: %v", err) |
779 | - } |
780 | - |
781 | - if err := os.WriteFile(fmt.Sprintf("%s/private_key.pem", contentDir), []byte(wis.WorkloadCredentials[domain].PrivateKeyPem), 0644); err != nil { |
782 | - return fmt.Errorf("error writing private_key.pem: %v", err) |
783 | + wtrcsMd, err := getMetadata(ctx, trustAnchorsKey) |
784 | + if err != nil { |
785 | + return fmt.Errorf("error getting workload-trust-anchors: %v", err) |
786 | } |
787 | |
788 | - if err := os.WriteFile(fmt.Sprintf("%s/ca_certificates.pem", contentDir), []byte(wtrcs.RootCertificates[domain].RootCertificatesPem), 0644); err != nil { |
789 | - return fmt.Errorf("error writing ca_certificates.pem: %v", err) |
790 | + if err := writeTrustAnchors(wtrcsMd, contentDir, spiffeID); err != nil { |
791 | + return fmt.Errorf("failed to write trust anchors: %w", err) |
792 | } |
793 | |
794 | if err := os.Symlink(contentDir, tempSymlink); err != nil { |
795 | return fmt.Errorf("error creating temporary link: %v", err) |
796 | } |
797 | |
798 | - oldTarget, err := os.Readlink(symlink) |
799 | + oldTarget, err := os.Readlink(opts.symlink) |
800 | if err != nil { |
801 | logger.Infof("Error reading existing symlink: %v\n", err) |
802 | oldTarget = "" |
803 | } |
804 | |
805 | // Only rotate on success of all steps above. |
806 | - logger.Infof("Rotating symlink %s", symlink) |
807 | + logger.Infof("Rotating symlink %s", opts.symlink) |
808 | |
809 | - if err := os.Rename(tempSymlink, symlink); err != nil { |
810 | + if err := os.Rename(tempSymlink, opts.symlink); err != nil { |
811 | return fmt.Errorf("error rotating target link: %v", err) |
812 | } |
813 | |
814 | // Clean up previous contents dir. |
815 | - newTarget, err := os.Readlink(symlink) |
816 | + newTarget, err := os.Readlink(opts.symlink) |
817 | if err != nil { |
818 | return fmt.Errorf("error reading new symlink: %v, unable to remove old symlink target", err) |
819 | } |
820 | diff --git a/gce_workload_cert_refresh/main_test.go b/gce_workload_cert_refresh/main_test.go |
821 | new file mode 100644 |
822 | index 0000000..5edb1c8 |
823 | --- /dev/null |
824 | +++ b/gce_workload_cert_refresh/main_test.go |
825 | @@ -0,0 +1,495 @@ |
826 | +// Copyright 2023 Google LLC |
827 | + |
828 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
829 | +// you may not use this file except in compliance with the License. |
830 | +// You may obtain a copy of the License at |
831 | + |
832 | +// https://www.apache.org/licenses/LICENSE-2.0 |
833 | + |
834 | +// Unless required by applicable law or agreed to in writing, software |
835 | +// distributed under the License is distributed on an "AS IS" BASIS, |
836 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
837 | +// See the License for the specific language governing permissions and |
838 | +// limitations under the License. |
839 | + |
840 | +package main |
841 | + |
842 | +import ( |
843 | + "context" |
844 | + "encoding/json" |
845 | + "fmt" |
846 | + "os" |
847 | + "path/filepath" |
848 | + "testing" |
849 | + |
850 | + "github.com/GoogleCloudPlatform/guest-agent/metadata" |
851 | + "github.com/google/go-cmp/cmp" |
852 | +) |
853 | + |
854 | +const ( |
855 | + workloadRespTpl = ` |
856 | + { |
857 | + "status": "OK", |
858 | + "workloadCredentials": { |
859 | + "%s": { |
860 | + "certificatePem": "%s", |
861 | + "privateKeyPem": "%s" |
862 | + } |
863 | + } |
864 | + } |
865 | + ` |
866 | + trustAnchorRespTpl = ` |
867 | + { |
868 | + "status": "Ok", |
869 | + "trustAnchors": { |
870 | + "%s": { |
871 | + "trustAnchorsPem": "%s" |
872 | + }, |
873 | + "%s": { |
874 | + "trustAnchorsPem": "%s" |
875 | + } |
876 | + } |
877 | + } |
878 | + ` |
879 | + testConfigStatusResp = ` |
880 | + { |
881 | + "status": "Ok", |
882 | + } |
883 | + ` |
884 | +) |
885 | + |
886 | +func TestWorkloadIdentitiesUnmarshal(t *testing.T) { |
887 | + certPem := "-----BEGIN CERTIFICATE-----datahere-----END CERTIFICATE-----" |
888 | + pvtPem := "-----BEGIN PRIVATE KEY-----datahere-----END PRIVATE KEY-----" |
889 | + spiffe := "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID" |
890 | + |
891 | + resp := fmt.Sprintf(workloadRespTpl, spiffe, certPem, pvtPem) |
892 | + want := WorkloadIdentities{ |
893 | + Status: "OK", |
894 | + WorkloadCredentials: map[string]WorkloadCredential{ |
895 | + spiffe: { |
896 | + CertificatePem: certPem, |
897 | + PrivateKeyPem: pvtPem, |
898 | + }, |
899 | + }, |
900 | + } |
901 | + |
902 | + got := WorkloadIdentities{} |
903 | + if err := json.Unmarshal([]byte(resp), &got); err != nil { |
904 | + t.Errorf("WorkloadIdentities.UnmarshalJSON(%s) failed unexpectedly with error: %v", resp, err) |
905 | + } |
906 | + |
907 | + if diff := cmp.Diff(want, got); diff != "" { |
908 | + t.Errorf("Workload identities diff (-want +got):\n%s", diff) |
909 | + } |
910 | +} |
911 | + |
912 | +func TestTrustAnchorsUnmarshal(t *testing.T) { |
913 | + domain1 := "12345.global.67890.workload.id.goog" |
914 | + pem1 := "-----BEGIN CERTIFICATE-----datahere1-----END CERTIFICATE-----" |
915 | + domain2 := "PEER_SPIFFE_TRUST_DOMAIN_2" |
916 | + pem2 := "-----BEGIN CERTIFICATE-----datahere2-----END CERTIFICATE-----" |
917 | + |
918 | + resp := fmt.Sprintf(trustAnchorRespTpl, domain1, pem1, domain2, pem2) |
919 | + want := WorkloadTrustedAnchors{ |
920 | + Status: "Ok", |
921 | + TrustAnchors: map[string]TrustAnchor{ |
922 | + domain1: { |
923 | + TrustAnchorsPem: pem1, |
924 | + }, |
925 | + domain2: { |
926 | + TrustAnchorsPem: pem2, |
927 | + }, |
928 | + }, |
929 | + } |
930 | + |
931 | + got := WorkloadTrustedAnchors{} |
932 | + if err := json.Unmarshal([]byte(resp), &got); err != nil { |
933 | + t.Errorf("WorkloadTrustedRootCerts.UnmarshalJSON(%s) failed unexpectedly with error: %v", resp, err) |
934 | + } |
935 | + |
936 | + if diff := cmp.Diff(want, got); diff != "" { |
937 | + t.Errorf("Workload trusted anchors diff (-want +got):\n%s", diff) |
938 | + } |
939 | +} |
940 | + |
941 | +func TestWriteTrustAnchors(t *testing.T) { |
942 | + spiffe := "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID" |
943 | + domain1 := "12345.global.67890.workload.id.goog" |
944 | + pem1 := "-----BEGIN CERTIFICATE-----datahere1-----END CERTIFICATE-----" |
945 | + domain2 := "PEER_SPIFFE_TRUST_DOMAIN_2" |
946 | + pem2 := "-----BEGIN CERTIFICATE-----datahere2-----END CERTIFICATE-----" |
947 | + |
948 | + resp := fmt.Sprintf(trustAnchorRespTpl, domain1, pem1, domain2, pem2) |
949 | + dir := t.TempDir() |
950 | + if err := writeTrustAnchors([]byte(resp), dir, spiffe); err != nil { |
951 | + t.Errorf("writeTrustAnchors(%s,%s,%s) failed unexpectedly with error %v", resp, dir, spiffe, err) |
952 | + } |
953 | + |
954 | + got, err := os.ReadFile(filepath.Join(dir, "ca_certificates.pem")) |
955 | + if err != nil { |
956 | + t.Errorf("failed to read file at %s with error: %v", filepath.Join(dir, "ca_certificates.pem"), err) |
957 | + } |
958 | + if string(got) != pem1 { |
959 | + t.Errorf("writeTrustAnchors(%s,%s,%s) wrote %q, expected to write %q", resp, dir, spiffe, string(got), pem1) |
960 | + } |
961 | +} |
962 | + |
963 | +func TestWriteWorkloadIdentities(t *testing.T) { |
964 | + certPem := "-----BEGIN CERTIFICATE-----datahere-----END CERTIFICATE-----" |
965 | + pvtPem := "-----BEGIN PRIVATE KEY-----datahere-----END PRIVATE KEY-----" |
966 | + spiffe := "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID" |
967 | + |
968 | + resp := fmt.Sprintf(workloadRespTpl, spiffe, certPem, pvtPem) |
969 | + dir := t.TempDir() |
970 | + |
971 | + gotID, err := writeWorkloadIdentities(dir, []byte(resp)) |
972 | + if err != nil { |
973 | + t.Errorf("writeWorkloadIdentities(%s,%s) failed unexpectedly with error %v", dir, resp, err) |
974 | + } |
975 | + if gotID != spiffe { |
976 | + t.Errorf("writeWorkloadIdentities(%s,%s) = %s, want %s", dir, resp, gotID, spiffe) |
977 | + } |
978 | + |
979 | + gotCertPem, err := os.ReadFile(filepath.Join(dir, "certificates.pem")) |
980 | + if err != nil { |
981 | + t.Errorf("failed to read file at %s with error: %v", filepath.Join(dir, "certificates.pem"), err) |
982 | + } |
983 | + if string(gotCertPem) != certPem { |
984 | + t.Errorf("writeWorkloadIdentities(%s,%s) wrote %q, expected to write %q", dir, resp, string(gotCertPem), certPem) |
985 | + } |
986 | + |
987 | + gotPvtPem, err := os.ReadFile(filepath.Join(dir, "private_key.pem")) |
988 | + if err != nil { |
989 | + t.Errorf("failed to read file at %s with error: %v", filepath.Join(dir, "private_key.pem"), err) |
990 | + } |
991 | + if string(gotPvtPem) != pvtPem { |
992 | + t.Errorf("writeWorkloadIdentities(%s,%s) wrote %q, expected to write %q", dir, resp, string(gotPvtPem), pvtPem) |
993 | + } |
994 | +} |
995 | + |
996 | +func TestFindDomainError(t *testing.T) { |
997 | + anchors := map[string]TrustAnchor{ |
998 | + "67890.global.12345.workload.id.goog": {}, |
999 | + "55555.global.67890.workload.id.goog": {}, |
1000 | + } |
1001 | + spiffeID := "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID" |
1002 | + |
1003 | + if _, err := findDomain(anchors, spiffeID); err == nil { |
1004 | + t.Errorf("findDomain(%+v, %s) succeded for unknown anchors, want error", anchors, spiffeID) |
1005 | + } |
1006 | +} |
1007 | + |
1008 | +func TestFindDomain(t *testing.T) { |
1009 | + tests := []struct { |
1010 | + desc string |
1011 | + anchors map[string]TrustAnchor |
1012 | + spiffeID string |
1013 | + want string |
1014 | + }{ |
1015 | + { |
1016 | + desc: "single_trust_anchor", |
1017 | + anchors: map[string]TrustAnchor{"12345.global.67890.workload.id.goog": {}}, |
1018 | + spiffeID: "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID", |
1019 | + want: "12345.global.67890.workload.id.goog", |
1020 | + }, |
1021 | + { |
1022 | + desc: "multiple_trust_anchor", |
1023 | + anchors: map[string]TrustAnchor{ |
1024 | + "67890.global.12345.workload.id.goog": {}, |
1025 | + "12345.global.67890.workload.id.goog": {}, |
1026 | + }, |
1027 | + spiffeID: "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID", |
1028 | + want: "12345.global.67890.workload.id.goog", |
1029 | + }, |
1030 | + } |
1031 | + |
1032 | + for _, test := range tests { |
1033 | + t.Run(test.desc, func(t *testing.T) { |
1034 | + got, err := findDomain(test.anchors, test.spiffeID) |
1035 | + if err != nil { |
1036 | + t.Errorf("findDomain(%+v, %s) failed unexpectedly with error: %v", test.anchors, test.spiffeID, err) |
1037 | + } |
1038 | + if got != test.want { |
1039 | + t.Errorf("findDomain(%+v, %s) = %s, want %s", test.anchors, test.spiffeID, got, test.want) |
1040 | + } |
1041 | + }) |
1042 | + } |
1043 | +} |
1044 | + |
1045 | +func TestIsEnabled(t *testing.T) { |
1046 | + ctx := context.Background() |
1047 | + |
1048 | + tests := []struct { |
1049 | + desc string |
1050 | + enabled string |
1051 | + want bool |
1052 | + err string |
1053 | + }{ |
1054 | + { |
1055 | + desc: "attr_correctly_added", |
1056 | + enabled: "true", |
1057 | + want: true, |
1058 | + }, |
1059 | + { |
1060 | + desc: "attr_incorrectly_added", |
1061 | + enabled: "blaah", |
1062 | + want: false, |
1063 | + }, |
1064 | + { |
1065 | + desc: "attr_not_added", |
1066 | + want: false, |
1067 | + err: enableWorkloadCertsKey, |
1068 | + }, |
1069 | + } |
1070 | + |
1071 | + for _, test := range tests { |
1072 | + t.Run(test.desc, func(t *testing.T) { |
1073 | + mdsClient = &mdsTestClient{enabled: test.enabled, throwErrOn: test.err} |
1074 | + if got := isEnabled(ctx); got != test.want { |
1075 | + t.Errorf("isEnabled(ctx) = %t, want %t", got, test.want) |
1076 | + } |
1077 | + }) |
1078 | + } |
1079 | +} |
1080 | + |
1081 | +// mdsTestClient is fake client to stub MDS response in unit tests. |
1082 | +type mdsTestClient struct { |
1083 | + // Is credential generation enabled. |
1084 | + enabled string |
1085 | + // Workload template. |
1086 | + spiffe, certPem, pvtPem string |
1087 | + // Trust Anchor template. |
1088 | + domain1, pem1, domain2, pem2 string |
1089 | + // Throw error on MDS request for "key". |
1090 | + throwErrOn string |
1091 | +} |
1092 | + |
1093 | +func (mds *mdsTestClient) Get(ctx context.Context) (*metadata.Descriptor, error) { |
1094 | + return nil, fmt.Errorf("Get() not yet implemented") |
1095 | +} |
1096 | + |
1097 | +func (mds *mdsTestClient) GetKey(ctx context.Context, key string, headers map[string]string) (string, error) { |
1098 | + if mds.throwErrOn == key { |
1099 | + return "", fmt.Errorf("this is fake error for testing") |
1100 | + } |
1101 | + |
1102 | + switch key { |
1103 | + case enableWorkloadCertsKey: |
1104 | + return mds.enabled, nil |
1105 | + case configStatusKey: |
1106 | + return testConfigStatusResp, nil |
1107 | + case workloadIdentitiesKey: |
1108 | + return fmt.Sprintf(workloadRespTpl, mds.spiffe, mds.certPem, mds.pvtPem), nil |
1109 | + case trustAnchorsKey: |
1110 | + return fmt.Sprintf(trustAnchorRespTpl, mds.domain1, mds.pem1, mds.domain2, mds.pem2), nil |
1111 | + default: |
1112 | + return "", fmt.Errorf("unknown key %q", key) |
1113 | + } |
1114 | +} |
1115 | + |
1116 | +func (mds *mdsTestClient) GetKeyRecursive(ctx context.Context, key string) (string, error) { |
1117 | + return "", fmt.Errorf("GetKeyRecursive() not yet implemented") |
1118 | +} |
1119 | + |
1120 | +func (mds *mdsTestClient) Watch(ctx context.Context) (*metadata.Descriptor, error) { |
1121 | + return nil, fmt.Errorf("Watch() not yet implemented") |
1122 | +} |
1123 | + |
1124 | +func (mds *mdsTestClient) WriteGuestAttributes(ctx context.Context, key string, value string) error { |
1125 | + return fmt.Errorf("WriteGuestattributes() not yet implemented") |
1126 | +} |
1127 | + |
1128 | +func TestRefreshCreds(t *testing.T) { |
1129 | + ctx := context.Background() |
1130 | + tmp := t.TempDir() |
1131 | + |
1132 | + // Templates to use in iterations. |
1133 | + spiffeTpl := "spiffe://12345.global.67890.workload.id.goog.%d/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID" |
1134 | + domain1Tpl := "12345.global.67890.workload.id.goog.%d" |
1135 | + pem1Tpl := "-----BEGIN CERTIFICATE-----datahere1.%d-----END CERTIFICATE-----" |
1136 | + domain2 := "PEER_SPIFFE_TRUST_DOMAIN_2_IGNORE" |
1137 | + pem2Tpl := "-----BEGIN CERTIFICATE-----datahere2.%d-----END CERTIFICATE-----" |
1138 | + certPemTpl := "-----BEGIN CERTIFICATE-----datahere.%d-----END CERTIFICATE-----" |
1139 | + pvtPemTpl := "-----BEGIN PRIVATE KEY-----datahere.%d-----END PRIVATE KEY-----" |
1140 | + |
1141 | + contentPrefix := filepath.Join(tmp, "workload-spiffe-contents") |
1142 | + tmpSymlinkPrefix := filepath.Join(tmp, "workload-spiffe-symlink") |
1143 | + link := filepath.Join(tmp, "workload-spiffe-credentials") |
1144 | + out := outputOpts{contentPrefix, tmpSymlinkPrefix, link} |
1145 | + |
1146 | + // Run refresh creds thrice to test updates. |
1147 | + // Link (workload-spiffe-credentials) should always refer to the updated content |
1148 | + // and previous directories should be removed. |
1149 | + for i := 1; i <= 3; i++ { |
1150 | + timeNow = func() string { return fmt.Sprintf("%d", i) } |
1151 | + spiffe := fmt.Sprintf(spiffeTpl, i) |
1152 | + domain1 := fmt.Sprintf(domain1Tpl, i) |
1153 | + pem1 := fmt.Sprintf(pem1Tpl, i) |
1154 | + pem2 := fmt.Sprintf(pem2Tpl, i) |
1155 | + certPem := fmt.Sprintf(certPemTpl, i) |
1156 | + pvtPem := fmt.Sprintf(pvtPemTpl, i) |
1157 | + |
1158 | + mdsClient = &mdsTestClient{ |
1159 | + spiffe: spiffe, |
1160 | + certPem: certPem, |
1161 | + pvtPem: pvtPem, |
1162 | + domain1: domain1, |
1163 | + pem1: pem1, |
1164 | + domain2: domain2, |
1165 | + pem2: pem2, |
1166 | + } |
1167 | + |
1168 | + if err := refreshCreds(ctx, out); err != nil { |
1169 | + t.Errorf("refreshCreds(ctx, %+v) failed unexpectedly with error: %v", out, err) |
1170 | + } |
1171 | + |
1172 | + // Verify all files are created with the content as expected. |
1173 | + tests := []struct { |
1174 | + path string |
1175 | + content string |
1176 | + }{ |
1177 | + { |
1178 | + path: filepath.Join(link, "ca_certificates.pem"), |
1179 | + content: pem1, |
1180 | + }, |
1181 | + { |
1182 | + path: filepath.Join(link, "certificates.pem"), |
1183 | + content: certPem, |
1184 | + }, |
1185 | + { |
1186 | + path: filepath.Join(link, "private_key.pem"), |
1187 | + content: pvtPem, |
1188 | + }, |
1189 | + { |
1190 | + path: filepath.Join(link, "config_status"), |
1191 | + content: testConfigStatusResp, |
1192 | + }, |
1193 | + } |
1194 | + |
1195 | + for _, test := range tests { |
1196 | + t.Run(test.path, func(t *testing.T) { |
1197 | + got, err := os.ReadFile(test.path) |
1198 | + if err != nil { |
1199 | + t.Errorf("failed to read expected file %q and content %q with error: %v", test.path, test.content, err) |
1200 | + } |
1201 | + if string(got) != test.content { |
1202 | + t.Errorf("refreshCreds(ctx, %+v) wrote %q, want content %q", out, string(got), test.content) |
1203 | + } |
1204 | + }) |
1205 | + } |
1206 | + |
1207 | + // Verify the symlink was created and references the right destination directory. |
1208 | + want := fmt.Sprintf("%s-%d", contentPrefix, i) |
1209 | + got, err := os.Readlink(link) |
1210 | + if err != nil { |
1211 | + t.Errorf("os.Readlink(%s) failed unexpectedly with error %v", link, err) |
1212 | + } |
1213 | + if got != want { |
1214 | + t.Errorf("os.Readlink(%s) = %s, want %s", link, got, want) |
1215 | + } |
1216 | + |
1217 | + // If its not first run make sure prev creds are deleted. |
1218 | + if i > 1 { |
1219 | + prevDir := fmt.Sprintf("%s-%d", contentPrefix, i-1) |
1220 | + if _, err := os.Stat(prevDir); err == nil { |
1221 | + t.Errorf("os.Stat(%s) succeeded on prev content directory, want error", prevDir) |
1222 | + } |
1223 | + } |
1224 | + } |
1225 | +} |
1226 | + |
1227 | +func TestRefreshCredsError(t *testing.T) { |
1228 | + ctx := context.Background() |
1229 | + tmp := t.TempDir() |
1230 | + |
1231 | + // Templates to use in iterations. |
1232 | + spiffe := "spiffe://12345.global.67890.workload.id.goog/ns/NAMESPACE_ID/sa/MANAGED_IDENTITY_ID" |
1233 | + domain1 := "12345.global.67890.workload.id.goog" |
1234 | + pem1 := "-----BEGIN CERTIFICATE-----datahere1-----END CERTIFICATE-----" |
1235 | + domain2 := "PEER_SPIFFE_TRUST_DOMAIN_2_IGNORE" |
1236 | + pem2 := "-----BEGIN CERTIFICATE-----datahere2-----END CERTIFICATE-----" |
1237 | + certPem := "-----BEGIN CERTIFICATE-----datahere-----END CERTIFICATE-----" |
1238 | + pvtPem := "-----BEGIN PRIVATE KEY-----datahere-----END PRIVATE KEY-----" |
1239 | + |
1240 | + contentPrefix := filepath.Join(tmp, "workload-spiffe-contents") |
1241 | + tmpSymlinkPrefix := filepath.Join(tmp, "workload-spiffe-symlink") |
1242 | + link := filepath.Join(tmp, "workload-spiffe-credentials") |
1243 | + out := outputOpts{contentPrefix, tmpSymlinkPrefix, link} |
1244 | + |
1245 | + client := &mdsTestClient{ |
1246 | + spiffe: spiffe, |
1247 | + certPem: certPem, |
1248 | + pvtPem: pvtPem, |
1249 | + domain1: domain1, |
1250 | + pem1: pem1, |
1251 | + domain2: domain2, |
1252 | + pem2: pem2, |
1253 | + } |
1254 | + |
1255 | + mdsClient = client |
1256 | + |
1257 | + // Run refresh creds twice. First run would succeed and second would fail. Verify all |
1258 | + // creds generated on the first run are present as is after failed second run. |
1259 | + for i := 1; i <= 2; i++ { |
1260 | + timeNow = func() string { return fmt.Sprintf("%d", i) } |
1261 | + |
1262 | + if i == 1 { |
1263 | + // First run should succeed. |
1264 | + if err := refreshCreds(ctx, out); err != nil { |
1265 | + t.Errorf("refreshCreds(ctx, %+v) failed unexpectedly with error: %v", out, err) |
1266 | + } |
1267 | + } else if i == 2 { |
1268 | + // Second run should fail. Fail in getting last metadata entry. |
1269 | + client.throwErrOn = trustAnchorsKey |
1270 | + if err := refreshCreds(ctx, out); err == nil { |
1271 | + t.Errorf("refreshCreds(ctx, %+v) succeeded for fake metadata error, should've failed", out) |
1272 | + } |
1273 | + } |
1274 | + |
1275 | + // Verify all files are created and are still present with the content as expected. |
1276 | + tests := []struct { |
1277 | + path string |
1278 | + content string |
1279 | + }{ |
1280 | + { |
1281 | + path: filepath.Join(link, "ca_certificates.pem"), |
1282 | + content: pem1, |
1283 | + }, |
1284 | + { |
1285 | + path: filepath.Join(link, "certificates.pem"), |
1286 | + content: certPem, |
1287 | + }, |
1288 | + { |
1289 | + path: filepath.Join(link, "private_key.pem"), |
1290 | + content: pvtPem, |
1291 | + }, |
1292 | + { |
1293 | + path: filepath.Join(link, "config_status"), |
1294 | + content: testConfigStatusResp, |
1295 | + }, |
1296 | + } |
1297 | + |
1298 | + for _, test := range tests { |
1299 | + t.Run(test.path, func(t *testing.T) { |
1300 | + got, err := os.ReadFile(test.path) |
1301 | + if err != nil { |
1302 | + t.Errorf("failed to read expected file %q and content %q with error: %v", test.path, test.content, err) |
1303 | + } |
1304 | + if string(got) != test.content { |
1305 | + t.Errorf("refreshCreds(ctx, %+v) wrote %q, want content %q", out, string(got), test.content) |
1306 | + } |
1307 | + }) |
1308 | + } |
1309 | + |
1310 | + // Verify the symlink was created and references the same destination directory. |
1311 | + want := fmt.Sprintf("%s-%d", contentPrefix, 1) |
1312 | + got, err := os.Readlink(link) |
1313 | + if err != nil { |
1314 | + t.Errorf("os.Readlink(%s) failed unexpectedly with error %v", link, err) |
1315 | + } |
1316 | + if got != want { |
1317 | + t.Errorf("os.Readlink(%s) = %s, want %s", link, got, want) |
1318 | + } |
1319 | + } |
1320 | +} |
1321 | diff --git a/go.mod b/go.mod |
1322 | index d5f8c28..1e4ea08 100644 |
1323 | --- a/go.mod |
1324 | +++ b/go.mod |
1325 | @@ -10,12 +10,14 @@ require ( |
1326 | github.com/go-ini/ini v1.66.6 |
1327 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da |
1328 | github.com/golang/protobuf v1.5.3 |
1329 | + github.com/google/go-cmp v0.5.9 |
1330 | github.com/google/go-tpm v0.9.0 |
1331 | github.com/google/go-tpm-tools v0.4.0 |
1332 | github.com/google/tink/go v1.7.0 |
1333 | github.com/kardianos/service v1.2.1 |
1334 | github.com/robfig/cron/v3 v3.0.1 |
1335 | github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07 |
1336 | + golang.org/x/crypto v0.11.0 |
1337 | golang.org/x/sys v0.11.0 |
1338 | google.golang.org/grpc v1.57.0 |
1339 | google.golang.org/protobuf v1.31.0 |
1340 | @@ -29,7 +31,7 @@ require ( |
1341 | cloud.google.com/go/iam v1.1.1 // indirect |
1342 | cloud.google.com/go/logging v1.7.0 // indirect |
1343 | cloud.google.com/go/longrunning v0.5.1 // indirect |
1344 | - github.com/google/go-cmp v0.5.9 // indirect |
1345 | + github.com/Microsoft/go-winio v0.6.1 // indirect |
1346 | github.com/google/go-sev-guest v0.7.0 // indirect |
1347 | github.com/google/logger v1.1.1 // indirect |
1348 | github.com/google/s2a-go v0.1.4 // indirect |
1349 | @@ -39,11 +41,12 @@ require ( |
1350 | github.com/pborman/uuid v1.2.1 // indirect |
1351 | github.com/pkg/errors v0.9.1 // indirect |
1352 | go.opencensus.io v0.24.0 // indirect |
1353 | - golang.org/x/crypto v0.11.0 // indirect |
1354 | + golang.org/x/mod v0.8.0 // indirect |
1355 | golang.org/x/net v0.12.0 // indirect |
1356 | golang.org/x/oauth2 v0.10.0 // indirect |
1357 | golang.org/x/sync v0.3.0 // indirect |
1358 | golang.org/x/text v0.11.0 // indirect |
1359 | + golang.org/x/tools v0.6.0 // indirect |
1360 | golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect |
1361 | google.golang.org/api v0.134.0 // indirect |
1362 | google.golang.org/appengine v1.6.7 // indirect |
1363 | diff --git a/go.sum b/go.sum |
1364 | index 55c71c0..9c5c73b 100644 |
1365 | --- a/go.sum |
1366 | +++ b/go.sum |
1367 | @@ -17,6 +17,8 @@ cloud.google.com/go/storage v1.31.0/go.mod h1:81ams1PrhW16L4kF7qg+4mTq7SRs5HsbDT |
1368 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= |
1369 | github.com/GoogleCloudPlatform/guest-logging-go v0.0.0-20230710215706-450679fd88a9 h1:b3geIwOPAShYtR4F0XFt+2NJXTHVTfbxUFmrpiZXHdQ= |
1370 | github.com/GoogleCloudPlatform/guest-logging-go v0.0.0-20230710215706-450679fd88a9/go.mod h1:6ZqSUIZRAPR5dNMWJ+FwIarFFQ9t5qalaKQs20o6h+I= |
1371 | +github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= |
1372 | +github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= |
1373 | github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= |
1374 | github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= |
1375 | github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= |
1376 | @@ -94,8 +96,10 @@ github.com/googleapis/enterprise-certificate-proxy v0.2.5/go.mod h1:RxW0N9901Cko |
1377 | github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas= |
1378 | github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU= |
1379 | github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= |
1380 | +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= |
1381 | github.com/kardianos/service v1.2.1 h1:AYndMsehS+ywIS6RB9KOlcXzteWUzxgMgBymJD7+BYk= |
1382 | github.com/kardianos/service v1.2.1/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= |
1383 | +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= |
1384 | github.com/pborman/uuid v1.2.1 h1:+ZZIw58t/ozdjRaXh/3awHfmWRbzYxJoAdNJxe/3pvw= |
1385 | github.com/pborman/uuid v1.2.1/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= |
1386 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= |
1387 | @@ -134,6 +138,8 @@ golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTk |
1388 | golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= |
1389 | golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= |
1390 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= |
1391 | +golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= |
1392 | +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= |
1393 | golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= |
1394 | golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= |
1395 | golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= |
1396 | @@ -176,6 +182,7 @@ golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= |
1397 | golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= |
1398 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= |
1399 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= |
1400 | +golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= |
1401 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= |
1402 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= |
1403 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= |
1404 | @@ -191,6 +198,8 @@ golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3 |
1405 | golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= |
1406 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= |
1407 | golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= |
1408 | +golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= |
1409 | +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= |
1410 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= |
1411 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= |
1412 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= |
1413 | diff --git a/google_authorized_keys/main.go b/google_authorized_keys/main.go |
1414 | index 4961397..634ab5a 100644 |
1415 | --- a/google_authorized_keys/main.go |
1416 | +++ b/google_authorized_keys/main.go |
1417 | @@ -20,24 +20,27 @@ import ( |
1418 | "encoding/json" |
1419 | "fmt" |
1420 | "io" |
1421 | - "net/http" |
1422 | "os" |
1423 | + "path" |
1424 | "runtime" |
1425 | "strconv" |
1426 | "strings" |
1427 | "time" |
1428 | |
1429 | + "github.com/GoogleCloudPlatform/guest-agent/metadata" |
1430 | "github.com/GoogleCloudPlatform/guest-agent/utils" |
1431 | "github.com/GoogleCloudPlatform/guest-logging-go/logger" |
1432 | ) |
1433 | |
1434 | var ( |
1435 | - programName = "GoogleAuthorizedKeysCommand" |
1436 | - metadataURL = "http://169.254.169.254/computeMetadata/v1/" |
1437 | - metadataRecursive = "/?recursive=true&alt=json" |
1438 | - defaultTimeout = 2 * time.Second |
1439 | + client metadata.MDSClientInterface |
1440 | + programName = path.Base(os.Args[0]) |
1441 | ) |
1442 | |
1443 | +func init() { |
1444 | + client = metadata.New() |
1445 | +} |
1446 | + |
1447 | func logFormat(e logger.LogEntry) string { |
1448 | now := time.Now().Format("2006/01/02 15:04:05") |
1449 | return fmt.Sprintf("%s %s: %s", now, programName, e.Message) |
1450 | @@ -49,45 +52,6 @@ func logFormatWindows(e logger.LogEntry) string { |
1451 | return fmt.Sprintf("%s %s: %s", now, programName, e.Message) |
1452 | } |
1453 | |
1454 | -func getMetadata(key string, recurse bool) ([]byte, error) { |
1455 | - client := &http.Client{ |
1456 | - Timeout: defaultTimeout, |
1457 | - } |
1458 | - |
1459 | - url := metadataURL + key |
1460 | - if recurse { |
1461 | - url += metadataRecursive |
1462 | - } |
1463 | - req, err := http.NewRequest("GET", url, nil) |
1464 | - if err != nil { |
1465 | - return nil, err |
1466 | - } |
1467 | - req.Header.Add("Metadata-Flavor", "Google") |
1468 | - |
1469 | - var res *http.Response |
1470 | - // Retry forever, increase sleep between retries (up to 5 times) in order |
1471 | - // to wait for slow network initialization. |
1472 | - var rt time.Duration |
1473 | - for i := 1; ; i++ { |
1474 | - res, err = client.Do(req) |
1475 | - if err == nil { |
1476 | - break |
1477 | - } |
1478 | - if i < 6 { |
1479 | - rt = time.Duration(3*i) * time.Second |
1480 | - } |
1481 | - logger.Errorf("error connecting to metadata server, retrying in %s, error: %v", rt, err) |
1482 | - time.Sleep(rt) |
1483 | - } |
1484 | - defer res.Body.Close() |
1485 | - |
1486 | - md, err := io.ReadAll(res.Body) |
1487 | - if err != nil { |
1488 | - return nil, err |
1489 | - } |
1490 | - return md, nil |
1491 | -} |
1492 | - |
1493 | func parseSSHKeys(username string, keys []string) []string { |
1494 | var keyList []string |
1495 | for _, key := range keys { |
1496 | @@ -143,7 +107,7 @@ type attributes struct { |
1497 | SSHKeys []string |
1498 | } |
1499 | |
1500 | -func getMetadataAttributes(metadataKey string) (*attributes, error) { |
1501 | +func getMetadataAttributes(ctx context.Context, metadataKey string) (*attributes, error) { |
1502 | var a attributes |
1503 | type jsonAttributes struct { |
1504 | EnableWindowsSSH string `json:"enable-windows-ssh"` |
1505 | @@ -151,11 +115,12 @@ func getMetadataAttributes(metadataKey string) (*attributes, error) { |
1506 | SSHKeys string `json:"ssh-keys"` |
1507 | } |
1508 | var ja jsonAttributes |
1509 | - metadata, err := getMetadata(metadataKey, true) |
1510 | + metadata, err := client.GetKeyRecursive(ctx, metadataKey) |
1511 | if err != nil { |
1512 | return nil, err |
1513 | } |
1514 | - if err := json.Unmarshal(metadata, &ja); err != nil { |
1515 | + |
1516 | + if err := json.Unmarshal([]byte(metadata), &ja); err != nil { |
1517 | return nil, err |
1518 | } |
1519 | |
1520 | @@ -191,12 +156,12 @@ func main() { |
1521 | } |
1522 | logger.Init(ctx, opts) |
1523 | |
1524 | - instanceAttributes, err := getMetadataAttributes("instance/attributes/") |
1525 | + instanceAttributes, err := getMetadataAttributes(ctx, "instance/attributes/") |
1526 | if err != nil { |
1527 | logger.Errorf("Cannot read instance metadata attributes: %v", err) |
1528 | os.Exit(1) |
1529 | } |
1530 | - projectAttributes, err := getMetadataAttributes("project/attributes/") |
1531 | + projectAttributes, err := getMetadataAttributes(ctx, "project/attributes/") |
1532 | if err != nil { |
1533 | logger.Errorf("Cannot read project metadata attributes: %v", err) |
1534 | os.Exit(1) |
1535 | diff --git a/google_authorized_keys/main_test.go b/google_authorized_keys/main_test.go |
1536 | index 49315f1..ea739a1 100644 |
1537 | --- a/google_authorized_keys/main_test.go |
1538 | +++ b/google_authorized_keys/main_test.go |
1539 | @@ -15,14 +15,15 @@ |
1540 | package main |
1541 | |
1542 | import ( |
1543 | + "context" |
1544 | "fmt" |
1545 | - "net/http" |
1546 | - "net/http/httptest" |
1547 | "reflect" |
1548 | "strconv" |
1549 | "strings" |
1550 | "testing" |
1551 | - "time" |
1552 | + |
1553 | + "github.com/GoogleCloudPlatform/guest-agent/metadata" |
1554 | + "github.com/GoogleCloudPlatform/guest-agent/utils" |
1555 | ) |
1556 | |
1557 | func stringSliceEqual(a, b []string) bool { |
1558 | @@ -50,16 +51,20 @@ var truebool *bool = &t |
1559 | var falsebool *bool = &f |
1560 | |
1561 | func TestParseSSHKeys(t *testing.T) { |
1562 | + pubKeyA := utils.MakeRandRSAPubKey(t) |
1563 | + pubKeyB := utils.MakeRandRSAPubKey(t) |
1564 | + pubKey := utils.MakeRandRSAPubKey(t) |
1565 | + |
1566 | keys := []string{ |
1567 | "# Here is some random data in the file.", |
1568 | - "usera:ssh-rsa AAAA1234USERA", |
1569 | - "userb:ssh-rsa AAAA1234USERB", |
1570 | - `usera:ssh-rsa AAAA1234 google-ssh {"userName":"usera@example.com","expireOn":"2095-04-23T12:34:56+0000"}`, |
1571 | - `usera:ssh-rsa AAAA1234 google-ssh {"userName":"usera@example.com","expireOn":"2020-04-23T12:34:56+0000"}`, |
1572 | + fmt.Sprintf("usera:ssh-rsa %s", pubKeyA), |
1573 | + fmt.Sprintf("userb:ssh-rsa %s", pubKeyB), |
1574 | + fmt.Sprintf(`usera:ssh-rsa %s google-ssh {"userName":"usera@example.com","expireOn":"2095-04-23T12:34:56+0000"}`, pubKey), |
1575 | + fmt.Sprintf(`usera:ssh-rsa %s google-ssh {"userName":"usera@example.com","expireOn":"2020-04-23T12:34:56+0000"}`, pubKey), |
1576 | } |
1577 | expected := []string{ |
1578 | - "ssh-rsa AAAA1234USERA", |
1579 | - `ssh-rsa AAAA1234 google-ssh {"userName":"usera@example.com","expireOn":"2095-04-23T12:34:56+0000"}`, |
1580 | + fmt.Sprintf("ssh-rsa %s", pubKeyA), |
1581 | + fmt.Sprintf(`ssh-rsa %s google-ssh {"userName":"usera@example.com","expireOn":"2095-04-23T12:34:56+0000"}`, pubKey), |
1582 | } |
1583 | |
1584 | user := "usera" |
1585 | @@ -117,6 +122,8 @@ func TestCheckWinSSHEnabled(t *testing.T) { |
1586 | } |
1587 | |
1588 | func TestGetUserKeysNew(t *testing.T) { |
1589 | + pubKey := utils.MakeRandRSAPubKey(t) |
1590 | + |
1591 | tests := []struct { |
1592 | userName string |
1593 | instanceMetadata attributes |
1594 | @@ -125,102 +132,114 @@ func TestGetUserKeysNew(t *testing.T) { |
1595 | }{ |
1596 | { |
1597 | userName: "name", |
1598 | - instanceMetadata: attributes{BlockProjectSSHKeys: false, |
1599 | - SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"}, |
1600 | + instanceMetadata: attributes{ |
1601 | + BlockProjectSSHKeys: false, |
1602 | + SSHKeys: []string{ |
1603 | + fmt.Sprintf("name:ssh-rsa %s instance1", pubKey), |
1604 | + fmt.Sprintf("othername:ssh-rsa %s instance2", pubKey), |
1605 | + }, |
1606 | }, |
1607 | projectMetadata: attributes{ |
1608 | - SSHKeys: []string{"name:ssh-rsa [KEY] project1", "othername:ssh-rsa [KEY] project2"}, |
1609 | + SSHKeys: []string{ |
1610 | + fmt.Sprintf("name:ssh-rsa %s project1", pubKey), |
1611 | + fmt.Sprintf("othername:ssh-rsa %s project2", pubKey), |
1612 | + }, |
1613 | + }, |
1614 | + expectedKeys: []string{ |
1615 | + fmt.Sprintf("ssh-rsa %s instance1", pubKey), |
1616 | + fmt.Sprintf("ssh-rsa %s project1", pubKey), |
1617 | }, |
1618 | - expectedKeys: []string{"ssh-rsa [KEY] instance1", "ssh-rsa [KEY] project1"}, |
1619 | }, |
1620 | { |
1621 | userName: "name", |
1622 | - instanceMetadata: attributes{BlockProjectSSHKeys: true, |
1623 | - SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"}, |
1624 | + instanceMetadata: attributes{ |
1625 | + BlockProjectSSHKeys: true, |
1626 | + SSHKeys: []string{ |
1627 | + fmt.Sprintf("name:ssh-rsa %s instance1", pubKey), |
1628 | + fmt.Sprintf("othername:ssh-rsa %s instance2", pubKey), |
1629 | + }, |
1630 | }, |
1631 | projectMetadata: attributes{ |
1632 | - SSHKeys: []string{"name:ssh-rsa [KEY] project1", "othername:ssh-rsa [KEY] project2"}, |
1633 | + SSHKeys: []string{ |
1634 | + fmt.Sprintf("name:ssh-rsa %s project1", pubKey), |
1635 | + fmt.Sprintf("othername:ssh-rsa %s project2", pubKey), |
1636 | + }, |
1637 | }, |
1638 | - expectedKeys: []string{"ssh-rsa [KEY] instance1"}, |
1639 | + expectedKeys: []string{fmt.Sprintf("ssh-rsa %s instance1", pubKey)}, |
1640 | }, |
1641 | { |
1642 | userName: "name", |
1643 | - instanceMetadata: attributes{BlockProjectSSHKeys: false, |
1644 | - SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"}, |
1645 | + instanceMetadata: attributes{ |
1646 | + BlockProjectSSHKeys: false, |
1647 | + SSHKeys: []string{ |
1648 | + fmt.Sprintf("name:ssh-rsa %s instance1", pubKey), |
1649 | + fmt.Sprintf("othername:ssh-rsa %s instance2", pubKey), |
1650 | + }, |
1651 | }, |
1652 | projectMetadata: attributes{ |
1653 | SSHKeys: nil, |
1654 | }, |
1655 | - expectedKeys: []string{"ssh-rsa [KEY] instance1"}, |
1656 | + expectedKeys: []string{fmt.Sprintf("ssh-rsa %s instance1", pubKey)}, |
1657 | }, |
1658 | { |
1659 | userName: "name", |
1660 | - instanceMetadata: attributes{BlockProjectSSHKeys: false, |
1661 | - SSHKeys: nil, |
1662 | + instanceMetadata: attributes{ |
1663 | + BlockProjectSSHKeys: false, |
1664 | + SSHKeys: nil, |
1665 | }, |
1666 | projectMetadata: attributes{ |
1667 | - SSHKeys: []string{"name:ssh-rsa [KEY] project1", "othername:ssh-rsa [KEY] project2"}, |
1668 | + SSHKeys: []string{ |
1669 | + fmt.Sprintf("name:ssh-rsa %s project1", pubKey), |
1670 | + fmt.Sprintf("othername:ssh-rsa %s project2", pubKey), |
1671 | + }, |
1672 | }, |
1673 | - expectedKeys: []string{"ssh-rsa [KEY] project1"}, |
1674 | + expectedKeys: []string{fmt.Sprintf("ssh-rsa %s project1", pubKey)}, |
1675 | }, |
1676 | } |
1677 | |
1678 | for count, tt := range tests { |
1679 | - if got, want := getUserKeys(tt.userName, &tt.instanceMetadata, &tt.projectMetadata), tt.expectedKeys; !stringSliceEqual(got, want) { |
1680 | - t.Errorf("getUserKeys[%d] incorrect return: got %v, want %v", count, got, want) |
1681 | - } |
1682 | + t.Run(fmt.Sprintf("test-%d", count), func(t *testing.T) { |
1683 | + if got, want := getUserKeys(tt.userName, &tt.instanceMetadata, &tt.projectMetadata), tt.expectedKeys; !stringSliceEqual(got, want) { |
1684 | + t.Errorf("getUserKeys[%d] incorrect return: got %v, want %v", count, got, want) |
1685 | + } |
1686 | + }) |
1687 | } |
1688 | } |
1689 | |
1690 | func TestGetMetadataAttributes(t *testing.T) { |
1691 | tests := []struct { |
1692 | - metadata string |
1693 | att *attributes |
1694 | expectErr bool |
1695 | }{ |
1696 | { |
1697 | - metadata: `{"enable-windows-ssh":"true","ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","block-project-ssh-keys":"false","other-metadata":"foo"}`, |
1698 | att: &attributes{EnableWindowsSSH: truebool, SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"}, BlockProjectSSHKeys: false}, |
1699 | expectErr: false, |
1700 | }, |
1701 | { |
1702 | - metadata: `{"enable-windows-ssh":"true","ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","block-project-ssh-keys":"true","other-metadata":"foo"}`, |
1703 | att: &attributes{EnableWindowsSSH: truebool, SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"}, BlockProjectSSHKeys: true}, |
1704 | expectErr: false, |
1705 | }, |
1706 | { |
1707 | - metadata: `{"ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","block-project-ssh-keys":"false","other-metadata":"foo"}`, |
1708 | att: &attributes{EnableWindowsSSH: nil, SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"}, BlockProjectSSHKeys: false}, |
1709 | expectErr: false, |
1710 | }, |
1711 | { |
1712 | - metadata: `{"enable-windows-ssh":"false","ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","other-metadata":"foo"}`, |
1713 | att: &attributes{EnableWindowsSSH: falsebool, SSHKeys: []string{"name:ssh-rsa [KEY] instance1", "othername:ssh-rsa [KEY] instance2"}, BlockProjectSSHKeys: false}, |
1714 | expectErr: false, |
1715 | }, |
1716 | { |
1717 | - metadata: `BADJSON`, |
1718 | att: nil, |
1719 | expectErr: true, |
1720 | }, |
1721 | } |
1722 | |
1723 | - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
1724 | - // Get test number from request path |
1725 | - tnum, _ := strconv.Atoi(strings.Split(r.URL.Path, "/")[2]) |
1726 | - fmt.Fprintf(w, tests[tnum].metadata) |
1727 | - })) |
1728 | - |
1729 | - defer ts.Close() |
1730 | - |
1731 | - metadataURL = ts.URL |
1732 | - defaultTimeout = 1 * time.Second |
1733 | + client = &mdsClient{} |
1734 | |
1735 | for count, tt := range tests { |
1736 | want := tt.att |
1737 | hasErr := false |
1738 | reqStr := fmt.Sprintf("/attributes/%d", count) |
1739 | - got, err := getMetadataAttributes(reqStr) |
1740 | + got, err := getMetadataAttributes(context.Background(), reqStr) |
1741 | if err != nil { |
1742 | hasErr = true |
1743 | } |
1744 | @@ -230,3 +249,43 @@ func TestGetMetadataAttributes(t *testing.T) { |
1745 | } |
1746 | } |
1747 | } |
1748 | + |
1749 | +type mdsClient struct{} |
1750 | + |
1751 | +func (mds *mdsClient) Get(ctx context.Context) (*metadata.Descriptor, error) { |
1752 | + return nil, fmt.Errorf("Get() not yet implemented") |
1753 | +} |
1754 | + |
1755 | +func (mds *mdsClient) GetKey(ctx context.Context, key string, headers map[string]string) (string, error) { |
1756 | + return "", fmt.Errorf("GetKey() not yet implemented") |
1757 | +} |
1758 | + |
1759 | +func (mds *mdsClient) GetKeyRecursive(ctx context.Context, key string) (string, error) { |
1760 | + i, err := strconv.Atoi(key[strings.LastIndex(key, "/")+1:]) |
1761 | + if err != nil { |
1762 | + return "", err |
1763 | + } |
1764 | + |
1765 | + switch i { |
1766 | + case 0: |
1767 | + return `{"enable-windows-ssh":"true","ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","block-project-ssh-keys":"false","other-metadata":"foo"}`, nil |
1768 | + case 1: |
1769 | + return `{"enable-windows-ssh":"true","ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","block-project-ssh-keys":"true","other-metadata":"foo"}`, nil |
1770 | + case 2: |
1771 | + return `{"ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","block-project-ssh-keys":"false","other-metadata":"foo"}`, nil |
1772 | + case 3: |
1773 | + return `{"enable-windows-ssh":"false","ssh-keys":"name:ssh-rsa [KEY] instance1\nothername:ssh-rsa [KEY] instance2","other-metadata":"foo"}`, nil |
1774 | + case 4: |
1775 | + return "BADJSON", nil |
1776 | + default: |
1777 | + return "", fmt.Errorf("unknown key %q", key) |
1778 | + } |
1779 | +} |
1780 | + |
1781 | +func (mds *mdsClient) Watch(ctx context.Context) (*metadata.Descriptor, error) { |
1782 | + return nil, fmt.Errorf("Watch() not yet implemented") |
1783 | +} |
1784 | + |
1785 | +func (mds *mdsClient) WriteGuestAttributes(ctx context.Context, key string, value string) error { |
1786 | + return fmt.Errorf("WriteGuestattributes() not yet implemented") |
1787 | +} |
1788 | diff --git a/google_guest_agent/addresses.go b/google_guest_agent/addresses.go |
1789 | index 7c0c741..6f13797 100644 |
1790 | --- a/google_guest_agent/addresses.go |
1791 | +++ b/google_guest_agent/addresses.go |
1792 | @@ -19,25 +19,21 @@ import ( |
1793 | "errors" |
1794 | "fmt" |
1795 | "net" |
1796 | - "os" |
1797 | "reflect" |
1798 | "runtime" |
1799 | + "slices" |
1800 | "strings" |
1801 | - "time" |
1802 | |
1803 | "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg" |
1804 | + network "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/network/manager" |
1805 | "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/run" |
1806 | - "github.com/GoogleCloudPlatform/guest-agent/metadata" |
1807 | - "github.com/GoogleCloudPlatform/guest-agent/utils" |
1808 | "github.com/GoogleCloudPlatform/guest-logging-go/logger" |
1809 | ) |
1810 | |
1811 | var ( |
1812 | - addressKey = regKeyBase + `\ForwardedIps` |
1813 | - oldWSFCAddresses string |
1814 | - oldWSFCEnable bool |
1815 | - interfacesEnabled bool |
1816 | - interfaces []net.Interface |
1817 | + addressKey = regKeyBase + `\ForwardedIps` |
1818 | + oldWSFCAddresses string |
1819 | + oldWSFCEnable bool |
1820 | ) |
1821 | |
1822 | type addressMgr struct{} |
1823 | @@ -78,7 +74,9 @@ func getForwardsFromRegistry(mac string) ([]string, error) { |
1824 | oldName := strings.Replace(mac, ":", "", -1) |
1825 | regFwdIPs, err = readRegMultiString(addressKey, oldName) |
1826 | if err == nil { |
1827 | - deleteRegKey(addressKey, oldName) |
1828 | + if err = deleteRegKey(addressKey, oldName); err != nil { |
1829 | + logger.Warningf("Failed to delete key: %q, name: %q from registry", addressKey, oldName) |
1830 | + } |
1831 | } |
1832 | } else if err != nil { |
1833 | return nil, err |
1834 | @@ -88,13 +86,13 @@ func getForwardsFromRegistry(mac string) ([]string, error) { |
1835 | |
1836 | func compareRoutes(configuredRoutes, desiredRoutes []string) (toAdd, toRm []string) { |
1837 | for _, desiredRoute := range desiredRoutes { |
1838 | - if !utils.ContainsString(desiredRoute, configuredRoutes) { |
1839 | + if !slices.Contains(configuredRoutes, desiredRoute) { |
1840 | toAdd = append(toAdd, desiredRoute) |
1841 | } |
1842 | } |
1843 | |
1844 | for _, configuredRoute := range configuredRoutes { |
1845 | - if !utils.ContainsString(configuredRoute, desiredRoutes) { |
1846 | + if !slices.Contains(desiredRoutes, configuredRoute) { |
1847 | toRm = append(toRm, configuredRoute) |
1848 | } |
1849 | } |
1850 | @@ -103,20 +101,6 @@ func compareRoutes(configuredRoutes, desiredRoutes []string) (toAdd, toRm []stri |
1851 | |
1852 | var badMAC []string |
1853 | |
1854 | -func getInterfaceByMAC(mac string) (net.Interface, error) { |
1855 | - hwaddr, err := net.ParseMAC(mac) |
1856 | - if err != nil { |
1857 | - return net.Interface{}, err |
1858 | - } |
1859 | - |
1860 | - for _, iface := range interfaces { |
1861 | - if iface.HardwareAddr.String() == hwaddr.String() { |
1862 | - return iface, nil |
1863 | - } |
1864 | - } |
1865 | - return net.Interface{}, fmt.Errorf("no interface found with MAC %s", mac) |
1866 | -} |
1867 | - |
1868 | // https://www.ietf.org/rfc/rfc1354.txt |
1869 | // Only fields that we currently care about. |
1870 | type ipForwardEntry struct { |
1871 | @@ -221,7 +205,7 @@ func (a *addressMgr) applyWSFCFilter(config *cfg.Sections) { |
1872 | for idx := range interfaces { |
1873 | var filteredForwardedIps []string |
1874 | for _, ip := range interfaces[idx].ForwardedIps { |
1875 | - if !utils.ContainsString(ip, wsfcAddrs) { |
1876 | + if !slices.Contains(wsfcAddrs, ip) { |
1877 | filteredForwardedIps = append(filteredForwardedIps, ip) |
1878 | } |
1879 | } |
1880 | @@ -229,7 +213,7 @@ func (a *addressMgr) applyWSFCFilter(config *cfg.Sections) { |
1881 | |
1882 | var filteredTargetInstanceIps []string |
1883 | for _, ip := range interfaces[idx].TargetInstanceIps { |
1884 | - if !utils.ContainsString(ip, wsfcAddrs) { |
1885 | + if !slices.Contains(wsfcAddrs, ip) { |
1886 | filteredTargetInstanceIps = append(filteredTargetInstanceIps, ip) |
1887 | } |
1888 | } |
1889 | @@ -290,28 +274,10 @@ func (a *addressMgr) Set(ctx context.Context) error { |
1890 | a.applyWSFCFilter(config) |
1891 | } |
1892 | |
1893 | - var err error |
1894 | - interfaces, err = net.Interfaces() |
1895 | + // Setup network interfaces. |
1896 | + err := network.SetupInterfaces(ctx, config, newMetadata.Instance.NetworkInterfaces) |
1897 | if err != nil { |
1898 | - return fmt.Errorf("error populating interfaces: %v", err) |
1899 | - } |
1900 | - |
1901 | - if config.NetworkInterfaces.Setup { |
1902 | - if runtime.GOOS != "windows" { |
1903 | - logger.Debugf("Configure IPv6") |
1904 | - if err := configureIPv6(ctx); err != nil { |
1905 | - // Continue through IPv6 configuration errors. |
1906 | - logger.Errorf("Error configuring IPv6: %v", err) |
1907 | - } |
1908 | - } |
1909 | - |
1910 | - if runtime.GOOS != "windows" && !interfacesEnabled { |
1911 | - logger.Debugf("Enable network interfaces") |
1912 | - if err := enableNetworkInterfaces(ctx, config); err != nil { |
1913 | - return err |
1914 | - } |
1915 | - interfacesEnabled = true |
1916 | - } |
1917 | + return fmt.Errorf("failed to setup network interfaces: %v", err) |
1918 | } |
1919 | |
1920 | if !config.NetworkInterfaces.IPForwarding { |
1921 | @@ -321,9 +287,9 @@ func (a *addressMgr) Set(ctx context.Context) error { |
1922 | logger.Debugf("Add routes for aliases, forwarded IP and target-instance IPs") |
1923 | // Add routes for IP aliases, forwarded and target-instance IPs. |
1924 | for _, ni := range newMetadata.Instance.NetworkInterfaces { |
1925 | - iface, err := getInterfaceByMAC(ni.Mac) |
1926 | + iface, err := network.GetInterfaceByMAC(ni.Mac) |
1927 | if err != nil { |
1928 | - if !utils.ContainsString(ni.Mac, badMAC) { |
1929 | + if !slices.Contains(badMAC, ni.Mac) { |
1930 | logger.Errorf("Error getting interface: %s", err) |
1931 | badMAC = append(badMAC, ni.Mac) |
1932 | } |
1933 | @@ -356,7 +322,7 @@ func (a *addressMgr) Set(ctx context.Context) error { |
1934 | } |
1935 | for _, ip := range configuredIPs { |
1936 | // Only add to `forwardedIPs` if it is recorded in the registry. |
1937 | - if utils.ContainsString(ip, regFwdIPs) { |
1938 | + if slices.Contains(regFwdIPs, ip) { |
1939 | forwardedIPs = append(forwardedIPs, ip) |
1940 | } |
1941 | } |
1942 | @@ -399,14 +365,14 @@ func (a *addressMgr) Set(ctx context.Context) error { |
1943 | var registryEntries []string |
1944 | for _, ip := range wantIPs { |
1945 | // If the IP is not in toAdd, add to registry list and continue. |
1946 | - if !utils.ContainsString(ip, toAdd) { |
1947 | + if !slices.Contains(toAdd, ip) { |
1948 | registryEntries = append(registryEntries, ip) |
1949 | continue |
1950 | } |
1951 | var err error |
1952 | if runtime.GOOS == "windows" { |
1953 | // Don't addAddress if this is already configured. |
1954 | - if !utils.ContainsString(ip, configuredIPs) { |
1955 | + if !slices.Contains(configuredIPs, ip) { |
1956 | err = addAddress(net.ParseIP(ip), net.IPv4Mask(255, 255, 255, 255), uint32(iface.Index)) |
1957 | } |
1958 | } else { |
1959 | @@ -422,7 +388,7 @@ func (a *addressMgr) Set(ctx context.Context) error { |
1960 | for _, ip := range toRm { |
1961 | var err error |
1962 | if runtime.GOOS == "windows" { |
1963 | - if !utils.ContainsString(ip, configuredIPs) { |
1964 | + if !slices.Contains(configuredIPs, ip) { |
1965 | continue |
1966 | } |
1967 | err = removeAddress(net.ParseIP(ip), uint32(iface.Index)) |
1968 | @@ -445,193 +411,3 @@ func (a *addressMgr) Set(ctx context.Context) error { |
1969 | |
1970 | return nil |
1971 | } |
1972 | - |
1973 | -// Enables or disables IPv6 on network interfaces. |
1974 | -func configureIPv6(ctx context.Context) error { |
1975 | - var newNi, oldNi metadata.NetworkInterfaces |
1976 | - if len(newMetadata.Instance.NetworkInterfaces) == 0 { |
1977 | - return fmt.Errorf("no interfaces found in metadata") |
1978 | - } |
1979 | - newNi = newMetadata.Instance.NetworkInterfaces[0] |
1980 | - if len(oldMetadata.Instance.NetworkInterfaces) > 0 { |
1981 | - oldNi = oldMetadata.Instance.NetworkInterfaces[0] |
1982 | - } |
1983 | - iface, err := getInterfaceByMAC(newNi.Mac) |
1984 | - if err != nil { |
1985 | - return err |
1986 | - } |
1987 | - switch { |
1988 | - case oldNi.DHCPv6Refresh != "" && newNi.DHCPv6Refresh == "", |
1989 | - newNi.DHCPv6Refresh == "" && len(oldMetadata.Instance.NetworkInterfaces) == 0: |
1990 | - // disable |
1991 | - // uses empty old interface slice to indicate this is first-run. |
1992 | - |
1993 | - // Before obtaining or releasing an IPv6 lease, we wait for |
1994 | - // 'tentative' IPs as part of SLAAC. We wait up to 5 seconds |
1995 | - // for this condition to automatically resolve. |
1996 | - tentative := []string{"-6", "-o", "a", "s", "dev", iface.Name, "scope", "link", "tentative"} |
1997 | - for i := 0; i < 5; i++ { |
1998 | - res := run.WithOutput(ctx, "ip", tentative...) |
1999 | - if res.ExitCode == 0 && res.StdOut == "" { |
2000 | - break |
2001 | - } |
2002 | - time.Sleep(1 * time.Second) |
2003 | - } |
2004 | - if err := run.Quiet(ctx, "dhclient", "-r", "-6", "-1", "-v", iface.Name); err != nil { |
2005 | - return err |
2006 | - } |
2007 | - case oldNi.DHCPv6Refresh == "" && newNi.DHCPv6Refresh != "": |
2008 | - // enable |
2009 | - tentative := []string{"-6", "-o", "a", "s", "dev", iface.Name, "scope", "link", "tentative"} |
2010 | - for i := 0; i < 5; i++ { |
2011 | - res := run.WithOutput(ctx, "ip", tentative...) |
2012 | - if res.ExitCode == 0 && res.StdOut == "" { |
2013 | - break |
2014 | - } |
2015 | - time.Sleep(1 * time.Second) |
2016 | - } |
2017 | - val := fmt.Sprintf("net.ipv6.conf.%s.accept_ra_rt_info_max_plen=128", iface.Name) |
2018 | - if err := run.Quiet(ctx, "sysctl", val); err != nil { |
2019 | - return err |
2020 | - } |
2021 | - if err := run.Quiet(ctx, "dhclient", "-1", "-6", "-v", iface.Name); err != nil { |
2022 | - return err |
2023 | - } |
2024 | - } |
2025 | - return nil |
2026 | -} |
2027 | - |
2028 | -// enableNetworkInterfaces runs `dhclient eth1 eth2 ... ethN` |
2029 | -// and `dhclient -6 eth1 eth2 ... ethN`. |
2030 | -// On RHEL7, it also calls disableNM for each interface. |
2031 | -// On SLES, it calls enableSLESInterfaces instead of dhclient. |
2032 | -func enableNetworkInterfaces(ctx context.Context, config *cfg.Sections) error { |
2033 | - if len(newMetadata.Instance.NetworkInterfaces) < 2 { |
2034 | - return nil |
2035 | - } |
2036 | - var googleInterfaces []string |
2037 | - // The primary (first) interface is managed by the OS, we only handle |
2038 | - // secondary interfaces in this code. |
2039 | - for _, ni := range newMetadata.Instance.NetworkInterfaces[1:] { |
2040 | - iface, err := getInterfaceByMAC(ni.Mac) |
2041 | - if err != nil { |
2042 | - if !utils.ContainsString(ni.Mac, badMAC) { |
2043 | - logger.Errorf("Error getting interface: %s", err) |
2044 | - badMAC = append(badMAC, ni.Mac) |
2045 | - } |
2046 | - continue |
2047 | - } |
2048 | - googleInterfaces = append(googleInterfaces, iface.Name) |
2049 | - } |
2050 | - var googleIpv6Interfaces []string |
2051 | - for _, ni := range newMetadata.Instance.NetworkInterfaces[1:] { |
2052 | - if ni.DHCPv6Refresh == "" { |
2053 | - // This interface is not IPv6 enabled |
2054 | - continue |
2055 | - } |
2056 | - iface, err := getInterfaceByMAC(ni.Mac) |
2057 | - if err != nil { |
2058 | - if !utils.ContainsString(ni.Mac, badMAC) { |
2059 | - logger.Errorf("Error getting interface: %s", err) |
2060 | - badMAC = append(badMAC, ni.Mac) |
2061 | - } |
2062 | - continue |
2063 | - } |
2064 | - googleIpv6Interfaces = append(googleIpv6Interfaces, iface.Name) |
2065 | - } |
2066 | - |
2067 | - switch { |
2068 | - case osInfo.OS == "sles": |
2069 | - return enableSLESInterfaces(ctx, googleInterfaces) |
2070 | - case (osInfo.OS == "rhel" || osInfo.OS == "centos") && osInfo.Version.Major >= 7: |
2071 | - for _, iface := range googleInterfaces { |
2072 | - err := disableNM(iface) |
2073 | - if err != nil { |
2074 | - return err |
2075 | - } |
2076 | - } |
2077 | - fallthrough |
2078 | - default: |
2079 | - dhcpCommand := config.NetworkInterfaces.DHCPCommand |
2080 | - if dhcpCommand != "" { |
2081 | - tokens := strings.Split(dhcpCommand, " ") |
2082 | - return run.Quiet(ctx, tokens[0], tokens[1:]...) |
2083 | - } |
2084 | - |
2085 | - // Try IPv4 first as it's higher priority. |
2086 | - if err := run.Quiet(ctx, "dhclient", googleInterfaces...); err != nil { |
2087 | - return err |
2088 | - } |
2089 | - |
2090 | - if len(googleIpv6Interfaces) == 0 { |
2091 | - return nil |
2092 | - } |
2093 | - for _, iface := range googleIpv6Interfaces { |
2094 | - // Enable kernel to accept to route advertisements. |
2095 | - val := fmt.Sprintf("net.ipv6.conf.%s.accept_ra_rt_info_max_plen=128", iface) |
2096 | - if err := run.Quiet(ctx, "sysctl", val); err != nil { |
2097 | - return err |
2098 | - } |
2099 | - } |
2100 | - |
2101 | - var dhclientArgs6 []string |
2102 | - dhclientArgs6 = append([]string{"-6"}, googleIpv6Interfaces...) |
2103 | - return run.Quiet(ctx, "dhclient", dhclientArgs6...) |
2104 | - } |
2105 | -} |
2106 | - |
2107 | -// enableSLESInterfaces writes one ifcfg file for each interface, then |
2108 | -// runs `wicked ifup eth1 eth2 ... ethN` |
2109 | -func enableSLESInterfaces(ctx context.Context, interfaces []string) error { |
2110 | - var err error |
2111 | - var priority = 10100 |
2112 | - for _, iface := range interfaces { |
2113 | - logger.Debugf("write enabling ifcfg-%s config", iface) |
2114 | - |
2115 | - var ifcfg *os.File |
2116 | - ifcfg, err = os.Create("/etc/sysconfig/network/ifcfg-" + iface) |
2117 | - if err != nil { |
2118 | - return err |
2119 | - } |
2120 | - defer closer(ifcfg) |
2121 | - contents := []string{ |
2122 | - googleComment, |
2123 | - "STARTMODE=hotplug", |
2124 | - // NOTE: 'dhcp' is the dhcp4+dhcp6 option. |
2125 | - "BOOTPROTO=dhcp", |
2126 | - fmt.Sprintf("DHCLIENT_ROUTE_PRIORITY=%d", priority), |
2127 | - } |
2128 | - _, err = ifcfg.WriteString(strings.Join(contents, "\n")) |
2129 | - if err != nil { |
2130 | - return err |
2131 | - } |
2132 | - priority += 100 |
2133 | - } |
2134 | - args := append([]string{"ifup", "--timeout", "1"}, interfaces...) |
2135 | - return run.Quiet(ctx, "/usr/sbin/wicked", args...) |
2136 | -} |
2137 | - |
2138 | -// disableNM writes an ifcfg file with DHCP and NetworkManager disabled. |
2139 | -func disableNM(iface string) error { |
2140 | - logger.Debugf("write disabling ifcfg-%s config", iface) |
2141 | - filename := "/etc/sysconfig/network-scripts/ifcfg-" + iface |
2142 | - ifcfg, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644) |
2143 | - if err == nil { |
2144 | - defer closer(ifcfg) |
2145 | - contents := []string{ |
2146 | - googleComment, |
2147 | - fmt.Sprintf("DEVICE=%s", iface), |
2148 | - "BOOTPROTO=none", |
2149 | - "DEFROUTE=no", |
2150 | - "IPV6INIT=no", |
2151 | - "NM_CONTROLLED=no", |
2152 | - "NOZEROCONF=yes", |
2153 | - } |
2154 | - _, err = ifcfg.WriteString(strings.Join(contents, "\n")) |
2155 | - return err |
2156 | - } |
2157 | - if os.IsExist(err) { |
2158 | - return nil |
2159 | - } |
2160 | - return err |
2161 | -} |
2162 | diff --git a/google_guest_agent/addresses_integ_test.go b/google_guest_agent/addresses_integ_test.go |
2163 | deleted file mode 100644 |
2164 | index 1486986..0000000 |
2165 | --- a/google_guest_agent/addresses_integ_test.go |
2166 | +++ /dev/null |
2167 | @@ -1,99 +0,0 @@ |
2168 | -// Copyright 2021 Google LLC |
2169 | - |
2170 | -// Licensed under the Apache License, Version 2.0 (the "License"); |
2171 | -// you may not use this file except in compliance with the License. |
2172 | -// You may obtain a copy of the License at |
2173 | - |
2174 | -// https://www.apache.org/licenses/LICENSE-2.0 |
2175 | - |
2176 | -// Unless required by applicable law or agreed to in writing, software |
2177 | -// distributed under the License is distributed on an "AS IS" BASIS, |
2178 | -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
2179 | -// See the License for the specific language governing permissions and |
2180 | -// limitations under the License. |
2181 | - |
2182 | -//go:build integration |
2183 | -// +build integration |
2184 | - |
2185 | -package main |
2186 | - |
2187 | -import ( |
2188 | - "context" |
2189 | - "fmt" |
2190 | - "strings" |
2191 | - "testing" |
2192 | -) |
2193 | - |
2194 | -const testIp = "192.168.0.0" |
2195 | - |
2196 | -func TestAddAndRemoveLocalRoute(t *testing.T) { |
2197 | - metdata, err := getMetadata(context.Context(), false) |
2198 | - if err != nil { |
2199 | - t.Fatalf("failed to get metadata, err %v", err) |
2200 | - } |
2201 | - iface, err := getInterfaceByMAC(metdata.Instance.NetworkInterfaces[0].Mac) |
2202 | - if err != nil { |
2203 | - t.Fatalf("failed to get interface from mac, err %v", err) |
2204 | - } |
2205 | - // test add local route |
2206 | - if err := removeLocalRoute(testIp, iface.Name); err != nil { |
2207 | - t.Fatalf("failed to remove local route, err %v", err) |
2208 | - } |
2209 | - if err := addLocalRoute(testIp, iface.Name); err != nil { |
2210 | - t.Fatalf("add test local route should not failed, err %v", err) |
2211 | - } |
2212 | - |
2213 | - res, err := getLocalRoutes(iface.Name) |
2214 | - if err != nil { |
2215 | - t.Fatalf("get local route should not failed, err %v", err) |
2216 | - } |
2217 | - exist := false |
2218 | - for _, route := range res { |
2219 | - if strings.Contains(route, fmt.Sprintf("local %s/24", testIp)) { |
2220 | - exist = true |
2221 | - } |
2222 | - } |
2223 | - if !exist { |
2224 | - t.Fatalf("route %s is not added", testIp) |
2225 | - } |
2226 | - |
2227 | - // test remove local route |
2228 | - if err := removeLocalRoute(testIp, iface.Name); err != nil { |
2229 | - t.Fatalf("add test local route should not failed") |
2230 | - } |
2231 | - res, err := getLocalRoutes(iface.Name) |
2232 | - if err != nil { |
2233 | - t.Fatalf("ip route list should not failed, err %s", res.err) |
2234 | - } |
2235 | - |
2236 | - for _, route := range res { |
2237 | - if strings.Contains(route, fmt.Sprintf("local %s/24", testIp)) { |
2238 | - t.Fatalf("route %s should be removed but exist", testIp) |
2239 | - } |
2240 | - } |
2241 | -} |
2242 | - |
2243 | -func TestGetLocalRoute(t *testing.T) { |
2244 | - metdata, err := getMetadata(context.Context(), false) |
2245 | - if err != nil { |
2246 | - t.Fatalf("failed to get metadata, err %v", err) |
2247 | - } |
2248 | - iface, err := getInterfaceByMAC(metdata.Instance.NetworkInterfaces[0].Mac) |
2249 | - if err != nil { |
2250 | - t.Fatalf("failed to get interface from mac, err %v", err) |
2251 | - } |
2252 | - |
2253 | - if err := addLocalRoute(testIp, iface.Name); err != nil { |
2254 | - t.Fatalf("add test local route should not failed, err %v", err) |
2255 | - } |
2256 | - routes, err := getLocalRoutes(iface.Name) |
2257 | - if err != nil { |
2258 | - t.Fatalf("get local routes should not failed, err %v", err) |
2259 | - } |
2260 | - if len(routes) != 1 { |
2261 | - t.Fatal("find unexpected local route %s.", routes[0]) |
2262 | - } |
2263 | - if routes[0] != testIp { |
2264 | - t.Fatal("find unexpected local route %s.", routes[0]) |
2265 | - } |
2266 | -} |
2267 | diff --git a/google_guest_agent/agentcrypto/mtls_mds.go b/google_guest_agent/agentcrypto/mtls_mds.go |
2268 | index ad61634..ca3dfe1 100644 |
2269 | --- a/google_guest_agent/agentcrypto/mtls_mds.go |
2270 | +++ b/google_guest_agent/agentcrypto/mtls_mds.go |
2271 | @@ -35,7 +35,7 @@ import ( |
2272 | const ( |
2273 | // UEFI variables are of format {VariableName}-{VendorGUID} |
2274 | // googleGUID is Google's (vendors/variable owners) GUID used to prevent name collision with other vendors. |
2275 | - googleGUID = "8be4df61-93ca-11d2-aa0d-00e098032b8c" |
2276 | + googleGUID = "a2858e46-a37f-456a-8c79-0c1fe48b65ff" |
2277 | // googleRootCACertEFIVarName is predefined string part of the UEFI variable name that holds Root CA cert. |
2278 | googleRootCACertEFIVarName = "InstanceRootCACertificate" |
2279 | // clientCertsKey is the metadata server key at which client identity certificate is exposed. |
2280 | diff --git a/google_guest_agent/agentcrypto/mtls_mds_linux.go b/google_guest_agent/agentcrypto/mtls_mds_linux.go |
2281 | index ee64ff1..3a75ff0 100644 |
2282 | --- a/google_guest_agent/agentcrypto/mtls_mds_linux.go |
2283 | +++ b/google_guest_agent/agentcrypto/mtls_mds_linux.go |
2284 | @@ -17,6 +17,7 @@ package agentcrypto |
2285 | import ( |
2286 | "context" |
2287 | "fmt" |
2288 | + "os" |
2289 | "os/exec" |
2290 | "path/filepath" |
2291 | |
2292 | @@ -35,8 +36,25 @@ const ( |
2293 | clientCredsFileName = "client.key" |
2294 | ) |
2295 | |
2296 | +var ( |
2297 | + // certUpdaters is a map of known CA certificate updaters with the local directory paths for certificates. |
2298 | + certUpdaters = map[string][]string{ |
2299 | + // SUSE, Debian and Ubuntu distributions. |
2300 | + // https://manpages.ubuntu.com/manpages/xenial/man8/update-ca-certificates.8.html |
2301 | + // https://github.com/openSUSE/ca-certificates |
2302 | + "update-ca-certificates": {"/usr/local/share/ca-certificates", "/usr/share/pki/trust/anchors"}, |
2303 | + // CentOS, Fedora, RedHat distributions. |
2304 | + // https://www.unix.com/man-page/centos/8/UPDATE-CA-TRUST |
2305 | + "update-ca-trust": {"/etc/pki/ca-trust/source/anchors"}, |
2306 | + } |
2307 | +) |
2308 | + |
2309 | // writeRootCACert writes Root CA cert from UEFI variable to output file. |
2310 | func (j *CredsJob) writeRootCACert(ctx context.Context, content []byte, outputFile string) error { |
2311 | + // The directory should be executable, but the file does not need to be. |
2312 | + if err := os.MkdirAll(filepath.Dir(outputFile), 0655); err != nil { |
2313 | + return err |
2314 | + } |
2315 | if err := utils.SaferWriteFile(content, outputFile, 0644); err != nil { |
2316 | return err |
2317 | } |
2318 | @@ -51,39 +69,42 @@ func (j *CredsJob) writeRootCACert(ctx context.Context, content []byte, outputFi |
2319 | |
2320 | // writeClientCredentials stores client credentials (certificate and private key). |
2321 | func (j *CredsJob) writeClientCredentials(plaintext []byte, outputFile string) error { |
2322 | + // The directory should be executable, but the file does not need to be. |
2323 | + if err := os.MkdirAll(filepath.Dir(outputFile), 0655); err != nil { |
2324 | + return err |
2325 | + } |
2326 | return utils.SaferWriteFile(plaintext, outputFile, 0644) |
2327 | } |
2328 | |
2329 | // getCAStoreUpdater interates over known system trust store updaters and returns the first found. |
2330 | func getCAStoreUpdater() (string, error) { |
2331 | - knownUpdaters := []string{"update-ca-certificates", "update-ca-trust"} |
2332 | var errs []string |
2333 | |
2334 | - for _, u := range knownUpdaters { |
2335 | + for u := range certUpdaters { |
2336 | _, err := exec.LookPath(u) |
2337 | if err == nil { |
2338 | return u, nil |
2339 | } |
2340 | - errs = append(errs, err.Error()) |
2341 | + errs = append(errs, fmt.Sprintf("lookup for %q failed with error: %v", u, err)) |
2342 | } |
2343 | |
2344 | - return "", fmt.Errorf("no known trust updaters %v were found: %v", knownUpdaters, errs) |
2345 | + return "", fmt.Errorf("no known trust updaters were found: %v", errs) |
2346 | } |
2347 | |
2348 | // certificateDirFromUpdater returns directory of local CA certificates for the given updater tool. |
2349 | func certificateDirFromUpdater(updater string) (string, error) { |
2350 | - switch updater { |
2351 | - // SUSE, Debian and Ubuntu distributions. |
2352 | - // https://manpages.ubuntu.com/manpages/xenial/man8/update-ca-certificates.8.html |
2353 | - case "update-ca-certificates": |
2354 | - return "/usr/local/share/ca-certificates/", nil |
2355 | - // CentOS, Fedora, RedHat distributions. |
2356 | - // https://www.unix.com/man-page/centos/8/UPDATE-CA-TRUST/ |
2357 | - case "update-ca-trust": |
2358 | - return "/etc/pki/ca-trust/source/anchors/", nil |
2359 | - default: |
2360 | + dirs, ok := certUpdaters[updater] |
2361 | + if !ok { |
2362 | return "", fmt.Errorf("unknown updater %q, no local trusted CA certificate directory found", updater) |
2363 | } |
2364 | + |
2365 | + for _, dir := range dirs { |
2366 | + fi, err := os.Stat(dir) |
2367 | + if err == nil && fi.IsDir() { |
2368 | + return dir, nil |
2369 | + } |
2370 | + } |
2371 | + return "", fmt.Errorf("no of the known directories %v found for updater %q", dirs, updater) |
2372 | } |
2373 | |
2374 | // updateSystemStore updates the local system store with the cert. |
2375 | diff --git a/google_guest_agent/agentcrypto/mtls_mds_linux_test.go b/google_guest_agent/agentcrypto/mtls_mds_linux_test.go |
2376 | index 020af48..831e02f 100644 |
2377 | --- a/google_guest_agent/agentcrypto/mtls_mds_linux_test.go |
2378 | +++ b/google_guest_agent/agentcrypto/mtls_mds_linux_test.go |
2379 | @@ -132,17 +132,24 @@ func TestShouldEnableError(t *testing.T) { |
2380 | } |
2381 | |
2382 | func TestCertificateDirFromUpdater(t *testing.T) { |
2383 | + updater1Dir := t.TempDir() |
2384 | + updater2Dir := t.TempDir() |
2385 | + certUpdaters = map[string][]string{ |
2386 | + "updater1": {updater1Dir}, |
2387 | + "updater2": {"/does/not/exist", updater2Dir}, |
2388 | + } |
2389 | + |
2390 | tests := []struct { |
2391 | updater string |
2392 | want string |
2393 | }{ |
2394 | { |
2395 | - updater: "update-ca-certificates", |
2396 | - want: "/usr/local/share/ca-certificates/", |
2397 | + updater: "updater1", |
2398 | + want: updater1Dir, |
2399 | }, |
2400 | { |
2401 | - updater: "update-ca-trust", |
2402 | - want: "/etc/pki/ca-trust/source/anchors/", |
2403 | + updater: "updater2", |
2404 | + want: updater2Dir, |
2405 | }, |
2406 | } |
2407 | |
2408 | @@ -160,8 +167,18 @@ func TestCertificateDirFromUpdater(t *testing.T) { |
2409 | } |
2410 | |
2411 | func TestCertificateDirFromUpdaterError(t *testing.T) { |
2412 | + // Fail for unknown updater. |
2413 | _, err := certificateDirFromUpdater("unknown") |
2414 | if err == nil { |
2415 | t.Errorf("certificateDirFromUpdater(unknown) succeeded for unknown updater, want error") |
2416 | } |
2417 | + |
2418 | + // Fail for missing known cert dir. |
2419 | + certUpdaters = map[string][]string{ |
2420 | + "updater1": {"/no/dir/exist"}, |
2421 | + } |
2422 | + _, err = certificateDirFromUpdater("updater1") |
2423 | + if err == nil { |
2424 | + t.Errorf("certificateDirFromUpdater(unknown) succeeded for missing cert dir, want error") |
2425 | + } |
2426 | } |
2427 | diff --git a/google_guest_agent/agentcrypto/mtls_mds_windows.go b/google_guest_agent/agentcrypto/mtls_mds_windows.go |
2428 | index a11f113..3f70954 100644 |
2429 | --- a/google_guest_agent/agentcrypto/mtls_mds_windows.go |
2430 | +++ b/google_guest_agent/agentcrypto/mtls_mds_windows.go |
2431 | @@ -59,6 +59,12 @@ var ( |
2432 | |
2433 | // writeRootCACert writes Root CA cert from UEFI variable to output file. |
2434 | func (j *CredsJob) writeRootCACert(_ context.Context, cacert []byte, outputFile string) error { |
2435 | + // Try to fetch previous certificate's serial number before it gets overwritten. |
2436 | + num, err := serialNumber(outputFile) |
2437 | + if err != nil { |
2438 | + logger.Debugf("No previous MDS root certificate was found, will skip cleanup: %v", err) |
2439 | + } |
2440 | + |
2441 | if err := utils.SaferWriteFile(cacert, outputFile, 0644); err != nil { |
2442 | return err |
2443 | } |
2444 | @@ -84,6 +90,25 @@ func (j *CredsJob) writeRootCACert(_ context.Context, cacert []byte, outputFile |
2445 | return fmt.Errorf("failed to store root cert ctx in store: %w", err) |
2446 | } |
2447 | |
2448 | + // MDS root cert was not refreshed or there's no previous cert, nothing to do, return. |
2449 | + if num == "" || fmt.Sprintf("%x", x509Cert.SerialNumber) == num { |
2450 | + return nil |
2451 | + } |
2452 | + |
2453 | + // Certificate is refreshed. Best effort to find the certcontext and delete it. |
2454 | + // Don't throw error here, it would skip client credential generation which |
2455 | + // may be about to expire. |
2456 | + oldCtx, err := findCert(root, certificateIssuer, num) |
2457 | + if err != nil { |
2458 | + logger.Warningf("Failed to find previous MDS root certificate with error: %v", err) |
2459 | + return nil |
2460 | + } |
2461 | + |
2462 | + if err := deleteCert(oldCtx, root); err != nil { |
2463 | + logger.Warningf("Failed to delete previous MDS root certificate(%s) with error: %v", num, err) |
2464 | + return nil |
2465 | + } |
2466 | + |
2467 | return nil |
2468 | } |
2469 | |
2470 | diff --git a/google_guest_agent/cfg/cfg.go b/google_guest_agent/cfg/cfg.go |
2471 | index 3a54c47..c59ec14 100644 |
2472 | --- a/google_guest_agent/cfg/cfg.go |
2473 | +++ b/google_guest_agent/cfg/cfg.go |
2474 | @@ -27,6 +27,10 @@ var ( |
2475 | // should always return it. |
2476 | instance *Sections |
2477 | |
2478 | + // configFile is a pointer to a function which takes the current OS name and returns |
2479 | + // an appropriate config file name. Replaceable by unit tests. |
2480 | + configFile = defaultConfigFile |
2481 | + |
2482 | // dataSource is a pointer to a data source loading/defining function, unit tests will |
2483 | // want to change this pointer to whatever makes sense to its implementation. |
2484 | dataSources = defaultDataSources |
2485 | @@ -87,6 +91,9 @@ setup = true |
2486 | [OSLogin] |
2487 | cert_authentication = true |
2488 | |
2489 | +[MDS] |
2490 | +mtls_bootstrapping_enabled = true |
2491 | + |
2492 | [Snapshots] |
2493 | enabled = false |
2494 | snapshot_service_ip = 169.254.169.254 |
2495 | @@ -94,7 +101,10 @@ snapshot_service_port = 8081 |
2496 | timeout_in_seconds = 60 |
2497 | |
2498 | [Unstable] |
2499 | -mds_mtls = false |
2500 | +command_monitor_enabled = false |
2501 | +command_pipe_mode = 0770 |
2502 | +command_pipe_group = |
2503 | +command_request_timeout = 10s |
2504 | ` |
2505 | ) |
2506 | |
2507 | @@ -144,6 +154,9 @@ type Sections struct { |
2508 | // OSLogin defines the OS Login configuration options. |
2509 | OSLogin *OSLogin `ini:"OSLogin,omitempty"` |
2510 | |
2511 | + // MDS defines the MDS configuration options. |
2512 | + MDS *MDS `ini:"MDS,omitempty"` |
2513 | + |
2514 | // Snpashots defines the snapshot listener configuration and behavior i.e. the server address and port. |
2515 | Snapshots *Snapshots `ini:"Snapshots,omitempty"` |
2516 | |
2517 | @@ -237,6 +250,12 @@ type OSLogin struct { |
2518 | CertAuthentication bool `ini:"cert_authentication,omitempty"` |
2519 | } |
2520 | |
2521 | +// MDS contains the configurations for MDS section. |
2522 | +type MDS struct { |
2523 | + // MTLSBootstrappingEnabled enables/disables the mTLS credential refresher. |
2524 | + MTLSBootstrappingEnabled bool `ini:"mtls_bootstrapping_enabled,omitempty"` |
2525 | +} |
2526 | + |
2527 | // NetworkInterfaces contains the configurations of NetworkInterfaces section. |
2528 | type NetworkInterfaces struct { |
2529 | DHCPCommand string `ini:"dhcp_command,omitempty"` |
2530 | @@ -256,7 +275,11 @@ type Snapshots struct { |
2531 | // is guaranteed for configurations defined in the Unstable section. By default all flags defined |
2532 | // in this section is disabled and is intended to isolate under development features. |
2533 | type Unstable struct { |
2534 | - MDSMTLS bool `ini:"mds_mtls,omitempty"` |
2535 | + CommandMonitorEnabled bool `ini:"command_monitor_enabled,omitempty"` |
2536 | + CommandPipePath string `ini:"command_pipe_path,omitempty"` |
2537 | + CommandRequestTimeout string `ini:"command_request_timeout,omitempty"` |
2538 | + CommandPipeMode string `ini:"command_pipe_mode,omitempty"` |
2539 | + CommandPipeGroup string `ini:"command_pipe_group,omitempty"` |
2540 | } |
2541 | |
2542 | // WSFC contains the configurations of WSFC section. |
2543 | @@ -274,18 +297,17 @@ func defaultConfigFile(osName string) string { |
2544 | } |
2545 | |
2546 | func defaultDataSources(extraDefaults []byte) []interface{} { |
2547 | - var res []interface{} |
2548 | - configFile := defaultConfigFile(runtime.GOOS) |
2549 | + var res = []interface{}{[]byte(defaultConfig)} |
2550 | + config := configFile(runtime.GOOS) |
2551 | |
2552 | if len(extraDefaults) > 0 { |
2553 | res = append(res, extraDefaults) |
2554 | } |
2555 | |
2556 | return append(res, []interface{}{ |
2557 | - []byte(defaultConfig), |
2558 | - configFile, |
2559 | - configFile + ".distro", |
2560 | - configFile + ".template", |
2561 | + config, |
2562 | + config + ".distro", |
2563 | + config + ".template", |
2564 | }...) |
2565 | } |
2566 | |
2567 | diff --git a/google_guest_agent/cfg/cfg_test.go b/google_guest_agent/cfg/cfg_test.go |
2568 | index efea06f..20321c1 100644 |
2569 | --- a/google_guest_agent/cfg/cfg_test.go |
2570 | +++ b/google_guest_agent/cfg/cfg_test.go |
2571 | @@ -14,7 +14,9 @@ |
2572 | |
2573 | package cfg |
2574 | |
2575 | -import "testing" |
2576 | +import ( |
2577 | + "testing" |
2578 | +) |
2579 | |
2580 | func TestLoad(t *testing.T) { |
2581 | if err := Load(nil); err != nil { |
2582 | diff --git a/google_guest_agent/command/Readme.md b/google_guest_agent/command/Readme.md |
2583 | new file mode 100644 |
2584 | index 0000000..4fabd96 |
2585 | --- /dev/null |
2586 | +++ b/google_guest_agent/command/Readme.md |
2587 | @@ -0,0 +1,24 @@ |
2588 | +# Guest Agent Command Monitor |
2589 | +## Overview |
2590 | +The Guest Agent command monitor is a system used for executing commands in the guest agent on behalf of components in the guest os. |
2591 | + |
2592 | +The events layer is formed of a **Monitor**, a **Server** and a **Handler** where the **Monitor** handles command registration for guest agent components, the **Server** is the component which listens for events from the gueest os, and the **Handler** is the function executed by the agent. |
2593 | + |
2594 | +Each **Handler** is identified by a string ID, provided when sending commands to the server. Requests and response to and from the server are structured in JSON format. A request must contain the name field, specifying the handler to be executed. A request may contain arbitrary other fields to be passed to the handler. An example request is below: |
2595 | + |
2596 | +``` |
2597 | +{"Name":"agent.ExampleCommand","ArbitraryArgument":123} |
2598 | +``` |
2599 | + |
2600 | +A response will be valid JSON and has two required fields: Status and StatusMessage. Status is an int which follows unix status code conventions (ie zero is success, status codes are arbitrary and meaning is defined by the function called) and StatusMessage is an explanatory string accompanying the Status. Two example responses are below. |
2601 | + |
2602 | +``` |
2603 | +{"Status":0,"StatusMessage":""} |
2604 | + |
2605 | +{"Status":7,"StatusMessage":"Failure message"} |
2606 | +``` |
2607 | + |
2608 | +By default, the Server listens on a unix socket or a named pipe, depending on platform. Permissions for the pipe and the pipe path can be set in the guest-agent [configuration](https://github.com/GoogleCloudPlatform/guest-agent#configuration). The default pipe path for windows and linux systems are `\\.\pipe\google-guest-agent-commands` non-windows and `/run/google-guest-agent/commands.sock` respectively. |
2609 | + |
2610 | +## Implementing a command handler |
2611 | +Registering a command handler will expose the handler function to be called by anyone with write permission to the underlying socket. To do so, call `command.Get().RegisterHandler(name, handerFunc)` to get the current command monitor and register the handlerFunc with it. Note that if the command system is disabled by user configuration, handler registration will succeed but the server will not be available for callers to send commands to. |
2612 | diff --git a/google_guest_agent/command/command.go b/google_guest_agent/command/command.go |
2613 | new file mode 100644 |
2614 | index 0000000..527d20f |
2615 | --- /dev/null |
2616 | +++ b/google_guest_agent/command/command.go |
2617 | @@ -0,0 +1,146 @@ |
2618 | +// Copyright 2023 Google Inc. All Rights Reserved. |
2619 | +// |
2620 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
2621 | +// you may not use this file except in compliance with the License. |
2622 | +// You may obtain a copy of the License at |
2623 | +// |
2624 | +// http://www.apache.org/licenses/LICENSE-2.0 |
2625 | +// |
2626 | +// Unless required by applicable law or agreed to in writing, software |
2627 | +// distributed under the License is distributed on an "AS IS" BASIS, |
2628 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
2629 | +// See the License for the specific language governing permissions and |
2630 | +// limitations under the License. |
2631 | + |
2632 | +// Package command facilitates calling commands within the guest-agent. |
2633 | +package command |
2634 | + |
2635 | +import ( |
2636 | + "context" |
2637 | + "encoding/json" |
2638 | + "fmt" |
2639 | + "io" |
2640 | + |
2641 | + "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg" |
2642 | +) |
2643 | + |
2644 | +// Get returns the current command monitor which can be used to register command handlers. |
2645 | +func Get() *Monitor { |
2646 | + return cmdMonitor |
2647 | +} |
2648 | + |
2649 | +// Handler functions are the business logic of commands. They must process json |
2650 | +// encoded as a byte slice which contains a Command field and optional arbitrary |
2651 | +// data, and return json which contains a Status, StatusMessage, and optional |
2652 | +// arbitrary data (again encoded as a byte slice). Returned errors will be |
2653 | +// passed onto the command requester. |
2654 | +type Handler func([]byte) ([]byte, error) |
2655 | + |
2656 | +// Request is the basic request structure. Command determines which handler the |
2657 | +// request is routed to. Callers may set additional arbitrary fields. |
2658 | +type Request struct { |
2659 | + Command string |
2660 | +} |
2661 | + |
2662 | +// Response is the basic response structure. Handlers may set additional |
2663 | +// arbitrary fields. |
2664 | +type Response struct { |
2665 | + // Status code for the request. Meaning is defined by the caller, but |
2666 | + // conventially zero is success. |
2667 | + Status int |
2668 | + // StatusMessage is an optional message defined by the caller. Should generally |
2669 | + // help a human understand what happened. |
2670 | + StatusMessage string |
2671 | +} |
2672 | + |
2673 | +var ( |
2674 | + // CmdNotFoundError is return when there is no handler for the request command |
2675 | + CmdNotFoundError = Response{ |
2676 | + Status: 101, |
2677 | + StatusMessage: "Could not find a handler for the requested command", |
2678 | + } |
2679 | + // BadRequestError is returned for invalid or unparseable JSON |
2680 | + BadRequestError = Response{ |
2681 | + Status: 102, |
2682 | + StatusMessage: "Could not parse valid JSON from request", |
2683 | + } |
2684 | + // ConnError is returned for errors from the underlying communication protocol |
2685 | + ConnError = Response{ |
2686 | + Status: 103, |
2687 | + StatusMessage: "Connection error", |
2688 | + } |
2689 | + // TimeoutError is returned when the timeout period elapses before valid JSON is receieved |
2690 | + TimeoutError = Response{ |
2691 | + Status: 104, |
2692 | + StatusMessage: "Connection timeout before reading valid request", |
2693 | + } |
2694 | + // HandlerError is returned when the handler function returns an non-nil error. The status message will be replaced with the returnd error string. |
2695 | + HandlerError = Response{ |
2696 | + Status: 105, |
2697 | + StatusMessage: "The command handler encountered an error processing your request", |
2698 | + } |
2699 | + // InternalErrorCode is the error code for internal command server errors. Returned when failing to marshal a response. |
2700 | + InternalErrorCode = 106 |
2701 | + internalError = []byte(`{"Status":106,"StatusMessage":"The command server encountered an internal error trying to respond to your request"}`) |
2702 | +) |
2703 | + |
2704 | +// RegisterHandler registers f as the handler for cmd. If a command.Server has |
2705 | +// been initialized, it will be signalled to start listening for commands. |
2706 | +func (m *Monitor) RegisterHandler(cmd string, f Handler) error { |
2707 | + m.handlersMu.Lock() |
2708 | + defer m.handlersMu.Unlock() |
2709 | + if _, ok := m.handlers[cmd]; ok { |
2710 | + return fmt.Errorf("cmd %s is already handled", cmd) |
2711 | + } |
2712 | + m.handlers[cmd] = f |
2713 | + return nil |
2714 | +} |
2715 | + |
2716 | +// UnregisterHandler clears the handlers for cmd. If a command.Server has been |
2717 | +// intialized and there are no more handlers registered, the server will be |
2718 | +// signalled to stop listening for commands. |
2719 | +func (m *Monitor) UnregisterHandler(cmd string) error { |
2720 | + m.handlersMu.Lock() |
2721 | + defer m.handlersMu.Unlock() |
2722 | + if _, ok := m.handlers[cmd]; !ok { |
2723 | + return fmt.Errorf("cmd %s is not registered", cmd) |
2724 | + } |
2725 | + delete(m.handlers, cmd) |
2726 | + return nil |
2727 | +} |
2728 | + |
2729 | +// SendCommand sends a command request over the configured pipe. |
2730 | +func SendCommand(ctx context.Context, req []byte) []byte { |
2731 | + pipe := cfg.Get().Unstable.CommandPipePath |
2732 | + if pipe == "" { |
2733 | + pipe = DefaultPipePath |
2734 | + } |
2735 | + return SendCmdPipe(ctx, pipe, req) |
2736 | +} |
2737 | + |
2738 | +// SendCmdPipe sends a command request over a specific pipe. Most callers |
2739 | +// should use SendCommand() instead. |
2740 | +func SendCmdPipe(ctx context.Context, pipe string, req []byte) []byte { |
2741 | + conn, err := dialPipe(ctx, pipe) |
2742 | + if err != nil { |
2743 | + if b, err := json.Marshal(ConnError); err != nil { |
2744 | + return b |
2745 | + } |
2746 | + return internalError |
2747 | + } |
2748 | + i, err := conn.Write(req) |
2749 | + if err != nil || i != len(req) { |
2750 | + if b, err := json.Marshal(ConnError); err != nil { |
2751 | + return b |
2752 | + } |
2753 | + return internalError |
2754 | + } |
2755 | + data, err := io.ReadAll(conn) |
2756 | + if err != nil { |
2757 | + if b, err := json.Marshal(ConnError); err != nil { |
2758 | + return b |
2759 | + } |
2760 | + return internalError |
2761 | + } |
2762 | + return data |
2763 | +} |
2764 | diff --git a/google_guest_agent/command/command_linux.go b/google_guest_agent/command/command_linux.go |
2765 | new file mode 100644 |
2766 | index 0000000..cf7dab1 |
2767 | --- /dev/null |
2768 | +++ b/google_guest_agent/command/command_linux.go |
2769 | @@ -0,0 +1,140 @@ |
2770 | +// Copyright 2023 Google Inc. All Rights Reserved. |
2771 | +// |
2772 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
2773 | +// you may not use this file except in compliance with the License. |
2774 | +// You may obtain a copy of the License at |
2775 | +// |
2776 | +// http://www.apache.org/licenses/LICENSE-2.0 |
2777 | +// |
2778 | +// Unless required by applicable law or agreed to in writing, software |
2779 | +// distributed under the License is distributed on an "AS IS" BASIS, |
2780 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
2781 | +// See the License for the specific language governing permissions and |
2782 | +// limitations under the License. |
2783 | + |
2784 | +package command |
2785 | + |
2786 | +import ( |
2787 | + "context" |
2788 | + "fmt" |
2789 | + "net" |
2790 | + "os" |
2791 | + "os/user" |
2792 | + "path" |
2793 | + "runtime" |
2794 | + "strconv" |
2795 | + "syscall" |
2796 | + |
2797 | + "github.com/GoogleCloudPlatform/guest-logging-go/logger" |
2798 | +) |
2799 | + |
2800 | +// DefaultPipePath is the default unix socket path for linux. |
2801 | +const DefaultPipePath = "/run/google-guest-agent/commands.sock" |
2802 | + |
2803 | +func mkdirpWithPerms(dir string, p os.FileMode, uid, gid int) error { |
2804 | + stat, err := os.Stat(dir) |
2805 | + if err == nil { |
2806 | + statT, ok := stat.Sys().(*syscall.Stat_t) |
2807 | + if !ok { |
2808 | + return fmt.Errorf("could not determine owner of %s", dir) |
2809 | + } |
2810 | + if !stat.IsDir() { |
2811 | + return fmt.Errorf("%s exists and is not a directory", dir) |
2812 | + } |
2813 | + if morePermissive(int(stat.Mode()), int(p)) { |
2814 | + if err := os.Chmod(dir, p); err != nil { |
2815 | + return fmt.Errorf("could not correct %s permissions to %d: %v", dir, p, err) |
2816 | + } |
2817 | + } |
2818 | + if statT.Uid != 0 && statT.Uid != uint32(uid) { |
2819 | + if err := os.Chown(dir, uid, -1); err != nil { |
2820 | + return fmt.Errorf("could not correct %s owner to %d: %v", dir, uid, err) |
2821 | + } |
2822 | + } |
2823 | + if statT.Gid != 0 && statT.Gid != uint32(gid) { |
2824 | + if err := os.Chown(dir, -1, gid); err != nil { |
2825 | + return fmt.Errorf("could not correct %s group to %d: %v", dir, gid, err) |
2826 | + } |
2827 | + } |
2828 | + } else { |
2829 | + parent, _ := path.Split(dir) |
2830 | + if parent != "/" && parent != "" { |
2831 | + if err := mkdirpWithPerms(parent, p, uid, gid); err != nil { |
2832 | + return err |
2833 | + } |
2834 | + } |
2835 | + if err := os.Mkdir(dir, p); err != nil { |
2836 | + return err |
2837 | + } |
2838 | + } |
2839 | + return nil |
2840 | +} |
2841 | + |
2842 | +func morePermissive(i, j int) bool { |
2843 | + for k := 0; k < 3; k++ { |
2844 | + if (i % 010) > (j % 10) { |
2845 | + return true |
2846 | + } |
2847 | + i = i / 010 |
2848 | + j = j / 010 |
2849 | + } |
2850 | + return false |
2851 | +} |
2852 | + |
2853 | +func listen(ctx context.Context, pipe string, filemode int, grp string) (net.Listener, error) { |
2854 | + // If grp is an int, use it as a GID |
2855 | + gid, err := strconv.Atoi(grp) |
2856 | + if err != nil { |
2857 | + // Otherwise lookup GID |
2858 | + group, err := user.LookupGroup(grp) |
2859 | + if err != nil { |
2860 | + logger.Errorf("guest agent command pipe group %s is not a GID nor a valid group, not changing socket ownership", grp) |
2861 | + gid = -1 |
2862 | + } else { |
2863 | + gid, err = strconv.Atoi(group.Gid) |
2864 | + if err != nil { |
2865 | + logger.Errorf("os reported group %s has gid %s which is not a valid int, not changing socket ownership. this should never happen", grp, group.Gid) |
2866 | + gid = -1 |
2867 | + } |
2868 | + } |
2869 | + } |
2870 | + // socket owner group does not need to have permissions to everything in the directory containing it, whatever user and group we are should own that |
2871 | + user, err := user.Current() |
2872 | + if err != nil { |
2873 | + return nil, fmt.Errorf("could not lookup current user") |
2874 | + } |
2875 | + currentuid, err := strconv.Atoi(user.Uid) |
2876 | + if err != nil { |
2877 | + return nil, fmt.Errorf("os reported user %s has uid %s which is not a valid int, can't determine directory owner. this should never happen", user.Username, user.Uid) |
2878 | + } |
2879 | + currentgid, err := strconv.Atoi(user.Gid) |
2880 | + if err != nil { |
2881 | + return nil, fmt.Errorf("os reported user %s has gid %s which is not a valid int, can't determine directory owner. this should never happen", user.Username, user.Gid) |
2882 | + } |
2883 | + if err := mkdirpWithPerms(path.Dir(pipe), os.FileMode(filemode), currentuid, currentgid); err != nil { |
2884 | + return nil, err |
2885 | + } |
2886 | + // Mutating the umask of the process for this is not ideal, but tightening permissions with chown after creation is not really secure. |
2887 | + // Lock OS thread while mutating umask so we don't lose a thread with a mutated mask. |
2888 | + runtime.LockOSThread() |
2889 | + oldmask := syscall.Umask(777 - filemode) |
2890 | + var lc net.ListenConfig |
2891 | + l, err := lc.Listen(ctx, "unix", pipe) |
2892 | + syscall.Umask(oldmask) |
2893 | + runtime.UnlockOSThread() |
2894 | + if err != nil { |
2895 | + return nil, err |
2896 | + } |
2897 | + // But we need to chown anyway to loosen permissions to include whatever group the user has configured |
2898 | + err = os.Chown(pipe, int(currentuid), gid) |
2899 | + if err != nil { |
2900 | + l.Close() |
2901 | + return nil, err |
2902 | + } |
2903 | + return l, nil |
2904 | +} |
2905 | + |
2906 | +func dialPipe(ctx context.Context, pipe string) (net.Conn, error) { |
2907 | + var dialer net.Dialer |
2908 | + return dialer.DialContext(ctx, "unix", pipe) |
2909 | +} |
2910 | diff --git a/google_guest_agent/command/command_monitor.go b/google_guest_agent/command/command_monitor.go |
2911 | new file mode 100644 |
2912 | index 0000000..6e1ee86 |
2913 | --- /dev/null |
2914 | +++ b/google_guest_agent/command/command_monitor.go |
2915 | @@ -0,0 +1,228 @@ |
2916 | +// Copyright 2023 Google Inc. All Rights Reserved. |
2917 | +// |
2918 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
2919 | +// you may not use this file except in compliance with the License. |
2920 | +// You may obtain a copy of the License at |
2921 | +// |
2922 | +// http://www.apache.org/licenses/LICENSE-2.0 |
2923 | +// |
2924 | +// Unless required by applicable law or agreed to in writing, software |
2925 | +// distributed under the License is distributed on an "AS IS" BASIS, |
2926 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
2927 | +// See the License for the specific language governing permissions and |
2928 | +// limitations under the License. |
2929 | + |
2930 | +/* |
2931 | + * This file contains the details of command's internal communication protocol |
2932 | + * listener. Most callers should not need to call anything in this file. The |
2933 | + * command handler and caller API is contained in command.go. |
2934 | + */ |
2935 | + |
2936 | +package command |
2937 | + |
2938 | +import ( |
2939 | + "bufio" |
2940 | + "context" |
2941 | + "encoding/json" |
2942 | + "errors" |
2943 | + "net" |
2944 | + "os" |
2945 | + "strconv" |
2946 | + "sync" |
2947 | + "time" |
2948 | + |
2949 | + "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg" |
2950 | + "github.com/GoogleCloudPlatform/guest-logging-go/logger" |
2951 | +) |
2952 | + |
2953 | +var cmdMonitor *Monitor = &Monitor{ |
2954 | + handlersMu: new(sync.RWMutex), |
2955 | + handlers: make(map[string]Handler), |
2956 | +} |
2957 | + |
2958 | +// Init starts an internally managed command server. The agent configuration |
2959 | +// will decide the server options. Returns a reference to the internally managed |
2960 | +// command monitor which the caller can Close() when appropriate. |
2961 | +func Init(ctx context.Context) { |
2962 | + if cmdMonitor.srv != nil { |
2963 | + return |
2964 | + } |
2965 | + pipe := cfg.Get().Unstable.CommandPipePath |
2966 | + if pipe == "" { |
2967 | + pipe = DefaultPipePath |
2968 | + } |
2969 | + to, err := time.ParseDuration(cfg.Get().Unstable.CommandRequestTimeout) |
2970 | + if err != nil { |
2971 | + logger.Errorf("commmand request timeout configuration is not a valid duration string, falling back to 10s timeout") |
2972 | + to = time.Duration(10) * time.Second |
2973 | + } |
2974 | + var pipemode int64 = 0770 |
2975 | + pipemode, err = strconv.ParseInt(cfg.Get().Unstable.CommandPipeMode, 8, 32) |
2976 | + if err != nil { |
2977 | + logger.Errorf("could not parse command_pipe_mode as octal integer: %v falling back to mode 0770", err) |
2978 | + } |
2979 | + cmdMonitor.srv = &Server{ |
2980 | + pipe: pipe, |
2981 | + pipeMode: int(pipemode), |
2982 | + pipeGroup: cfg.Get().Unstable.CommandPipeGroup, |
2983 | + timeout: to, |
2984 | + monitor: cmdMonitor, |
2985 | + } |
2986 | + err = cmdMonitor.srv.start(ctx) |
2987 | + if err != nil { |
2988 | + logger.Errorf("failed to start command server: %s", err) |
2989 | + } |
2990 | +} |
2991 | + |
2992 | +// Close will close the internally managed command server, if it was initialized. |
2993 | +func Close() error { |
2994 | + if cmdMonitor.srv != nil { |
2995 | + return cmdMonitor.srv.Close() |
2996 | + } |
2997 | + return nil |
2998 | +} |
2999 | + |
3000 | +// Monitor is the structure which handles command registration and deregistration. |
3001 | +type Monitor struct { |
3002 | + srv *Server |
3003 | + handlersMu *sync.RWMutex |
3004 | + handlers map[string]Handler |
3005 | +} |
3006 | + |
3007 | +// Close stops the server from listening to commands. |
3008 | +func (m *Monitor) Close() error { return m.srv.Close() } |
3009 | + |
3010 | +// Start begins listening for commands. |
3011 | +func (m *Monitor) Start(ctx context.Context) error { return m.srv.start(ctx) } |
3012 | + |
3013 | +// Server is the server structure which will listen for command requests and |
3014 | +// route them to handlers. Most callers should not interact with this directly. |
3015 | +type Server struct { |
3016 | + pipe string |
3017 | + pipeMode int |
3018 | + pipeGroup string |
3019 | + timeout time.Duration |
3020 | + srv net.Listener |
3021 | + monitor *Monitor |
3022 | +} |
3023 | + |
3024 | +// Close signals the server to stop listening for commands and stop waiting to |
3025 | +// listen. |
3026 | +func (c *Server) Close() error { |
3027 | + if c.srv != nil { |
3028 | + return c.srv.Close() |
3029 | + } |
3030 | + return nil |
3031 | +} |
3032 | + |
3033 | +func (c *Server) start(ctx context.Context) error { |
3034 | + if c.srv != nil { |
3035 | + return errors.New("server already listening") |
3036 | + } |
3037 | + srv, err := listen(ctx, c.pipe, c.pipeMode, c.pipeGroup) |
3038 | + if err != nil { |
3039 | + return err |
3040 | + } |
3041 | + go func() { |
3042 | + defer srv.Close() |
3043 | + for { |
3044 | + if ctx.Err() != nil { |
3045 | + return |
3046 | + } |
3047 | + conn, err := srv.Accept() |
3048 | + if err != nil { |
3049 | + if err == net.ErrClosed { |
3050 | + break |
3051 | + } |
3052 | + logger.Infof("error on connection to pipe %s: %v", c.pipe, err) |
3053 | + continue |
3054 | + } |
3055 | + go func(conn net.Conn) { |
3056 | + defer conn.Close() |
3057 | + // Go has lots of helpers to do this for us but none of them return the byte |
3058 | + // slice afterwards, and we need it for the handler |
3059 | + var b []byte |
3060 | + r := bufio.NewReader(conn) |
3061 | + var depth int |
3062 | + deadline := time.Now().Add(c.timeout) |
3063 | + e := conn.SetReadDeadline(deadline) |
3064 | + if e != nil { |
3065 | + logger.Infof("could not set read deadline on command request: %v", e) |
3066 | + return |
3067 | + } |
3068 | + for { |
3069 | + if time.Now().After(deadline) { |
3070 | + if b, err := json.Marshal(TimeoutError); err != nil { |
3071 | + conn.Write(internalError) |
3072 | + } else { |
3073 | + conn.Write(b) |
3074 | + } |
3075 | + return |
3076 | + } |
3077 | + rune, _, err := r.ReadRune() |
3078 | + if err != nil { |
3079 | + logger.Debugf("connection read error: %v", err) |
3080 | + if errors.Is(err, os.ErrDeadlineExceeded) { |
3081 | + if b, err := json.Marshal(TimeoutError); err != nil { |
3082 | + conn.Write(internalError) |
3083 | + } else { |
3084 | + conn.Write(b) |
3085 | + } |
3086 | + } else { |
3087 | + if b, err := json.Marshal(ConnError); err != nil { |
3088 | + conn.Write(internalError) |
3089 | + } else { |
3090 | + conn.Write(b) |
3091 | + } |
3092 | + } |
3093 | + return |
3094 | + } |
3095 | + b = append(b, byte(rune)) |
3096 | + switch rune { |
3097 | + case '{': |
3098 | + depth++ |
3099 | + case '}': |
3100 | + depth-- |
3101 | + } |
3102 | + // Must check here because the first pass always depth = 0 |
3103 | + if depth == 0 { |
3104 | + break |
3105 | + } |
3106 | + } |
3107 | + var req Request |
3108 | + err := json.Unmarshal(b, &req) |
3109 | + if err != nil { |
3110 | + if b, err := json.Marshal(BadRequestError); err != nil { |
3111 | + conn.Write(internalError) |
3112 | + } else { |
3113 | + conn.Write(b) |
3114 | + } |
3115 | + return |
3116 | + } |
3117 | + c.monitor.handlersMu.RLock() |
3118 | + defer c.monitor.handlersMu.RUnlock() |
3119 | + handler, ok := c.monitor.handlers[req.Command] |
3120 | + if !ok { |
3121 | + if b, err := json.Marshal(CmdNotFoundError); err != nil { |
3122 | + conn.Write(internalError) |
3123 | + } else { |
3124 | + conn.Write(b) |
3125 | + } |
3126 | + return |
3127 | + } |
3128 | + resp, err := handler(b) |
3129 | + if err != nil { |
3130 | + re := Response{Status: HandlerError.Status, StatusMessage: err.Error()} |
3131 | + if b, err := json.Marshal(re); err != nil { |
3132 | + resp = internalError |
3133 | + } else { |
3134 | + resp = b |
3135 | + } |
3136 | + } |
3137 | + conn.Write(resp) |
3138 | + }(conn) |
3139 | + } |
3140 | + }() |
3141 | + c.srv = srv |
3142 | + return nil |
3143 | +} |
3144 | diff --git a/google_guest_agent/command/command_test.go b/google_guest_agent/command/command_test.go |
3145 | new file mode 100644 |
3146 | index 0000000..96c2d9f |
3147 | --- /dev/null |
3148 | +++ b/google_guest_agent/command/command_test.go |
3149 | @@ -0,0 +1,209 @@ |
3150 | +// Copyright 2023 Google Inc. All Rights Reserved. |
3151 | +// |
3152 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
3153 | +// you may not use this file except in compliance with the License. |
3154 | +// You may obtain a copy of the License at |
3155 | +// |
3156 | +// http://www.apache.org/licenses/LICENSE-2.0 |
3157 | +// |
3158 | +// Unless required by applicable law or agreed to in writing, software |
3159 | +// distributed under the License is distributed on an "AS IS" BASIS, |
3160 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
3161 | +// See the License for the specific language governing permissions and |
3162 | +// limitations under the License. |
3163 | + |
3164 | +package command |
3165 | + |
3166 | +import ( |
3167 | + "context" |
3168 | + "encoding/json" |
3169 | + "fmt" |
3170 | + "io" |
3171 | + "math/rand" |
3172 | + "os/user" |
3173 | + "path" |
3174 | + "runtime" |
3175 | + "sync" |
3176 | + "testing" |
3177 | + "time" |
3178 | + |
3179 | + "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg" |
3180 | +) |
3181 | + |
3182 | +func cmdServerForTest(t *testing.T, pipeMode int, pipeGroup string, timeout time.Duration) *Server { |
3183 | + cs := &Server{ |
3184 | + pipe: getTestPipePath(t), |
3185 | + pipeMode: pipeMode, |
3186 | + pipeGroup: pipeGroup, |
3187 | + timeout: timeout, |
3188 | + monitor: &Monitor{ |
3189 | + handlersMu: new(sync.RWMutex), |
3190 | + handlers: make(map[string]Handler), |
3191 | + }, |
3192 | + } |
3193 | + cs.monitor.srv = cs |
3194 | + err := cs.start(testctx(t)) |
3195 | + if err != nil { |
3196 | + t.Fatal(err) |
3197 | + } |
3198 | + t.Cleanup(func() { |
3199 | + err := cs.Close() |
3200 | + if err != nil { |
3201 | + t.Errorf("error closing command server: %v", err) |
3202 | + } |
3203 | + }) |
3204 | + return cs |
3205 | +} |
3206 | + |
3207 | +func getTestPipePath(t *testing.T) string { |
3208 | + if runtime.GOOS == "windows" { |
3209 | + return `\\.\pipe\google-guest-agent-commands-test-` + t.Name() |
3210 | + } |
3211 | + return path.Join(t.TempDir(), "run", "pipe") |
3212 | +} |
3213 | + |
3214 | +func testctx(t *testing.T) context.Context { |
3215 | + d, ok := t.Deadline() |
3216 | + if !ok { |
3217 | + ctx, cancel := context.WithCancel(context.Background()) |
3218 | + t.Cleanup(cancel) |
3219 | + return ctx |
3220 | + } |
3221 | + ctx, cancel := context.WithDeadline(context.Background(), d) |
3222 | + t.Cleanup(cancel) |
3223 | + return ctx |
3224 | +} |
3225 | + |
3226 | +type testRequest struct { |
3227 | + Command string |
3228 | + ArbitraryData int |
3229 | +} |
3230 | + |
3231 | +func TestInit(t *testing.T) { |
3232 | + cfg.Load(nil) |
3233 | + cfg.Get().Unstable.CommandPipePath = getTestPipePath(t) |
3234 | + if cmdMonitor.srv != nil { |
3235 | + t.Fatal("internal command server already exists") |
3236 | + } |
3237 | + Init(testctx(t)) |
3238 | + if cmdMonitor.srv == nil { |
3239 | + t.Errorf("could not start internally managed command server") |
3240 | + } |
3241 | + if err := Close(); err != nil { |
3242 | + t.Errorf("could not close managed command server: %s", err) |
3243 | + } |
3244 | +} |
3245 | + |
3246 | +func TestListen(t *testing.T) { |
3247 | + cu, err := user.Current() |
3248 | + if err != nil { |
3249 | + t.Fatalf("could not get current user: %v", err) |
3250 | + } |
3251 | + ug, err := cu.GroupIds() |
3252 | + if err != nil { |
3253 | + t.Fatalf("could not get user groups for %s: %v", cu.Name, err) |
3254 | + } |
3255 | + resp := []byte(`{"Status":0,"StatusMessage":"OK"}`) |
3256 | + errresp := []byte(`{"Status":1,"StatusMessage":"ERR"}`) |
3257 | + req := []byte(`{"ArbitraryData":1234,"Command":"TestListen"}`) |
3258 | + h := func(b []byte) ([]byte, error) { |
3259 | + var r testRequest |
3260 | + err := json.Unmarshal(b, &r) |
3261 | + if err != nil || r.ArbitraryData != 1234 { |
3262 | + return errresp, nil |
3263 | + } |
3264 | + return resp, nil |
3265 | + } |
3266 | + |
3267 | + testcases := []struct { |
3268 | + name string |
3269 | + filemode int |
3270 | + group string |
3271 | + }{ |
3272 | + { |
3273 | + name: "world read/writeable", |
3274 | + filemode: 0777, |
3275 | + group: "-1", |
3276 | + }, |
3277 | + { |
3278 | + name: "group read/writeable", |
3279 | + filemode: 0770, |
3280 | + group: "-1", |
3281 | + }, |
3282 | + { |
3283 | + name: "user read/writeable", |
3284 | + filemode: 0700, |
3285 | + group: "-1", |
3286 | + }, |
3287 | + { |
3288 | + name: "additional user group as group owner", |
3289 | + filemode: 0770, |
3290 | + group: ug[rand.Intn(len(ug))], |
3291 | + }, |
3292 | + } |
3293 | + for _, tc := range testcases { |
3294 | + t.Run(tc.name, func(t *testing.T) { |
3295 | + cs := cmdServerForTest(t, tc.filemode, tc.group, time.Second) |
3296 | + err := cs.monitor.RegisterHandler("TestListen", h) |
3297 | + if err != nil { |
3298 | + t.Errorf("could not register handler: %v", err) |
3299 | + } |
3300 | + d := SendCmdPipe(testctx(t), cs.pipe, req) |
3301 | + var r Response |
3302 | + err = json.Unmarshal(d, &r) |
3303 | + if err != nil { |
3304 | + t.Error(err) |
3305 | + } |
3306 | + if r.Status != 0 || r.StatusMessage != "OK" { |
3307 | + t.Errorf("unexpected status from test-cmd, want 0, \"OK\" but got %d, %q", r.Status, r.StatusMessage) |
3308 | + } |
3309 | + }) |
3310 | + } |
3311 | +} |
3312 | + |
3313 | +func TestHandlerFailure(t *testing.T) { |
3314 | + req := []byte(`{"Command":"TestHandlerFailure"}`) |
3315 | + h := func(b []byte) ([]byte, error) { |
3316 | + return nil, fmt.Errorf("always fail") |
3317 | + } |
3318 | + |
3319 | + cs := cmdServerForTest(t, 0777, "-1", time.Second) |
3320 | + cs.monitor.RegisterHandler("TestHandlerFailure", h) |
3321 | + d := SendCmdPipe(testctx(t), cs.pipe, req) |
3322 | + var r Response |
3323 | + err := json.Unmarshal(d, &r) |
3324 | + if err != nil { |
3325 | + t.Error(err) |
3326 | + } |
3327 | + if r.Status != HandlerError.Status || r.StatusMessage != "always fail" { |
3328 | + t.Errorf("unexpected status from TestHandlerFailure, want %d, \"always fail\" but got %d, %q", HandlerError.Status, r.Status, r.StatusMessage) |
3329 | + } |
3330 | +} |
3331 | + |
3332 | +func TestListenTimeout(t *testing.T) { |
3333 | + expect, err := json.Marshal(TimeoutError) |
3334 | + if err != nil { |
3335 | + t.Fatal(err) |
3336 | + } |
3337 | + if runtime.GOOS == "windows" { |
3338 | + // winio library does not surface timeouts from the underlying net.Conn as |
3339 | + // timeouts, but as generic errors. Timeouts still work they just can't be |
3340 | + // detected as timeouts, so they are generic connErrors here. |
3341 | + expect, err = json.Marshal(ConnError) |
3342 | + if err != nil { |
3343 | + t.Fatal(err) |
3344 | + } |
3345 | + } |
3346 | + cs := cmdServerForTest(t, 0770, "-1", time.Millisecond) |
3347 | + conn, err := dialPipe(testctx(t), cs.pipe) |
3348 | + if err != nil { |
3349 | + t.Errorf("could not connect to command server: %v", err) |
3350 | + } |
3351 | + data, err := io.ReadAll(conn) |
3352 | + if err != nil { |
3353 | + t.Errorf("error reading response from command server: %v", err) |
3354 | + } |
3355 | + if string(data) != string(expect) { |
3356 | + t.Errorf("unexpected response from timed out connection, got %s but want %s", data, expect) |
3357 | + } |
3358 | +} |
3359 | diff --git a/google_guest_agent/command/command_windows.go b/google_guest_agent/command/command_windows.go |
3360 | new file mode 100644 |
3361 | index 0000000..0ac7213 |
3362 | --- /dev/null |
3363 | +++ b/google_guest_agent/command/command_windows.go |
3364 | @@ -0,0 +1,104 @@ |
3365 | +// Copyright 2023 Google Inc. All Rights Reserved. |
3366 | +// |
3367 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
3368 | +// you may not use this file except in compliance with the License. |
3369 | +// You may obtain a copy of the License at |
3370 | +// |
3371 | +// http://www.apache.org/licenses/LICENSE-2.0 |
3372 | +// |
3373 | +// Unless required by applicable law or agreed to in writing, software |
3374 | +// distributed under the License is distributed on an "AS IS" BASIS, |
3375 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
3376 | +// See the License for the specific language governing permissions and |
3377 | +// limitations under the License. |
3378 | + |
3379 | +package command |
3380 | + |
3381 | +import ( |
3382 | + "context" |
3383 | + "fmt" |
3384 | + "net" |
3385 | + "os/user" |
3386 | + |
3387 | + "github.com/GoogleCloudPlatform/guest-logging-go/logger" |
3388 | + "github.com/Microsoft/go-winio" |
3389 | +) |
3390 | + |
3391 | +const ( |
3392 | + // DefaultPipePath is the default named pipe path for windows. |
3393 | + DefaultPipePath = `\\.\pipe\google-guest-agent-commands` |
3394 | + nullSID = "S-1-0-0" |
3395 | + worldSID = "S-1-1-0" |
3396 | + creatorOwnerSID = "S-1-3-0" |
3397 | + creatorGroupSID = "S-1-3-1" |
3398 | +) |
3399 | + |
3400 | +func genSecurityDescriptor(filemode int, grp string) string { |
3401 | + // This function translates the intention of a unix file mode and owner group into an appropriate SDDL security descriptor for a windows named pipe. |
3402 | + owner := creatorOwnerSID |
3403 | + group := creatorGroupSID |
3404 | + |
3405 | + wPerm := filemode % 010 |
3406 | + filemode /= 010 |
3407 | + gPerm := filemode % 010 |
3408 | + filemode /= 010 |
3409 | + uPerm := filemode % 010 |
3410 | + |
3411 | + // Having only read or only write access to a bidirectional pipe is pointless so we treat access for user/group as yes or no based on whether the permission grants RW access |
3412 | + if uPerm < 06 { |
3413 | + owner = nullSID |
3414 | + } |
3415 | + if gPerm < 06 { |
3416 | + group = nullSID |
3417 | + } |
3418 | + // If permissions grant world RW, make world the owner |
3419 | + if wPerm > 05 { |
3420 | + owner = worldSID |
3421 | + group = worldSID |
3422 | + } |
3423 | + |
3424 | + // Group is handled as supplemental DACL, but ignore it if user specified no group rw permission |
3425 | + var dacl string |
3426 | + if gPerm > 05 { |
3427 | + g, err := user.LookupGroup(grp) |
3428 | + if err != nil { |
3429 | + logger.Errorf("Could not lookup group %s SID, this group will not be included in the command server security descriptor: %v", grp, err) |
3430 | + } else { |
3431 | + // Allow access;Protected DACL;Allow all general access;Empty object guid;Empty inherit object guid;group sid from lookup |
3432 | + dacl = fmt.Sprintf("D:(A;P;GA;;;%s)", g.Gid) |
3433 | + } |
3434 | + } |
3435 | + |
3436 | + sddl := "O:%sG:%s%s" |
3437 | + return fmt.Sprintf(sddl, owner, group, dacl) |
3438 | +} |
3439 | + |
3440 | +func listen(ctx context.Context, path string, filemode int, group string) (net.Listener, error) { |
3441 | + // Winio library does not provide any method to listen on context. Failing to |
3442 | + // specify a pipeconfig (or using the zero value) results in flaky ACCESS_DENIED |
3443 | + // errors when re-opening the same pipe (~1/10). |
3444 | + // https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#remarks |
3445 | + // Even with a pipeconfig, this flakes ~1/200 runs, hence the retry until the |
3446 | + // context is expired or listen is successful. |
3447 | + var l net.Listener |
3448 | + var lastError error |
3449 | + for { |
3450 | + if ctx.Err() != nil { |
3451 | + return nil, fmt.Errorf("context expired: %v before successful listen (last error: %v)", ctx.Err(), lastError) |
3452 | + } |
3453 | + config := &winio.PipeConfig{ |
3454 | + MessageMode: false, |
3455 | + InputBufferSize: 1024, |
3456 | + OutputBufferSize: 1024, |
3457 | + SecurityDescriptor: genSecurityDescriptor(filemode, group), |
3458 | + } |
3459 | + l, lastError = winio.ListenPipe(path, config) |
3460 | + if lastError == nil { |
3461 | + return l, lastError |
3462 | + } |
3463 | + } |
3464 | +} |
3465 | + |
3466 | +func dialPipe(ctx context.Context, pipe string) (net.Conn, error) { |
3467 | + return winio.DialPipeContext(ctx, pipe) |
3468 | +} |
3469 | diff --git a/google_guest_agent/command/command_windows_test.go b/google_guest_agent/command/command_windows_test.go |
3470 | new file mode 100644 |
3471 | index 0000000..b4789e2 |
3472 | --- /dev/null |
3473 | +++ b/google_guest_agent/command/command_windows_test.go |
3474 | @@ -0,0 +1,73 @@ |
3475 | +//go:build windows |
3476 | + |
3477 | +// Copyright 2023 Google Inc. All Rights Reserved. |
3478 | +// |
3479 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
3480 | +// you may not use this file except in compliance with the License. |
3481 | +// You may obtain a copy of the License at |
3482 | +// |
3483 | +// http://www.apache.org/licenses/LICENSE-2.0 |
3484 | +// |
3485 | +// Unless required by applicable law or agreed to in writing, software |
3486 | +// distributed under the License is distributed on an "AS IS" BASIS, |
3487 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
3488 | +// See the License for the specific language governing permissions and |
3489 | +// limitations under the License. |
3490 | +package command |
3491 | + |
3492 | +import ( |
3493 | + "os/user" |
3494 | + "testing" |
3495 | +) |
3496 | + |
3497 | +func TestGenSecurityDescriptor(t *testing.T) { |
3498 | + guest, err := user.LookupGroup("Guests") |
3499 | + if err != nil { |
3500 | + t.Fatal(err) |
3501 | + } |
3502 | + testcases := []struct { |
3503 | + name string |
3504 | + filemode int |
3505 | + group string |
3506 | + output string |
3507 | + }{ |
3508 | + { |
3509 | + name: "world writeable", |
3510 | + filemode: 0777, |
3511 | + group: nullSID, |
3512 | + output: "O:" + worldSID + "G:" + worldSID, |
3513 | + }, |
3514 | + { |
3515 | + name: "user+group writable", |
3516 | + filemode: 0770, |
3517 | + group: "", |
3518 | + output: "O:" + creatorOwnerSID + "G:" + creatorGroupSID, |
3519 | + }, |
3520 | + { |
3521 | + name: "user writable", |
3522 | + filemode: 0700, |
3523 | + group: nullSID, |
3524 | + output: "O:" + creatorOwnerSID + "G:" + nullSID, |
3525 | + }, |
3526 | + { |
3527 | + name: "no write permissions", |
3528 | + filemode: 000, |
3529 | + group: nullSID, |
3530 | + output: "O:" + nullSID + "G:" + nullSID, |
3531 | + }, |
3532 | + { |
3533 | + name: "custom named group", |
3534 | + filemode: 0770, |
3535 | + group: "Guests", |
3536 | + output: "O:" + creatorOwnerSID + "G:" + creatorGroupSID + "D:(A;P;GA;;;" + guest.Gid + ")", |
3537 | + }, |
3538 | + } |
3539 | + for _, tc := range testcases { |
3540 | + t.Run(tc.name, func(t *testing.T) { |
3541 | + sd := genSecurityDescriptor(tc.filemode, tc.group) |
3542 | + if sd != tc.output { |
3543 | + t.Errorf("unexpected output from genSecurityDescriptor(%d, %s), got %s want %s", tc.filemode, tc.group, sd, tc.output) |
3544 | + } |
3545 | + }) |
3546 | + } |
3547 | +} |
3548 | diff --git a/google_guest_agent/diagnostics.go b/google_guest_agent/diagnostics.go |
3549 | index b2d239b..62cde39 100644 |
3550 | --- a/google_guest_agent/diagnostics.go |
3551 | +++ b/google_guest_agent/diagnostics.go |
3552 | @@ -19,6 +19,7 @@ import ( |
3553 | "encoding/json" |
3554 | "reflect" |
3555 | "runtime" |
3556 | + "slices" |
3557 | "sync/atomic" |
3558 | |
3559 | "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg" |
3560 | @@ -94,7 +95,7 @@ func (d *diagnosticsMgr) Set(ctx context.Context) error { |
3561 | } |
3562 | |
3563 | strEntry := newMetadata.Instance.Attributes.Diagnostics |
3564 | - if utils.ContainsString(strEntry, diagnosticsEntries) { |
3565 | + if slices.Contains(diagnosticsEntries, strEntry) { |
3566 | return nil |
3567 | } |
3568 | diagnosticsEntries = append(diagnosticsEntries, strEntry) |
3569 | diff --git a/google_guest_agent/events/events.go b/google_guest_agent/events/events.go |
3570 | index 2bf57e3..24d10dd 100644 |
3571 | --- a/google_guest_agent/events/events.go |
3572 | +++ b/google_guest_agent/events/events.go |
3573 | @@ -21,13 +21,14 @@ import ( |
3574 | "sync" |
3575 | |
3576 | "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/events/metadata" |
3577 | - "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/events/sshtrustedca" |
3578 | "github.com/GoogleCloudPlatform/guest-logging-go/logger" |
3579 | ) |
3580 | |
3581 | var ( |
3582 | - // availableWatchers mapps all kown available event watchers. |
3583 | - availableWatchers = make(map[string]Watcher) |
3584 | + defaultWatchers = []Watcher{ |
3585 | + metadata.New(), |
3586 | + } |
3587 | + instance *Manager |
3588 | ) |
3589 | |
3590 | // Watcher defines the interface between the events manager and the actual |
3591 | @@ -48,14 +49,65 @@ type Watcher interface { |
3592 | // Manager defines the interface between events management layer and the |
3593 | // core guest agent implementation. |
3594 | type Manager struct { |
3595 | - watchers map[string]Watcher |
3596 | + // watcherEvents maps the registered watchers and their events. |
3597 | + watcherEvents []*WatcherEventType |
3598 | + |
3599 | + // watchersMap is a convenient manager's mapping of registered watcher instances. |
3600 | + watchersMap map[string]bool |
3601 | + |
3602 | + // watchersMutex protects the watchers map. |
3603 | + watchersMutex sync.Mutex |
3604 | + |
3605 | + // removingWatcherEvents is a map of watchers being removed. |
3606 | + removingWatcherEvents map[string]bool |
3607 | + |
3608 | + // running is a flag indicating if the Run() was previously called. |
3609 | + running bool |
3610 | + |
3611 | + // runningMutex protects the running flag. |
3612 | + runningMutex sync.RWMutex |
3613 | + |
3614 | + // subscribers maps the subscribed callbacks. |
3615 | subscribers map[string][]*eventSubscriber |
3616 | + |
3617 | + // subscribersMutex protects subscribers member/map of the manager object. |
3618 | + subscribersMutex sync.Mutex |
3619 | + |
3620 | + // queue queue struct manages the running watchers, when it gets to len() |
3621 | + // down to zero means all watchers are done and we can signal the other |
3622 | + // control go routines to leave(given we don't have any more job left to |
3623 | + // process). |
3624 | + queue *watcherQueue |
3625 | } |
3626 | |
3627 | -// Config offers a mechanism for the consumers to configure the Manager behavior. |
3628 | -type Config struct { |
3629 | - // Watchers lists the enabled watchers, of not provided all available watchers will be enabled. |
3630 | - Watchers []string |
3631 | +// watcherQueue wraps the watchers <-> callbacks communication as well as the |
3632 | +// communication/coordination of the multiple control go routine i.e. the one |
3633 | +// responsible to calling callbacks after a event is produced by the watcher etc. |
3634 | +type watcherQueue struct { |
3635 | + // queueMutex protects the access to watchersMap |
3636 | + queueMutex sync.RWMutex |
3637 | + |
3638 | + // watchersMap maps the currently running watchers. |
3639 | + watchersMap map[string]bool |
3640 | + |
3641 | + // finishContextHandler is a channel used to communicate with the context handling |
3642 | + // go routine that it should finish/end its job (usually after all watchers are done). |
3643 | + finishContextHandler chan bool |
3644 | + |
3645 | + // finishCallbackHandler is a channel used to communicate with the callback handling |
3646 | + // go routine that it should finish/end its job (usually after all watchers are done). |
3647 | + finishCallbackHandler chan bool |
3648 | + |
3649 | + // watcherDone is a channel used to communicate that a given watcher is finished/done. |
3650 | + watcherDone chan string |
3651 | + |
3652 | + // dataBus is the channel used to communicate between watchers (event producer) and the |
3653 | + // callback handler (event consumer managing go routine). |
3654 | + dataBus chan eventBusData |
3655 | + |
3656 | + // leaving is a flag that indicates no more job should be processed as we are done |
3657 | + // with all watchers and callbacks. |
3658 | + leaving bool |
3659 | } |
3660 | |
3661 | // EventData wraps the data communicated from a Watcher to a Subscriber. |
3662 | @@ -72,6 +124,20 @@ type WatcherEventType struct { |
3663 | watcher Watcher |
3664 | // evType idenfities the event type this object refences to. |
3665 | evType string |
3666 | + // removed is a channel used to communicate with the running watcher go routine |
3667 | + // that it shouldn't renew even if the watcher requested a renew (in response of a |
3668 | + // RemoveWatcher() call. |
3669 | + removed chan bool |
3670 | +} |
3671 | + |
3672 | +type eventSubscriber struct { |
3673 | + data interface{} |
3674 | + cb *EventCb |
3675 | +} |
3676 | + |
3677 | +type eventBusData struct { |
3678 | + evType string |
3679 | + data *EventData |
3680 | } |
3681 | |
3682 | // EventCb defines the callback interface between watchers and subscribers. The arguments are: |
3683 | @@ -84,224 +150,328 @@ type WatcherEventType struct { |
3684 | // to be unregistered/unsubscribed. |
3685 | type EventCb func(ctx context.Context, evType string, data interface{}, evData *EventData) bool |
3686 | |
3687 | -type eventSubscriber struct { |
3688 | - data interface{} |
3689 | - cb EventCb |
3690 | +// length returns how many watchers are currently running. |
3691 | +func (ep *watcherQueue) length() int { |
3692 | + ep.queueMutex.RLock() |
3693 | + defer ep.queueMutex.RUnlock() |
3694 | + return len(ep.watchersMap) |
3695 | } |
3696 | |
3697 | -type eventBusData struct { |
3698 | - evType string |
3699 | - data *EventData |
3700 | +// add adds a new watcher to the queue. |
3701 | +func (ep *watcherQueue) add(evType string) { |
3702 | + ep.queueMutex.Lock() |
3703 | + defer ep.queueMutex.Unlock() |
3704 | + ep.watchersMap[evType] = true |
3705 | } |
3706 | |
3707 | -func init() { |
3708 | - err := initWatchers([]Watcher{ |
3709 | - metadata.New(), |
3710 | - sshtrustedca.New(sshtrustedca.DefaultPipePath), |
3711 | - }) |
3712 | - if err != nil { |
3713 | - logger.Errorf("Failed to initialize watchers: %+v", err) |
3714 | - } |
3715 | +// del removes a watcher from the queue. |
3716 | +func (ep *watcherQueue) del(evType string) int { |
3717 | + ep.queueMutex.Lock() |
3718 | + defer ep.queueMutex.Unlock() |
3719 | + delete(ep.watchersMap, evType) |
3720 | + return len(ep.watchersMap) |
3721 | } |
3722 | |
3723 | -// init initializes the known available event watchers. |
3724 | -func initWatchers(watchers []Watcher) error { |
3725 | - for _, curr := range watchers { |
3726 | - // Error if we are accidentaly not properly setting the id. |
3727 | - if curr.ID() == "" { |
3728 | - return fmt.Errorf("invalid event watcher id, skipping") |
3729 | +// AddDefaultWatchers add the default watchers: |
3730 | +// - metadata |
3731 | +func (mngr *Manager) AddDefaultWatchers(ctx context.Context) error { |
3732 | + for _, curr := range defaultWatchers { |
3733 | + if err := mngr.AddWatcher(ctx, curr); err != nil { |
3734 | + return err |
3735 | } |
3736 | - availableWatchers[curr.ID()] = curr |
3737 | } |
3738 | return nil |
3739 | } |
3740 | |
3741 | -// New allocates and initializes a events Manager based on provided cfg. |
3742 | -func New(cfg *Config) (*Manager, error) { |
3743 | - res := &Manager{ |
3744 | - watchers: availableWatchers, |
3745 | - subscribers: make(map[string][]*eventSubscriber), |
3746 | +// newManager allocates and initializes a events Manager. |
3747 | +func newManager() *Manager { |
3748 | + return &Manager{ |
3749 | + watchersMap: make(map[string]bool), |
3750 | + removingWatcherEvents: make(map[string]bool), |
3751 | + subscribers: make(map[string][]*eventSubscriber), |
3752 | + queue: &watcherQueue{ |
3753 | + watchersMap: make(map[string]bool), |
3754 | + dataBus: make(chan eventBusData), |
3755 | + finishCallbackHandler: make(chan bool), |
3756 | + finishContextHandler: make(chan bool), |
3757 | + watcherDone: make(chan string), |
3758 | + }, |
3759 | } |
3760 | +} |
3761 | |
3762 | - // Align manager's config based on consumers provided watchers If it is |
3763 | - // passing in wanted/expected watchers, otherwise use all available ones. |
3764 | - if cfg != nil && len(cfg.Watchers) > 0 { |
3765 | - res.watchers = make(map[string]Watcher) |
3766 | - for _, curr := range cfg.Watchers { |
3767 | - // Report back if we don't know the provided watcher id. |
3768 | - if _, found := availableWatchers[curr]; !found { |
3769 | - return nil, fmt.Errorf("invalid/unknown watcher id: %s", curr) |
3770 | - } |
3771 | - res.watchers[curr] = availableWatchers[curr] |
3772 | - } |
3773 | - } |
3774 | +func init() { |
3775 | + instance = newManager() |
3776 | +} |
3777 | |
3778 | - return res, nil |
3779 | +// Get allocates a new manager if one doesn't exists or returns the one previously allocated. |
3780 | +func Get() *Manager { |
3781 | + if instance == nil { |
3782 | + panic("The event's manager instance should had being initialized.") |
3783 | + } |
3784 | + return instance |
3785 | } |
3786 | |
3787 | // Subscribe registers an event consumer/subscriber callback to a given event type, data |
3788 | // is a context pointer provided by the caller to be passed down when calling cb when |
3789 | // a new event happens. |
3790 | func (mngr *Manager) Subscribe(evType string, data interface{}, cb EventCb) { |
3791 | + mngr.subscribersMutex.Lock() |
3792 | + defer mngr.subscribersMutex.Unlock() |
3793 | mngr.subscribers[evType] = append(mngr.subscribers[evType], |
3794 | &eventSubscriber{ |
3795 | data: data, |
3796 | - cb: cb, |
3797 | + cb: &cb, |
3798 | }, |
3799 | ) |
3800 | } |
3801 | |
3802 | -func (mngr *Manager) eventTypes() []*WatcherEventType { |
3803 | - var res []*WatcherEventType |
3804 | - for _, watcher := range mngr.watchers { |
3805 | - for _, evType := range watcher.Events() { |
3806 | - res = append(res, &WatcherEventType{watcher, evType}) |
3807 | +func (mngr *Manager) unsubscribe(evType string, cb *EventCb) { |
3808 | + var keepMe []*eventSubscriber |
3809 | + for _, curr := range mngr.subscribers[evType] { |
3810 | + if curr.cb != cb { |
3811 | + keepMe = append(keepMe, curr) |
3812 | } |
3813 | } |
3814 | - return res |
3815 | + |
3816 | + mngr.subscribers[evType] = keepMe |
3817 | + |
3818 | + if len(keepMe) == 0 { |
3819 | + logger.Debugf("No more subscribers left for evType: %s", evType) |
3820 | + delete(mngr.subscribers, evType) |
3821 | + } |
3822 | } |
3823 | |
3824 | -type watcherQueue struct { |
3825 | - mutex sync.Mutex |
3826 | - watchersMap map[string]bool |
3827 | +// Unsubscribe removes the subscription of a given callback for a given event type. |
3828 | +func (mngr *Manager) Unsubscribe(evType string, cb EventCb) { |
3829 | + mngr.subscribersMutex.Lock() |
3830 | + defer mngr.subscribersMutex.Unlock() |
3831 | + mngr.unsubscribe(evType, &cb) |
3832 | } |
3833 | |
3834 | -func (ep *watcherQueue) add(evType string) { |
3835 | - ep.mutex.Lock() |
3836 | - defer ep.mutex.Unlock() |
3837 | - ep.watchersMap[evType] = true |
3838 | +// RemoveWatcher removes a watcher from the event manager. Each running watcher has its own |
3839 | +// context (derived from the one provided in the AddWatcher() call) and will have it canceled |
3840 | +// after calling this method. |
3841 | +func (mngr *Manager) RemoveWatcher(ctx context.Context, watcher Watcher) error { |
3842 | + mngr.watchersMutex.Lock() |
3843 | + defer mngr.watchersMutex.Unlock() |
3844 | + |
3845 | + id := watcher.ID() |
3846 | + logger.Debugf("Got a request to remove watcher: %s", id) |
3847 | + if _, found := mngr.watchersMap[id]; !found { |
3848 | + return fmt.Errorf("unknown Watcher(%s)", id) |
3849 | + } |
3850 | + |
3851 | + for _, curr := range mngr.watcherEvents { |
3852 | + if _, found := mngr.removingWatcherEvents[curr.evType]; found { |
3853 | + logger.Debugf("Watcher(%s) is being removed, skipping removal request: %s", id, curr.evType) |
3854 | + continue |
3855 | + } |
3856 | + |
3857 | + if curr.watcher.ID() == id { |
3858 | + mngr.removingWatcherEvents[curr.evType] = true |
3859 | + logger.Debugf("Removing watcher: %s, event type: %s", id, curr.evType) |
3860 | + curr.removed <- true |
3861 | + } |
3862 | + } |
3863 | + |
3864 | + return nil |
3865 | } |
3866 | |
3867 | -func (ep *watcherQueue) del(evType string) int { |
3868 | - ep.mutex.Lock() |
3869 | - defer ep.mutex.Unlock() |
3870 | - delete(ep.watchersMap, evType) |
3871 | - return len(ep.watchersMap) |
3872 | +// AddWatcher adds/enables a new watcher. The watcher will be fired up right away if the |
3873 | +// event manager is already running, otherwise it's scheduled to run when Run() is called. |
3874 | +func (mngr *Manager) AddWatcher(ctx context.Context, watcher Watcher) error { |
3875 | + mngr.watchersMutex.Lock() |
3876 | + defer mngr.watchersMutex.Unlock() |
3877 | + id := watcher.ID() |
3878 | + if _, found := mngr.watchersMap[id]; found { |
3879 | + return fmt.Errorf("watcher(%s) was previously added", id) |
3880 | + } |
3881 | + |
3882 | + // Add the watchers and its events to internal mappings. |
3883 | + evTypes := make(map[string]*WatcherEventType) |
3884 | + mngr.watchersMap[id] = true |
3885 | + |
3886 | + for _, curr := range watcher.Events() { |
3887 | + evType := &WatcherEventType{ |
3888 | + watcher: watcher, |
3889 | + evType: curr, |
3890 | + removed: make(chan bool), |
3891 | + } |
3892 | + |
3893 | + evTypes[curr] = evType |
3894 | + mngr.watcherEvents = append(mngr.watcherEvents, evType) |
3895 | + } |
3896 | + |
3897 | + mngr.runningMutex.RLock() |
3898 | + defer mngr.runningMutex.RUnlock() |
3899 | + // If we are not running don't bother "running" the watcher, Run() will do it later. |
3900 | + if !mngr.running { |
3901 | + return nil |
3902 | + } |
3903 | + |
3904 | + // If we are already running the "run/launch" the watcher. |
3905 | + for _, curr := range watcher.Events() { |
3906 | + logger.Debugf("Adding watcher for event: %s", curr) |
3907 | + mngr.queue.add(curr) |
3908 | + go func(watcher Watcher, evType string, removed chan bool) { |
3909 | + mngr.runWatcher(ctx, watcher, evType, removed) |
3910 | + }(watcher, curr, evTypes[curr].removed) |
3911 | + } |
3912 | + |
3913 | + return nil |
3914 | +} |
3915 | + |
3916 | +func (mngr *Manager) runWatcher(ctx context.Context, watcher Watcher, evType string, removed chan bool) { |
3917 | + nCtx, cancel := context.WithCancel(ctx) |
3918 | + abort := false |
3919 | + id := watcher.ID() |
3920 | + |
3921 | + go func() { |
3922 | + abort = <-removed |
3923 | + logger.Debugf("Got a request to abort watcher(%s) for event: %s", id, evType) |
3924 | + cancel() |
3925 | + }() |
3926 | + |
3927 | + for renew := true; renew; { |
3928 | + var evData interface{} |
3929 | + var err error |
3930 | + |
3931 | + renew, evData, err = watcher.Run(nCtx, evType) |
3932 | + |
3933 | + logger.Debugf("Watcher(%s) returned event: %q, should renew?: %t", id, evType, renew) |
3934 | + |
3935 | + if abort || mngr.queue.leaving { |
3936 | + logger.Debugf("Watcher(%s), either are aborting(%t) or leaving(%t), breaking renew cycle", |
3937 | + id, abort, mngr.queue.leaving) |
3938 | + break |
3939 | + } |
3940 | + |
3941 | + mngr.queue.dataBus <- eventBusData{ |
3942 | + evType: evType, |
3943 | + data: &EventData{ |
3944 | + Data: evData, |
3945 | + Error: err, |
3946 | + }, |
3947 | + } |
3948 | + } |
3949 | + |
3950 | + logger.Debugf("watcher finishing: %s", evType) |
3951 | + if !abort { |
3952 | + removed <- true |
3953 | + } |
3954 | + |
3955 | + mngr.queue.watcherDone <- evType |
3956 | } |
3957 | |
3958 | // Run runs the event manager, it will block until all watchers have given up/failed. |
3959 | -func (mngr *Manager) Run(ctx context.Context) { |
3960 | +// The event manager is meant to be started right after the early initialization code |
3961 | +// and live until the application ends, the event manager can not be restarted - the Run() |
3962 | +// method will return an error if one tries to run it twice. |
3963 | +func (mngr *Manager) Run(ctx context.Context) error { |
3964 | var wg sync.WaitGroup |
3965 | - var leaving bool |
3966 | - |
3967 | - syncBus := make(chan eventBusData) |
3968 | - defer close(syncBus) |
3969 | |
3970 | - cancelContext := make(chan bool) |
3971 | - cancelCallback := make(chan bool) |
3972 | + mngr.runningMutex.Lock() |
3973 | + if mngr.running { |
3974 | + mngr.runningMutex.Unlock() |
3975 | + return fmt.Errorf("tried calling event manager's Run() twice") |
3976 | + } |
3977 | + mngr.running = true |
3978 | + mngr.runningMutex.Unlock() |
3979 | |
3980 | - defer close(cancelContext) |
3981 | - defer close(cancelCallback) |
3982 | + queue := mngr.queue |
3983 | |
3984 | // Manages the context's done signal, pass it down to the other go routines to |
3985 | // finish its job and leave. Additionally, if the remaining go routines are leaving |
3986 | - // we get it handled via syncBus channel and drop this go routine as well. |
3987 | + // we get it handled via dataBus channel and drop this go routine as well. |
3988 | wg.Add(1) |
3989 | - go func(done <-chan struct{}, cancelContext <-chan bool, cancelCallback chan<- bool) { |
3990 | + go func(done <-chan struct{}, finishContextHandler <-chan bool, finishCallbackHandler chan<- bool) { |
3991 | defer wg.Done() |
3992 | |
3993 | for { |
3994 | select { |
3995 | case <-done: |
3996 | logger.Debugf("Got context's Done() signal, leaving.") |
3997 | - leaving = true |
3998 | - cancelCallback <- true |
3999 | + queue.leaving = true |
4000 | + finishCallbackHandler <- true |
4001 | return |
4002 | - case <-cancelContext: |
4003 | - leaving = true |
4004 | + case <-finishContextHandler: |
4005 | + logger.Debugf("Got context handler finish signal, leaving.") |
4006 | + queue.leaving = true |
4007 | return |
4008 | } |
4009 | } |
4010 | - }(ctx.Done(), cancelContext, cancelCallback) |
4011 | + }(ctx.Done(), queue.finishContextHandler, queue.finishCallbackHandler) |
4012 | |
4013 | // Manages the event processing avoiding blocking the watcher's go routines. |
4014 | - // This will listen to syncBus and call the events handlers/callbacks. |
4015 | + // This will listen to dataBus and call the events handlers/callbacks. |
4016 | wg.Add(1) |
4017 | - go func(bus <-chan eventBusData, cancelCallback <-chan bool) { |
4018 | + go func(bus <-chan eventBusData, finishCallbackHandler <-chan bool) { |
4019 | defer wg.Done() |
4020 | |
4021 | for { |
4022 | select { |
4023 | - case <-cancelCallback: |
4024 | + case <-finishCallbackHandler: |
4025 | return |
4026 | case busData := <-bus: |
4027 | - subscribers, found := mngr.subscribers[busData.evType] |
4028 | - if !found || len(subscribers) == 0 { |
4029 | + subscribers := mngr.subscribers[busData.evType] |
4030 | + if subscribers == nil { |
4031 | logger.Debugf("No subscriber found for event: %s, returning.", busData.evType) |
4032 | continue |
4033 | } |
4034 | |
4035 | - keepMe := make([]*eventSubscriber, 0) |
4036 | - for _, curr := range mngr.subscribers[busData.evType] { |
4037 | + deleteMe := make([]*eventSubscriber, 0) |
4038 | + for _, curr := range subscribers { |
4039 | logger.Debugf("Running registered callback for event: %s", busData.evType) |
4040 | - renew := curr.cb(ctx, busData.evType, curr.data, busData.data) |
4041 | - if renew { |
4042 | - keepMe = append(keepMe, curr) |
4043 | + renew := (*curr.cb)(ctx, busData.evType, curr.data, busData.data) |
4044 | + if !renew { |
4045 | + deleteMe = append(deleteMe, curr) |
4046 | } |
4047 | logger.Debugf("Returning from event %q subscribed callback, should renew?: %t", busData.evType, renew) |
4048 | } |
4049 | |
4050 | - mngr.subscribers[busData.evType] = keepMe |
4051 | - |
4052 | - // No more subscribers for this event type, delete it from the subscribers map. |
4053 | - if len(keepMe) == 0 { |
4054 | - logger.Debugf("No more subscribers left for evType: %s", busData.evType) |
4055 | - delete(mngr.subscribers, busData.evType) |
4056 | + mngr.subscribersMutex.Lock() |
4057 | + for _, curr := range deleteMe { |
4058 | + mngr.unsubscribe(busData.evType, curr.cb) |
4059 | } |
4060 | + leave := mngr.subscribers[busData.evType] == nil |
4061 | + mngr.subscribersMutex.Unlock() |
4062 | |
4063 | // No more subscribers at all, we have nothing more left to do here. |
4064 | - if len(mngr.subscribers) == 0 { |
4065 | + if leave { |
4066 | logger.Debugf("No subscribers left, leaving") |
4067 | break |
4068 | } |
4069 | } |
4070 | } |
4071 | - }(syncBus, cancelCallback) |
4072 | - |
4073 | - // This control struct manages the registered watchers, when it gets to len() |
4074 | - // down to zero means all watchers are done and we can signal the other 2 control |
4075 | - // go routines to leave(given we don't have any more job left to process). |
4076 | - control := &watcherQueue{ |
4077 | - watchersMap: make(map[string]bool), |
4078 | - } |
4079 | + }(queue.dataBus, queue.finishCallbackHandler) |
4080 | |
4081 | // Creates a goroutine for each registered watcher's event and keep handling its |
4082 | // execution until they give up/finishes their job by returning renew = false. |
4083 | - for _, curr := range mngr.eventTypes() { |
4084 | - control.add(curr.evType) |
4085 | - wg.Add(1) |
4086 | - |
4087 | - go func(bus chan<- eventBusData, watcher Watcher, evType string, cancelContext chan<- bool, cancelCallback chan<- bool) { |
4088 | - var evData interface{} |
4089 | - var err error |
4090 | - |
4091 | - defer wg.Done() |
4092 | - |
4093 | - for renew := true; renew; { |
4094 | - renew, evData, err = watcher.Run(ctx, evType) |
4095 | - |
4096 | - logger.Debugf("Watcher(%s) returned event: %q, should renew?: %t", watcher.ID(), evType, renew) |
4097 | - |
4098 | - if leaving { |
4099 | - break |
4100 | - } |
4101 | + for _, curr := range mngr.watcherEvents { |
4102 | + queue.add(curr.evType) |
4103 | + go func(watcher Watcher, evType string, removed chan bool) { |
4104 | + mngr.runWatcher(ctx, watcher, evType, removed) |
4105 | + }(curr.watcher, curr.evType, curr.removed) |
4106 | + } |
4107 | |
4108 | - bus <- eventBusData{ |
4109 | - evType: evType, |
4110 | - data: &EventData{ |
4111 | - Data: evData, |
4112 | - Error: err, |
4113 | - }, |
4114 | - } |
4115 | - } |
4116 | + // Controls the completion of the watcher go routines, their removal from the queue |
4117 | + // and signals to context & callback control go routines about watchers completion. |
4118 | + wg.Add(1) |
4119 | + go func() { |
4120 | + defer wg.Done() |
4121 | |
4122 | - if !leaving && control.del(evType) == 0 { |
4123 | + for len := queue.length(); len > 0; { |
4124 | + doneStr := <-queue.watcherDone |
4125 | + len = queue.del(doneStr) |
4126 | + delete(mngr.removingWatcherEvents, doneStr) |
4127 | + if !queue.leaving && len == 0 { |
4128 | logger.Debugf("All watchers are finished, signaling to leave.") |
4129 | - cancelContext <- true |
4130 | - cancelCallback <- true |
4131 | + queue.finishContextHandler <- true |
4132 | + queue.finishCallbackHandler <- true |
4133 | } |
4134 | - }(syncBus, curr.watcher, curr.evType, cancelContext, cancelCallback) |
4135 | - } |
4136 | + } |
4137 | + }() |
4138 | |
4139 | wg.Wait() |
4140 | + return nil |
4141 | } |
4142 | diff --git a/google_guest_agent/events/events_test.go b/google_guest_agent/events/events_test.go |
4143 | index 6fcb2c3..a77bfb9 100644 |
4144 | --- a/google_guest_agent/events/events_test.go |
4145 | +++ b/google_guest_agent/events/events_test.go |
4146 | @@ -17,48 +17,26 @@ package events |
4147 | import ( |
4148 | "context" |
4149 | "fmt" |
4150 | + "sync" |
4151 | "testing" |
4152 | "time" |
4153 | |
4154 | "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/events/metadata" |
4155 | ) |
4156 | |
4157 | -func TestConstructor(t *testing.T) { |
4158 | - tests := []struct { |
4159 | - config *Config |
4160 | - success bool |
4161 | - }{ |
4162 | - {config: nil, success: true}, |
4163 | - {config: &Config{Watchers: []string{metadata.WatcherID}}, success: true}, |
4164 | - {config: &Config{Watchers: []string{"foobar"}}, success: false}, |
4165 | - } |
4166 | +func TestAddWatcher(t *testing.T) { |
4167 | + eventManager := newManager() |
4168 | + metadataWatcher := metadata.New() |
4169 | + ctx := context.Background() |
4170 | |
4171 | - for i, tt := range tests { |
4172 | - t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { |
4173 | - _, err := New(tt.config) |
4174 | - if err != nil && tt.success { |
4175 | - t.Errorf("expected success, got error: %+v", err) |
4176 | - } |
4177 | - }) |
4178 | - } |
4179 | -} |
4180 | - |
4181 | -func TestInitWatcers(t *testing.T) { |
4182 | - tests := []struct { |
4183 | - watchers []Watcher |
4184 | - success bool |
4185 | - }{ |
4186 | - {watchers: []Watcher{metadata.New()}, success: true}, |
4187 | - {watchers: []Watcher{&testWatcher{}}, success: false}, |
4188 | + err := eventManager.AddWatcher(ctx, metadataWatcher) |
4189 | + if err != nil { |
4190 | + t.Errorf("expected success, got error: %+v", err) |
4191 | } |
4192 | |
4193 | - for i, tt := range tests { |
4194 | - t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { |
4195 | - err := initWatchers(tt.watchers) |
4196 | - if err != nil && tt.success { |
4197 | - t.Errorf("expected success, got error: %+v", err) |
4198 | - } |
4199 | - }) |
4200 | + err = eventManager.AddWatcher(ctx, metadataWatcher) |
4201 | + if err == nil { |
4202 | + t.Errorf("expected error, had success, event manager shouldn't add same watcher twice") |
4203 | } |
4204 | } |
4205 | |
4206 | @@ -91,20 +69,16 @@ func TestRun(t *testing.T) { |
4207 | watcherID := "test-watcher" |
4208 | maxCount := 10 |
4209 | |
4210 | - err := initWatchers([]Watcher{ |
4211 | - &testWatcher{ |
4212 | - watcherID: watcherID, |
4213 | - maxCount: maxCount, |
4214 | - }, |
4215 | - }) |
4216 | + ctx := context.Background() |
4217 | + eventManager := newManager() |
4218 | |
4219 | - if err != nil { |
4220 | - t.Fatalf("Failed to init/register watcher: %+v", err) |
4221 | - } |
4222 | + err := eventManager.AddWatcher(ctx, &testWatcher{ |
4223 | + watcherID: watcherID, |
4224 | + maxCount: maxCount, |
4225 | + }) |
4226 | |
4227 | - eventManager, err := New(&Config{Watchers: []string{watcherID}}) |
4228 | if err != nil { |
4229 | - t.Fatalf("Failed to init event manager: %+v", err) |
4230 | + t.Fatalf("Failed to add watcher to event manager: %+v", err) |
4231 | } |
4232 | |
4233 | counter := 0 |
4234 | @@ -114,7 +88,9 @@ func TestRun(t *testing.T) { |
4235 | return true |
4236 | }) |
4237 | |
4238 | - eventManager.Run(context.Background()) |
4239 | + if err := eventManager.Run(ctx); err != nil { |
4240 | + t.Errorf("Failed to run event managed, expected success, got error: %+v", err) |
4241 | + } |
4242 | |
4243 | if counter != maxCount { |
4244 | t.Errorf("Failed to increment callback counter, expected: %d, got: %d", maxCount, counter) |
4245 | @@ -126,20 +102,16 @@ func TestUnsubscribe(t *testing.T) { |
4246 | maxCount := 10 |
4247 | unsubscribeAt := 2 |
4248 | |
4249 | - err := initWatchers([]Watcher{ |
4250 | - &testWatcher{ |
4251 | - watcherID: watcherID, |
4252 | - maxCount: maxCount, |
4253 | - }, |
4254 | - }) |
4255 | + ctx := context.Background() |
4256 | + eventManager := newManager() |
4257 | |
4258 | - if err != nil { |
4259 | - t.Fatalf("Failed to init/register watcher: %+v", err) |
4260 | - } |
4261 | + err := eventManager.AddWatcher(ctx, &testWatcher{ |
4262 | + watcherID: watcherID, |
4263 | + maxCount: maxCount, |
4264 | + }) |
4265 | |
4266 | - eventManager, err := New(&Config{Watchers: []string{watcherID}}) |
4267 | if err != nil { |
4268 | - t.Fatalf("Failed to init event manager: %+v", err) |
4269 | + t.Fatalf("Failed to add watcher to event manager: %+v", err) |
4270 | } |
4271 | |
4272 | counter := 0 |
4273 | @@ -151,7 +123,9 @@ func TestUnsubscribe(t *testing.T) { |
4274 | return true |
4275 | }) |
4276 | |
4277 | - eventManager.Run(context.Background()) |
4278 | + if err := eventManager.Run(ctx); err != nil { |
4279 | + t.Errorf("Failed to run event managed, expected success, got error: %+v", err) |
4280 | + } |
4281 | |
4282 | if counter != unsubscribeAt { |
4283 | t.Errorf("Failed to unsubscribe callback, expected: %d, got: %d", unsubscribeAt, counter) |
4284 | @@ -162,20 +136,16 @@ func TestCancelBeforeCallbacks(t *testing.T) { |
4285 | watcherID := "test-watcher" |
4286 | timeout := (1 * time.Second) / 100 |
4287 | |
4288 | - err := initWatchers([]Watcher{ |
4289 | - &testCancel{ |
4290 | - watcherID: watcherID, |
4291 | - timeout: timeout, |
4292 | - }, |
4293 | - }) |
4294 | + ctx, cancel := context.WithCancel(context.Background()) |
4295 | + eventManager := newManager() |
4296 | |
4297 | - if err != nil { |
4298 | - t.Fatalf("Failed to init/register watcher: %+v", err) |
4299 | - } |
4300 | + err := eventManager.AddWatcher(ctx, &testCancel{ |
4301 | + watcherID: watcherID, |
4302 | + timeout: timeout, |
4303 | + }) |
4304 | |
4305 | - eventManager, err := New(&Config{Watchers: []string{watcherID}}) |
4306 | if err != nil { |
4307 | - t.Fatalf("Failed to init event manager: %+v", err) |
4308 | + t.Fatalf("Failed to add watcher to event manager: %+v", err) |
4309 | } |
4310 | |
4311 | eventManager.Subscribe("test-watcher,test-event", nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool { |
4312 | @@ -183,13 +153,14 @@ func TestCancelBeforeCallbacks(t *testing.T) { |
4313 | return true |
4314 | }) |
4315 | |
4316 | - ctx, cancel := context.WithCancel(context.Background()) |
4317 | go func() { |
4318 | time.Sleep(timeout / 2) |
4319 | cancel() |
4320 | }() |
4321 | |
4322 | - eventManager.Run(ctx) |
4323 | + if err := eventManager.Run(ctx); err != nil { |
4324 | + t.Errorf("Failed to run event managed, expected success, got error: %+v", err) |
4325 | + } |
4326 | } |
4327 | |
4328 | type testCancel struct { |
4329 | @@ -214,33 +185,30 @@ func TestCancelAfterCallbacks(t *testing.T) { |
4330 | watcherID := "test-watcher" |
4331 | timeout := (1 * time.Second) / 100 |
4332 | |
4333 | - err := initWatchers([]Watcher{ |
4334 | - &testCancel{ |
4335 | - watcherID: watcherID, |
4336 | - timeout: timeout, |
4337 | - }, |
4338 | - }) |
4339 | + ctx, cancel := context.WithCancel(context.Background()) |
4340 | + eventManager := newManager() |
4341 | |
4342 | - if err != nil { |
4343 | - t.Fatalf("Failed to init/register watcher: %+v", err) |
4344 | - } |
4345 | + err := eventManager.AddWatcher(ctx, &testCancel{ |
4346 | + watcherID: watcherID, |
4347 | + timeout: timeout, |
4348 | + }) |
4349 | |
4350 | - eventManager, err := New(&Config{Watchers: []string{watcherID}}) |
4351 | if err != nil { |
4352 | - t.Fatalf("Failed to init event manager: %+v", err) |
4353 | + t.Fatalf("Failed to add watcher to event manager: %+v", err) |
4354 | } |
4355 | |
4356 | eventManager.Subscribe("test-watcher,test-event", nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool { |
4357 | return true |
4358 | }) |
4359 | |
4360 | - ctx, cancel := context.WithCancel(context.Background()) |
4361 | go func() { |
4362 | time.Sleep(timeout * 10) |
4363 | cancel() |
4364 | }() |
4365 | |
4366 | - eventManager.Run(ctx) |
4367 | + if err := eventManager.Run(ctx); err != nil { |
4368 | + t.Errorf("Failed to run event managed, expected success, got error: %+v", err) |
4369 | + } |
4370 | } |
4371 | |
4372 | type testCancelWatcher struct { |
4373 | @@ -285,20 +253,16 @@ func TestCancelCallbacksAndWatchers(t *testing.T) { |
4374 | t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { |
4375 | cancelSubscriberAfter := curr.cancelSubscriberAfter |
4376 | |
4377 | - err := initWatchers([]Watcher{ |
4378 | - &testCancelWatcher{ |
4379 | - watcherID: watcherID, |
4380 | - after: curr.cancelWatcherAfter, |
4381 | - }, |
4382 | - }) |
4383 | + ctx := context.Background() |
4384 | + eventManager := newManager() |
4385 | |
4386 | - if err != nil { |
4387 | - t.Fatalf("Failed to init/register watcher: %+v", err) |
4388 | - } |
4389 | + err := eventManager.AddWatcher(ctx, &testCancelWatcher{ |
4390 | + watcherID: watcherID, |
4391 | + after: curr.cancelWatcherAfter, |
4392 | + }) |
4393 | |
4394 | - eventManager, err := New(&Config{Watchers: []string{watcherID}}) |
4395 | if err != nil { |
4396 | - t.Fatalf("Failed to init event manager: %+v", err) |
4397 | + t.Fatalf("Failed to add watcher to event manager: %+v", err) |
4398 | } |
4399 | |
4400 | eventManager.Subscribe("test-watcher,test-event", nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool { |
4401 | @@ -310,7 +274,9 @@ func TestCancelCallbacksAndWatchers(t *testing.T) { |
4402 | return true |
4403 | }) |
4404 | |
4405 | - eventManager.Run(context.Background()) |
4406 | + if err := eventManager.Run(ctx); err != nil { |
4407 | + t.Errorf("Failed to run event managed, expected success, got error: %+v", err) |
4408 | + } |
4409 | }) |
4410 | } |
4411 | } |
4412 | @@ -320,20 +286,16 @@ func TestMultipleEvents(t *testing.T) { |
4413 | firstEvent := "multiple-events,first-event" |
4414 | secondEvent := "multiple-events,second-event" |
4415 | |
4416 | - err := initWatchers([]Watcher{ |
4417 | - &testMultipleEvents{ |
4418 | - watcherID: watcherID, |
4419 | - eventIDS: []string{firstEvent, secondEvent}, |
4420 | - }, |
4421 | - }) |
4422 | + ctx := context.Background() |
4423 | + eventManager := newManager() |
4424 | |
4425 | - if err != nil { |
4426 | - t.Fatalf("Failed to init/register watcher: %+v", err) |
4427 | - } |
4428 | + err := eventManager.AddWatcher(ctx, &testMultipleEvents{ |
4429 | + watcherID: watcherID, |
4430 | + eventIDS: []string{firstEvent, secondEvent}, |
4431 | + }) |
4432 | |
4433 | - eventManager, err := New(&Config{Watchers: []string{watcherID}}) |
4434 | if err != nil { |
4435 | - t.Fatalf("Failed to init event manager: %+v", err) |
4436 | + t.Fatalf("Failed to add watcher to event manager: %+v", err) |
4437 | } |
4438 | |
4439 | var hitFirstEvent bool |
4440 | @@ -348,7 +310,9 @@ func TestMultipleEvents(t *testing.T) { |
4441 | return false |
4442 | }) |
4443 | |
4444 | - eventManager.Run(context.Background()) |
4445 | + if err := eventManager.Run(ctx); err != nil { |
4446 | + t.Errorf("Failed to run event managed, expected success, got error: %+v", err) |
4447 | + } |
4448 | |
4449 | if !hitFirstEvent || !hitSecondEvent { |
4450 | t.Errorf("Failed to call back events, first event hit? (%t), second event hit? (%t)", hitFirstEvent, hitSecondEvent) |
4451 | @@ -371,3 +335,304 @@ func (tt *testMultipleEvents) Events() []string { |
4452 | func (tt *testMultipleEvents) Run(ctx context.Context, evType string) (bool, interface{}, error) { |
4453 | return false, nil, nil |
4454 | } |
4455 | + |
4456 | +func TestAddWatcherAfterRun(t *testing.T) { |
4457 | + firstWatcher := &genericWatcher{ |
4458 | + watcherID: "first-watcher", |
4459 | + shouldRenew: true, |
4460 | + } |
4461 | + |
4462 | + secondWatcher := &genericWatcher{ |
4463 | + watcherID: "second-watcher", |
4464 | + } |
4465 | + |
4466 | + ctx := context.Background() |
4467 | + eventManager := newManager() |
4468 | + |
4469 | + err := eventManager.AddWatcher(ctx, firstWatcher) |
4470 | + |
4471 | + if err != nil { |
4472 | + t.Fatalf("Failed to add watcher to event manager: %+v", err) |
4473 | + } |
4474 | + |
4475 | + eventManager.Subscribe(firstWatcher.eventID(), nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool { |
4476 | + if err := eventManager.AddWatcher(ctx, secondWatcher); err != nil { |
4477 | + t.Errorf("Failed to add a second watcher: %+v, expected success", err) |
4478 | + } |
4479 | + firstWatcher.shouldRenew = false |
4480 | + return false |
4481 | + }) |
4482 | + |
4483 | + var hitSecondEvent bool |
4484 | + eventManager.Subscribe(secondWatcher.eventID(), nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool { |
4485 | + hitSecondEvent = true |
4486 | + return false |
4487 | + }) |
4488 | + |
4489 | + if err := eventManager.Run(ctx); err != nil { |
4490 | + t.Errorf("Failed to run event managed, expected success, got error: %+v", err) |
4491 | + } |
4492 | + |
4493 | + if !hitSecondEvent { |
4494 | + t.Errorf("Failed registering second watcher, expected hitSecondEvent: false, got: %t", hitSecondEvent) |
4495 | + } |
4496 | +} |
4497 | + |
4498 | +type genericWatcher struct { |
4499 | + watcherID string |
4500 | + shouldRenew bool |
4501 | + wait time.Duration |
4502 | +} |
4503 | + |
4504 | +func (gw *genericWatcher) eventID() string { |
4505 | + return gw.watcherID + ",test-event" |
4506 | +} |
4507 | + |
4508 | +func (gw *genericWatcher) ID() string { |
4509 | + return gw.watcherID |
4510 | +} |
4511 | + |
4512 | +func (gw *genericWatcher) Events() []string { |
4513 | + return []string{gw.eventID()} |
4514 | +} |
4515 | + |
4516 | +func (gw *genericWatcher) Run(ctx context.Context, evType string) (bool, interface{}, error) { |
4517 | + if gw.wait > 0 { |
4518 | + time.Sleep(gw.wait) |
4519 | + } |
4520 | + return gw.shouldRenew, nil, nil |
4521 | +} |
4522 | + |
4523 | +func TestAddDefaultWatchers(t *testing.T) { |
4524 | + firstWatcher := &genericWatcher{ |
4525 | + watcherID: "first-watcher", |
4526 | + shouldRenew: false, |
4527 | + } |
4528 | + |
4529 | + defaultWatchers = []Watcher{ |
4530 | + firstWatcher, |
4531 | + } |
4532 | + |
4533 | + ctx := context.Background() |
4534 | + eventManager := newManager() |
4535 | + |
4536 | + err := eventManager.AddDefaultWatchers(ctx) |
4537 | + |
4538 | + if err != nil { |
4539 | + t.Fatalf("Failed to add watcher to event manager: %+v", err) |
4540 | + } |
4541 | + |
4542 | + if len(eventManager.watchersMap) == 0 { |
4543 | + t.Fatalf("Failed to add default watchers, expected: %d, got: %d", len(defaultWatchers), |
4544 | + len(eventManager.watchersMap)) |
4545 | + } |
4546 | + |
4547 | + if len(eventManager.watcherEvents) == 0 { |
4548 | + t.Fatalf("Failed to add default watchers, expected: %d, got: %d", len(defaultWatchers), |
4549 | + len(eventManager.watcherEvents)) |
4550 | + } |
4551 | +} |
4552 | + |
4553 | +func TestCallingRunTwice(t *testing.T) { |
4554 | + firstWatcher := &genericWatcher{ |
4555 | + watcherID: "first-watcher", |
4556 | + shouldRenew: false, |
4557 | + } |
4558 | + |
4559 | + defaultWatchers = []Watcher{ |
4560 | + firstWatcher, |
4561 | + } |
4562 | + |
4563 | + timeout := (1 * time.Second) / 100 |
4564 | + ctx, cancel := context.WithCancel(context.Background()) |
4565 | + eventManager := newManager() |
4566 | + |
4567 | + err := eventManager.AddDefaultWatchers(ctx) |
4568 | + |
4569 | + if err != nil { |
4570 | + t.Fatalf("Failed to add watcher to event manager: %+v", err) |
4571 | + } |
4572 | + |
4573 | + var wg sync.WaitGroup |
4574 | + wg.Add(1) |
4575 | + go func() { |
4576 | + defer wg.Done() |
4577 | + time.Sleep(timeout) |
4578 | + cancel() |
4579 | + }() |
4580 | + |
4581 | + errors := []error{} |
4582 | + wg.Add(1) |
4583 | + go func() { |
4584 | + defer wg.Done() |
4585 | + if err := eventManager.Run(ctx); err != nil { |
4586 | + errors = append(errors, err) |
4587 | + } |
4588 | + }() |
4589 | + |
4590 | + wg.Add(1) |
4591 | + go func() { |
4592 | + defer wg.Done() |
4593 | + if err := eventManager.Run(ctx); err != nil { |
4594 | + errors = append(errors, err) |
4595 | + } |
4596 | + }() |
4597 | + |
4598 | + wg.Wait() |
4599 | + |
4600 | + if len(errors) == 0 { |
4601 | + t.Errorf("Executing Run() twice should fail, we got not failure") |
4602 | + } |
4603 | + |
4604 | + if len(errors) > 1 { |
4605 | + t.Errorf("Executing Run() twice should produce a single error, got: %+v", errors) |
4606 | + } |
4607 | +} |
4608 | + |
4609 | +type testRemoveWatcher struct { |
4610 | + watcherID string |
4611 | + timeout time.Duration |
4612 | +} |
4613 | + |
4614 | +func (tc *testRemoveWatcher) ID() string { |
4615 | + return tc.watcherID |
4616 | +} |
4617 | + |
4618 | +func (tc *testRemoveWatcher) Events() []string { |
4619 | + return []string{tc.watcherID + ",test-event"} |
4620 | +} |
4621 | + |
4622 | +func (tc *testRemoveWatcher) Run(ctx context.Context, evType string) (bool, interface{}, error) { |
4623 | + select { |
4624 | + case <-ctx.Done(): |
4625 | + return false, nil, nil |
4626 | + case <-time.After(tc.timeout): |
4627 | + return true, nil, nil |
4628 | + } |
4629 | +} |
4630 | + |
4631 | +func TestRemoveWatcherBeforeCallbacks(t *testing.T) { |
4632 | + watcherID := "test-watcher" |
4633 | + timeout := (1 * time.Second) / 100 |
4634 | + |
4635 | + ctx := context.Background() |
4636 | + eventManager := newManager() |
4637 | + |
4638 | + watcher := &testRemoveWatcher{ |
4639 | + watcherID: watcherID, |
4640 | + timeout: timeout, |
4641 | + } |
4642 | + |
4643 | + err := eventManager.AddWatcher(ctx, watcher) |
4644 | + |
4645 | + if err != nil { |
4646 | + t.Fatalf("Failed to add watcher to event manager: %+v", err) |
4647 | + } |
4648 | + |
4649 | + eventManager.Subscribe("test-watcher,test-event", nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool { |
4650 | + t.Errorf("Expected to have canceled before calling callback") |
4651 | + return false |
4652 | + }) |
4653 | + |
4654 | + go func() { |
4655 | + time.Sleep(timeout / 2) |
4656 | + if err := eventManager.RemoveWatcher(ctx, watcher); err != nil { |
4657 | + t.Errorf("Failed to remove watcher: %+v", err) |
4658 | + } |
4659 | + }() |
4660 | + |
4661 | + if err := eventManager.Run(ctx); err != nil { |
4662 | + t.Errorf("Failed running event manager, expected success, got error: %+v", err) |
4663 | + } |
4664 | +} |
4665 | + |
4666 | +func TestRemoveWatcherFromCallback(t *testing.T) { |
4667 | + watcher := &genericWatcher{ |
4668 | + watcherID: "first-watcher", |
4669 | + shouldRenew: true, |
4670 | + } |
4671 | + |
4672 | + ctx := context.Background() |
4673 | + eventManager := newManager() |
4674 | + |
4675 | + err := eventManager.AddWatcher(ctx, watcher) |
4676 | + |
4677 | + if err != nil { |
4678 | + t.Fatalf("Failed to add watcher to event manager: %+v", err) |
4679 | + } |
4680 | + |
4681 | + eventManager.Subscribe(watcher.eventID(), nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool { |
4682 | + if err := eventManager.RemoveWatcher(ctx, watcher); err != nil { |
4683 | + t.Fatalf("Failed to remove watcher, it should have succeeded: %+v", err) |
4684 | + } |
4685 | + return true |
4686 | + }) |
4687 | + |
4688 | + if err := eventManager.Run(ctx); err != nil { |
4689 | + t.Errorf("Failed running event manager, expected success, got error: %+v", err) |
4690 | + } |
4691 | +} |
4692 | + |
4693 | +func TestCrossWatcherRemovalFromCallback(t *testing.T) { |
4694 | + firstWatcher := &genericWatcher{ |
4695 | + watcherID: "first-watcher", |
4696 | + shouldRenew: true, |
4697 | + } |
4698 | + |
4699 | + secondWatcher := &genericWatcher{ |
4700 | + watcherID: "second-watcher", |
4701 | + shouldRenew: true, |
4702 | + } |
4703 | + |
4704 | + thirdWatcher := &genericWatcher{ |
4705 | + watcherID: "third-watcher", |
4706 | + shouldRenew: true, |
4707 | + wait: (1 * time.Second) / 3, |
4708 | + } |
4709 | + |
4710 | + ctx := context.Background() |
4711 | + eventManager := newManager() |
4712 | + |
4713 | + watchers := []Watcher{ |
4714 | + firstWatcher, |
4715 | + secondWatcher, |
4716 | + thirdWatcher, |
4717 | + } |
4718 | + |
4719 | + for _, curr := range watchers { |
4720 | + err := eventManager.AddWatcher(ctx, curr) |
4721 | + |
4722 | + if err != nil { |
4723 | + t.Fatalf("Failed to add watcher to event manager: %+v", err) |
4724 | + } |
4725 | + } |
4726 | + |
4727 | + removed := false |
4728 | + eventManager.Subscribe(thirdWatcher.eventID(), nil, func(ctx context.Context, evType string, data interface{}, evData *EventData) bool { |
4729 | + if !removed { |
4730 | + if err := eventManager.RemoveWatcher(ctx, firstWatcher); err != nil { |
4731 | + t.Errorf("Failed to remove firstWatcher, it should have succeeded: %+v", err) |
4732 | + } |
4733 | + if err := eventManager.RemoveWatcher(ctx, secondWatcher); err != nil { |
4734 | + t.Errorf("Failed to remove secondWatcher, it should have succeeded: %+v", err) |
4735 | + } |
4736 | + removed = true |
4737 | + return true |
4738 | + } |
4739 | + |
4740 | + queueLen := eventManager.queue.length() |
4741 | + if queueLen != 1 { |
4742 | + t.Errorf("Failed to remove watcher, expected remaining watchers: 1, got: %d", queueLen) |
4743 | + } |
4744 | + |
4745 | + if err := eventManager.RemoveWatcher(ctx, thirdWatcher); err != nil { |
4746 | + t.Errorf("Failed to remove thirdWatcher, it should have succeeded: %+v", err) |
4747 | + } |
4748 | + |
4749 | + return false |
4750 | + }) |
4751 | + |
4752 | + if err := eventManager.Run(ctx); err != nil { |
4753 | + t.Errorf("Failed running event manager, expected success, got error: %+v", err) |
4754 | + } |
4755 | +} |
4756 | diff --git a/google_guest_agent/events/metadata/metadata.go b/google_guest_agent/events/metadata/metadata.go |
4757 | index 9b93cb3..80a9b15 100644 |
4758 | --- a/google_guest_agent/events/metadata/metadata.go |
4759 | +++ b/google_guest_agent/events/metadata/metadata.go |
4760 | @@ -19,7 +19,6 @@ import ( |
4761 | "context" |
4762 | "net" |
4763 | "net/url" |
4764 | - "time" |
4765 | |
4766 | "github.com/GoogleCloudPlatform/guest-agent/metadata" |
4767 | "github.com/GoogleCloudPlatform/guest-logging-go/logger" |
4768 | @@ -32,11 +31,6 @@ const ( |
4769 | LongpollEvent = "metadata-watcher,longpoll" |
4770 | ) |
4771 | |
4772 | -var ( |
4773 | - // arbitrarily defined wait duration(keeps behavioral backward compatibility). |
4774 | - retryWaitDuration = 5 * time.Second |
4775 | -) |
4776 | - |
4777 | // Watcher is the metadata event watcher implementation. |
4778 | type Watcher struct { |
4779 | client metadata.MDSClientInterface |
4780 | diff --git a/google_guest_agent/events/metadata/metadata_test.go b/google_guest_agent/events/metadata/metadata_test.go |
4781 | index 1b8aa5a..291f656 100644 |
4782 | --- a/google_guest_agent/events/metadata/metadata_test.go |
4783 | +++ b/google_guest_agent/events/metadata/metadata_test.go |
4784 | @@ -39,6 +39,10 @@ func (mds *mdsClient) GetKey(ctx context.Context, key string, headers map[string |
4785 | return "", fmt.Errorf("GetKey() not yet implemented") |
4786 | } |
4787 | |
4788 | +func (mds *mdsClient) GetKeyRecursive(ctx context.Context, key string) (string, error) { |
4789 | + return "", fmt.Errorf("GetKeyRecursive() not yet implemented") |
4790 | +} |
4791 | + |
4792 | func (mds *mdsClient) Watch(ctx context.Context) (*metadata.Descriptor, error) { |
4793 | if !mds.disableUnknownFailure { |
4794 | return nil, errUnknown |
4795 | diff --git a/google_guest_agent/fakes/fake_mds.go b/google_guest_agent/fakes/fake_mds.go |
4796 | index b6e25b1..0e8213a 100644 |
4797 | --- a/google_guest_agent/fakes/fake_mds.go |
4798 | +++ b/google_guest_agent/fakes/fake_mds.go |
4799 | @@ -37,6 +37,11 @@ func NewFakeMDSClient() *MDSClient { |
4800 | return &MDSClient{} |
4801 | } |
4802 | |
4803 | +// GetKeyRecursive implements fake GetKeyRecursive MDS method. |
4804 | +func (s MDSClient) GetKeyRecursive(ctx context.Context, key string) (string, error) { |
4805 | + return "", fmt.Errorf("GetKeyRecursive() not yet implemented") |
4806 | +} |
4807 | + |
4808 | // GetKey implements fake GetKey MDS method. |
4809 | func (s MDSClient) GetKey(ctx context.Context, key string, headers map[string]string) (string, error) { |
4810 | valid := ` |
4811 | diff --git a/google_guest_agent/instance_setup.go b/google_guest_agent/instance_setup.go |
4812 | index 38c834a..6327821 100644 |
4813 | --- a/google_guest_agent/instance_setup.go |
4814 | +++ b/google_guest_agent/instance_setup.go |
4815 | @@ -25,8 +25,10 @@ import ( |
4816 | "strings" |
4817 | "time" |
4818 | |
4819 | + "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/agentcrypto" |
4820 | "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg" |
4821 | "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/run" |
4822 | + "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/scheduler" |
4823 | "github.com/GoogleCloudPlatform/guest-logging-go/logger" |
4824 | "github.com/go-ini/ini" |
4825 | ) |
4826 | @@ -148,13 +150,14 @@ func agentInit(ctx context.Context) { |
4827 | return |
4828 | } |
4829 | |
4830 | - // The below actions require metadata to be set, so if it |
4831 | - // hasn't yet been set, wait on it here. In instances without |
4832 | - // network access, this will become an indefinite wait. |
4833 | - // TODO: split agentInit into needs-network and no-network functions. |
4834 | - for newMetadata == nil { |
4835 | - logger.Debugf("populate first time metadata...") |
4836 | - newMetadata, _ = mdsClient.Get(ctx) |
4837 | + if newMetadata == nil { |
4838 | + var err error |
4839 | + logger.Debugf("populate metadata for the first time...") |
4840 | + newMetadata, err = mdsClient.Get(ctx) |
4841 | + if err != nil { |
4842 | + logger.Errorf("Failed to reach MDS(all retries exhausted): %+v", err) |
4843 | + os.Exit(1) |
4844 | + } |
4845 | } |
4846 | |
4847 | // Disable overcommit accounting; e2 instances only. |
4848 | @@ -204,6 +207,14 @@ func agentInit(ctx context.Context) { |
4849 | } |
4850 | } |
4851 | } |
4852 | + // Schedules jobs that need to be started before notifying systemd Agent process has started. |
4853 | + // We want to generate MDS credentials as early as possible so that any process in the Guest can |
4854 | + // use them. Processes may depend on the Guest Agent at startup to ensure that the credentials are |
4855 | + // available for use. By generating the credentials before notifying the systemd, we ensure that |
4856 | + // they are generated for any process that depends on the Guest Agent. |
4857 | + if config.MDS.MTLSBootstrappingEnabled { |
4858 | + scheduler.ScheduleJobs(ctx, []scheduler.Job{agentcrypto.New()}, true) |
4859 | + } |
4860 | } |
4861 | |
4862 | func generateSSHKeys(ctx context.Context) error { |
4863 | diff --git a/google_guest_agent/instance_setup_integ_test.go b/google_guest_agent/instance_setup_integ_test.go |
4864 | deleted file mode 100644 |
4865 | index 6c343bf..0000000 |
4866 | --- a/google_guest_agent/instance_setup_integ_test.go |
4867 | +++ /dev/null |
4868 | @@ -1,208 +0,0 @@ |
4869 | -// Copyright 2021 Google LLC |
4870 | -// |
4871 | -// Licensed under the Apache License, Version 2.0 (the "License"); |
4872 | -// you may not use this file except in compliance with the License. |
4873 | -// You may obtain a copy of the License at |
4874 | -// |
4875 | -// http://www.apache.org/licenses/LICENSE-2.0 |
4876 | -// |
4877 | -// Unless required by applicable law or agreed to in writing, software |
4878 | -// distributed under the License is distributed on an "AS IS" BASIS, |
4879 | -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
4880 | -// See the License for the specific language governing permissions and |
4881 | -// limitations under the License. |
4882 | - |
4883 | -//go:build integration |
4884 | -// +build integration |
4885 | - |
4886 | -package main |
4887 | - |
4888 | -import ( |
4889 | - "context" |
4890 | - "os" |
4891 | - "path/filepath" |
4892 | - "strings" |
4893 | - "testing" |
4894 | - |
4895 | - "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg" |
4896 | -) |
4897 | - |
4898 | -const ( |
4899 | - botoCfg = "/etc/boto.cfg" |
4900 | -) |
4901 | - |
4902 | -func getConfig(t *testing.T) (*cfg.Sections, string) { |
4903 | - t.Helper() |
4904 | - |
4905 | - if err := cfg.Load(nil); err != nil { |
4906 | - t.Fatalf("Failed to load configuration: %+v", err) |
4907 | - } |
4908 | - |
4909 | - config, err := cfg.Get() |
4910 | - if err != nil { |
4911 | - t.Fatalf("Failed to get config: %+v", err) |
4912 | - } |
4913 | - |
4914 | - if config == nil { |
4915 | - t.Fatal("cfg.Get() returned a nil config") |
4916 | - } |
4917 | - |
4918 | - tempDir := filepath.Join(t.TempDir(), "test_instance_setup") |
4919 | - err := os.Mkdir(tempDir, 0755) |
4920 | - if err != nil { |
4921 | - t.Fatalf("Failed to create working dir: %+v", err) |
4922 | - } |
4923 | - |
4924 | - // Configure a non-standard instance ID dir for us to play with. |
4925 | - config.Instance.InstanceIDDir = tempDir |
4926 | - config.InstanceSetup.HostKeyDir = tempDir |
4927 | - |
4928 | - return config, tempDir |
4929 | -} |
4930 | - |
4931 | -// TestInstanceSetupSSHKeys validates SSH keys are generated on first boot and not changed afterward. |
4932 | -func TestInstanceSetupSSHKeys(t *testing.T) { |
4933 | - config, tempDir := getConfig(t) |
4934 | - ctx := context.Background() |
4935 | - agentInit(ctx) |
4936 | - |
4937 | - if _, err := os.Stat(tempDir + "/google_instance_id"); err != nil { |
4938 | - t.Fatal("instance ID File was not created by agentInit") |
4939 | - } |
4940 | - |
4941 | - dir, err := os.Open(tempDir) |
4942 | - if err != nil { |
4943 | - t.Fatal("failed to open working dir") |
4944 | - } |
4945 | - defer dir.Close() |
4946 | - |
4947 | - files, err := dir.Readdirnames(0) |
4948 | - if err != nil { |
4949 | - t.Fatal("failed to read files") |
4950 | - } |
4951 | - |
4952 | - var keys []string |
4953 | - for _, file := range files { |
4954 | - if strings.HasPrefix(file, "ssh_host_") { |
4955 | - keys = append(keys, file) |
4956 | - } |
4957 | - } |
4958 | - |
4959 | - if len(keys) == 0 { |
4960 | - t.Fatal("instance setup didn't create SSH host keys") |
4961 | - } |
4962 | - |
4963 | - // Remove one key file and run again to confirm SSH keys have not |
4964 | - // changed because the instance ID file has not changed. |
4965 | - if err := os.Remove(tempDir + "/" + keys[0]); err != nil { |
4966 | - t.Fatal("failed to remove key file") |
4967 | - } |
4968 | - |
4969 | - agentInit(ctx) |
4970 | - |
4971 | - if _, err := dir.Seek(0, 0); err != nil { |
4972 | - t.Fatal("failed to rewind dir for second check") |
4973 | - } |
4974 | - files2, err := dir.Readdirnames(0) |
4975 | - if err != nil { |
4976 | - t.Fatal("failed to read files") |
4977 | - } |
4978 | - |
4979 | - var keys2 []string |
4980 | - for _, file := range files2 { |
4981 | - if strings.HasPrefix(file, "ssh_host_") { |
4982 | - keys2 = append(keys2, file) |
4983 | - } |
4984 | - if file == keys[0] { |
4985 | - t.Fatalf("agentInit recreated key %s", file) |
4986 | - } |
4987 | - } |
4988 | - |
4989 | - if len(keys) == len(keys2) { |
4990 | - t.Fatal("agentInit recreated SSH host keys") |
4991 | - } |
4992 | -} |
4993 | - |
4994 | -// TestInstanceSetupSSHKeysDisabled validates the config option to disable host |
4995 | -// key generation is respected. |
4996 | -func TestInstanceSetupSSHKeysDisabled(t *testing.T) { |
4997 | - config, tempDir := getConfig(t) |
4998 | - |
4999 | - // Disable SSH host key generation. |
5000 | - config.InstanceSetup.SetHostKeys = false |
The diff has been truncated for viewing.
$ dput ubuntu ../google- guest-agent_ 20240213. 00-0ubuntu1_ source. changes distribution: check whether the target distribution is currently supported (using distro-info) guest-agent_ 20240213. 00-0ubuntu1. dsc guest-agent_ 20240213. 00.orig. tar.gz guest-agent_ 20240213. 00-0ubuntu1. debian. tar.xz guest-agent_ 20240213. 00-0ubuntu1_ source. buildinfo guest-agent_ 20240213. 00-0ubuntu1_ source. changes
Uploading google-guest-agent using ftp to ubuntu (host: upload.ubuntu.com; directory: /ubuntu)
running supported-
{'allowed': ['release', 'proposed', 'backports', 'security'], 'known': ['release', 'proposed', 'updates', 'backports', 'security']}
running required-fields: check whether a field is present and non-empty in the changes file
running checksum: verify checksums before uploading
running suite-mismatch: check the target distribution for common errors
running check-debs: makes sure the upload contains a binary package
running gpg: check GnuPG signatures before the upload
Uploading google-
Uploading google-
Uploading google-
Uploading google-
Uploading google-