Merge lp:~mpontillo/maas/beaconing-packet-format into lp:~maas-committers/maas/trunk
- beaconing-packet-format
- Merge into trunk
Proposed by
Mike Pontillo
Status: | Merged |
---|---|
Approved by: | Mike Pontillo |
Approved revision: | no longer in the source branch. |
Merged at revision: | 6099 |
Proposed branch: | lp:~mpontillo/maas/beaconing-packet-format |
Merge into: | lp:~maas-committers/maas/trunk |
Diff against target: |
673 lines (+345/-74) 4 files modified
src/provisioningserver/security.py (+31/-9) src/provisioningserver/tests/test_security.py (+55/-44) src/provisioningserver/utils/beaconing.py (+135/-10) src/provisioningserver/utils/tests/test_beaconing.py (+124/-11) |
To merge this branch: | bzr merge lp:~mpontillo/maas/beaconing-packet-format |
Related bugs: |
Reviewer | Review Type | Date Requested | Status |
---|---|---|---|
Данило Шеган (community) | Approve | ||
Review via email: mp+325601@code.launchpad.net |
Commit message
Add methods to encode and decode beacon packets.
Also, change the name of the Fernet methods to make it clear that the encryption is done using a pre-shared key (the MAAS shared secret).
Minor test suite refactoring for SharedSecretTes
Description of the change
To post a comment you must log in.
Revision history for this message
Mike Pontillo (mpontillo) wrote : | # |
Thanks for the review. Some replies below. You've given me some things to think about, so I'll make this WIP for now.
Revision history for this message
Данило Шеган (danilo) wrote : | # |
Looks great, thanks for the changes!
A few more things that you probably forgot to complete with it being past 1am :-)
Not blocking since the gist of the stuff is there (basically, you are not making use of the new exception and raw=True parameter).
review:
Approve
Revision history for this message
Mike Pontillo (mpontillo) wrote : | # |
Thanks much for the review; some replies below.
Preview Diff
[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1 | === modified file 'src/provisioningserver/security.py' | |||
2 | --- src/provisioningserver/security.py 2017-06-02 17:28:04 +0000 | |||
3 | +++ src/provisioningserver/security.py 2017-06-21 15:18:45 +0000 | |||
4 | @@ -9,7 +9,10 @@ | |||
5 | 9 | "get_shared_secret_from_filesystem", | 9 | "get_shared_secret_from_filesystem", |
6 | 10 | ] | 10 | ] |
7 | 11 | 11 | ||
9 | 12 | from base64 import urlsafe_b64encode | 12 | from base64 import ( |
10 | 13 | urlsafe_b64decode, | ||
11 | 14 | urlsafe_b64encode, | ||
12 | 15 | ) | ||
13 | 13 | import binascii | 16 | import binascii |
14 | 14 | from binascii import ( | 17 | from binascii import ( |
15 | 15 | a2b_hex, | 18 | a2b_hex, |
16 | @@ -41,6 +44,10 @@ | |||
17 | 41 | ) | 44 | ) |
18 | 42 | 45 | ||
19 | 43 | 46 | ||
20 | 47 | class MissingSharedSecret(RuntimeError): | ||
21 | 48 | """Raised when the MAAS shared secret is missing.""" | ||
22 | 49 | |||
23 | 50 | |||
24 | 44 | def to_hex(b): | 51 | def to_hex(b): |
25 | 45 | """Convert byte string to hex encoding.""" | 52 | """Convert byte string to hex encoding.""" |
26 | 46 | assert isinstance(b, bytes), "%r is not a byte string" % (b,) | 53 | assert isinstance(b, bytes), "%r is not a byte string" % (b,) |
27 | @@ -140,7 +147,8 @@ | |||
28 | 140 | global _fernet_psk | 147 | global _fernet_psk |
29 | 141 | if _fernet_psk is None: | 148 | if _fernet_psk is None: |
30 | 142 | secret = get_shared_secret_from_filesystem() | 149 | secret = get_shared_secret_from_filesystem() |
32 | 143 | assert secret is not None, "MAAS shared secret not found." | 150 | if secret is None: |
33 | 151 | raise MissingSharedSecret("MAAS shared secret not found.") | ||
34 | 144 | # Keying material is required by PBKDF2 to be a byte string. | 152 | # Keying material is required by PBKDF2 to be a byte string. |
35 | 145 | kdf = PBKDF2HMAC( | 153 | kdf = PBKDF2HMAC( |
36 | 146 | algorithm=hashes.SHA256(), | 154 | algorithm=hashes.SHA256(), |
37 | @@ -171,7 +179,7 @@ | |||
38 | 171 | return f | 179 | return f |
39 | 172 | 180 | ||
40 | 173 | 181 | ||
42 | 174 | def fernet_encrypt(message): | 182 | def fernet_encrypt_psk(message, raw=False): |
43 | 175 | """Encrypts the specified message using the Fernet format. | 183 | """Encrypts the specified message using the Fernet format. |
44 | 176 | 184 | ||
45 | 177 | Returns the encrypted token, as a byte string. | 185 | Returns the encrypted token, as a byte string. |
46 | @@ -183,29 +191,43 @@ | |||
47 | 183 | 191 | ||
48 | 184 | :param message: The message to encrypt. | 192 | :param message: The message to encrypt. |
49 | 185 | :type message: Must be of type 'bytes' or a UTF-8 'str'. | 193 | :type message: Must be of type 'bytes' or a UTF-8 'str'. |
50 | 194 | :param raw: if True, returns the decoded base64 bytes representing the | ||
51 | 195 | Fernet token. The bytes must be converted back to base64 to be | ||
52 | 196 | decrypted. (Or the 'raw' argument on the corresponding | ||
53 | 197 | fernet_decrypt_psk() function can be used.) | ||
54 | 186 | :return: the encryption token, as a base64-encoded byte string. | 198 | :return: the encryption token, as a base64-encoded byte string. |
55 | 187 | """ | 199 | """ |
57 | 188 | f = _get_fernet_context() | 200 | fernet = _get_fernet_context() |
58 | 189 | if isinstance(message, str): | 201 | if isinstance(message, str): |
59 | 190 | message = message.encode("utf-8") | 202 | message = message.encode("utf-8") |
64 | 191 | return f.encrypt(message) | 203 | token = fernet.encrypt(message) |
65 | 192 | 204 | if raw is True: | |
66 | 193 | 205 | token = urlsafe_b64decode(token) | |
67 | 194 | def fernet_decrypt(token, ttl=None): | 206 | return token |
68 | 207 | |||
69 | 208 | |||
70 | 209 | def fernet_decrypt_psk(token, ttl=None, raw=False): | ||
71 | 195 | """Decrypts the specified Fernet token using the MAAS secret. | 210 | """Decrypts the specified Fernet token using the MAAS secret. |
72 | 196 | 211 | ||
73 | 197 | Returns the decrypted token as a byte string; the user is responsible for | 212 | Returns the decrypted token as a byte string; the user is responsible for |
74 | 198 | converting it to the correct format or encoding. | 213 | converting it to the correct format or encoding. |
75 | 199 | 214 | ||
76 | 200 | :param message: The token to decrypt. | 215 | :param message: The token to decrypt. |
78 | 201 | :type token: bytes | 216 | :type token: Must be of type 'bytes', or an ASCII base64 string. |
79 | 202 | :param ttl: Optional amount of time (in seconds) allowed to have elapsed | 217 | :param ttl: Optional amount of time (in seconds) allowed to have elapsed |
80 | 203 | before the message is rejected upon decryption. Note that the Fernet | 218 | before the message is rejected upon decryption. Note that the Fernet |
81 | 204 | library considers times up to 60 seconds into the future (beyond the | 219 | library considers times up to 60 seconds into the future (beyond the |
82 | 205 | TTL) to be valid. | 220 | TTL) to be valid. |
83 | 221 | :param raw: if True, treats the string as the decoded base64 bytes of a | ||
84 | 222 | Fernet token, and attempts to encode them (as expected by the Fernet | ||
85 | 223 | APIs) before decrypting. | ||
86 | 206 | :return: bytes | 224 | :return: bytes |
87 | 207 | """ | 225 | """ |
88 | 226 | if raw is True: | ||
89 | 227 | token = urlsafe_b64encode(token) | ||
90 | 208 | f = _get_fernet_context() | 228 | f = _get_fernet_context() |
91 | 229 | if isinstance(token, str): | ||
92 | 230 | token = token.encode("ascii") | ||
93 | 209 | return f.decrypt(token, ttl=ttl) | 231 | return f.decrypt(token, ttl=ttl) |
94 | 210 | 232 | ||
95 | 211 | 233 | ||
96 | 212 | 234 | ||
97 | === modified file 'src/provisioningserver/tests/test_security.py' | |||
98 | --- src/provisioningserver/tests/test_security.py 2017-06-02 17:28:04 +0000 | |||
99 | +++ src/provisioningserver/tests/test_security.py 2017-06-21 15:18:45 +0000 | |||
100 | @@ -28,8 +28,9 @@ | |||
101 | 28 | from provisioningserver import security | 28 | from provisioningserver import security |
102 | 29 | from provisioningserver.path import get_data_path | 29 | from provisioningserver.path import get_data_path |
103 | 30 | from provisioningserver.security import ( | 30 | from provisioningserver.security import ( |
106 | 31 | fernet_decrypt, | 31 | fernet_decrypt_psk, |
107 | 32 | fernet_encrypt, | 32 | fernet_encrypt_psk, |
108 | 33 | MissingSharedSecret, | ||
109 | 33 | ) | 34 | ) |
110 | 34 | from provisioningserver.utils.fs import ( | 35 | from provisioningserver.utils.fs import ( |
111 | 35 | FileLock, | 36 | FileLock, |
112 | @@ -37,7 +38,10 @@ | |||
113 | 37 | write_text_file, | 38 | write_text_file, |
114 | 38 | ) | 39 | ) |
115 | 39 | from testtools import ExpectedException | 40 | from testtools import ExpectedException |
117 | 40 | from testtools.matchers import Equals | 41 | from testtools.matchers import ( |
118 | 42 | Equals, | ||
119 | 43 | IsInstance, | ||
120 | 44 | ) | ||
121 | 41 | 45 | ||
122 | 42 | 46 | ||
123 | 43 | class SharedSecretTestCase(MAASTestCase): | 47 | class SharedSecretTestCase(MAASTestCase): |
124 | @@ -49,17 +53,24 @@ | |||
125 | 49 | # so that tests cannot interfere with each other. | 53 | # so that tests cannot interfere with each other. |
126 | 50 | get_secret.return_value = get_data_path( | 54 | get_secret.return_value = get_data_path( |
127 | 51 | "var", "lib", "maas", "secret-%s" % factory.make_string(16)) | 55 | "var", "lib", "maas", "secret-%s" % factory.make_string(16)) |
128 | 52 | secret_file = security.get_shared_secret_filesystem_path() | ||
129 | 53 | # Extremely unlikely, but just in case. | 56 | # Extremely unlikely, but just in case. |
132 | 54 | if os.path.isfile(secret_file): | 57 | self.delete_secret() |
133 | 55 | os.remove(secret_file) | 58 | self.addCleanup( |
134 | 59 | setattr, security, "DEFAULT_ITERATION_COUNT", | ||
135 | 60 | security.DEFAULT_ITERATION_COUNT) | ||
136 | 61 | # The default high iteration count would make the tests very slow. | ||
137 | 62 | security.DEFAULT_ITERATION_COUNT = 2 | ||
138 | 56 | super().setUp() | 63 | super().setUp() |
139 | 57 | 64 | ||
140 | 58 | def tearDown(self): | 65 | def tearDown(self): |
141 | 66 | self.delete_secret() | ||
142 | 67 | super().tearDown() | ||
143 | 68 | |||
144 | 69 | def delete_secret(self): | ||
145 | 70 | security._fernet_psk = None | ||
146 | 59 | secret_file = security.get_shared_secret_filesystem_path() | 71 | secret_file = security.get_shared_secret_filesystem_path() |
147 | 60 | if os.path.isfile(secret_file): | 72 | if os.path.isfile(secret_file): |
148 | 61 | os.remove(secret_file) | 73 | os.remove(secret_file) |
149 | 62 | super().tearDown() | ||
150 | 63 | 74 | ||
151 | 64 | def write_secret(self): | 75 | def write_secret(self): |
152 | 65 | secret = factory.make_bytes() | 76 | secret = factory.make_bytes() |
153 | @@ -121,6 +132,12 @@ | |||
154 | 121 | 132 | ||
155 | 122 | class TestSetSharedSecretOnFilesystem(MAASTestCase): | 133 | class TestSetSharedSecretOnFilesystem(MAASTestCase): |
156 | 123 | 134 | ||
157 | 135 | def test__default_iteration_count_is_reasonably_large(self): | ||
158 | 136 | # Ensure that the iteration count is high by default. This is very | ||
159 | 137 | # important so that the MAAS secret cannot be determined by | ||
160 | 138 | # brute-force. | ||
161 | 139 | self.assertThat(security.DEFAULT_ITERATION_COUNT, Equals(100000)) | ||
162 | 140 | |||
163 | 124 | def read_secret(self): | 141 | def read_secret(self): |
164 | 125 | secret_path = security.get_shared_secret_filesystem_path() | 142 | secret_path = security.get_shared_secret_filesystem_path() |
165 | 126 | secret_hex = read_text_file(secret_path) | 143 | secret_hex = read_text_file(secret_path) |
166 | @@ -299,40 +316,23 @@ | |||
167 | 299 | 316 | ||
168 | 300 | class TestFernetEncryption(SharedSecretTestCase): | 317 | class TestFernetEncryption(SharedSecretTestCase): |
169 | 301 | 318 | ||
170 | 302 | def setUp(self): | ||
171 | 303 | security._fernet_psk = None | ||
172 | 304 | # Ensure that the iteration count is high by default. This is very | ||
173 | 305 | # important so that the MAAS secret cannot be determined by | ||
174 | 306 | # brute-force. As a side effect, this ensures our tearDown (which | ||
175 | 307 | # resets the iteration count to its default) works properly. | ||
176 | 308 | self.assertThat(security.DEFAULT_ITERATION_COUNT, Equals(100000)) | ||
177 | 309 | self._previous_iteration_count = security.DEFAULT_ITERATION_COUNT | ||
178 | 310 | # The default high iteration count would make the tests very slow. | ||
179 | 311 | security.DEFAULT_ITERATION_COUNT = 2 | ||
180 | 312 | super().setUp() | ||
181 | 313 | |||
182 | 314 | def tearDown(self): | ||
183 | 315 | security._fernet_psk = None | ||
184 | 316 | security.DEFAULT_ITERATION_COUNT = self._previous_iteration_count | ||
185 | 317 | super().tearDown() | ||
186 | 318 | |||
187 | 319 | def test__first_encrypt_caches_psk(self): | 319 | def test__first_encrypt_caches_psk(self): |
188 | 320 | self.write_secret() | 320 | self.write_secret() |
189 | 321 | self.assertIsNone(security._fernet_psk) | 321 | self.assertIsNone(security._fernet_psk) |
190 | 322 | testdata = factory.make_string() | 322 | testdata = factory.make_string() |
192 | 323 | fernet_encrypt(testdata) | 323 | fernet_encrypt_psk(testdata) |
193 | 324 | self.assertIsNotNone(security._fernet_psk) | 324 | self.assertIsNotNone(security._fernet_psk) |
194 | 325 | 325 | ||
195 | 326 | def test__derives_identical_key_on_decrypt(self): | 326 | def test__derives_identical_key_on_decrypt(self): |
196 | 327 | self.write_secret() | 327 | self.write_secret() |
197 | 328 | self.assertIsNone(security._fernet_psk) | 328 | self.assertIsNone(security._fernet_psk) |
198 | 329 | testdata = factory.make_bytes() | 329 | testdata = factory.make_bytes() |
200 | 330 | token = fernet_encrypt(testdata) | 330 | token = fernet_encrypt_psk(testdata) |
201 | 331 | first_key = security._fernet_psk | 331 | first_key = security._fernet_psk |
202 | 332 | # Make it seem like we're decrypting something without ever encrypting | 332 | # Make it seem like we're decrypting something without ever encrypting |
203 | 333 | # anything first. | 333 | # anything first. |
204 | 334 | security._fernet_psk = None | 334 | security._fernet_psk = None |
206 | 335 | decrypted = fernet_decrypt(token) | 335 | decrypted = fernet_decrypt_psk(token) |
207 | 336 | second_key = security._fernet_psk | 336 | second_key = security._fernet_psk |
208 | 337 | self.assertEqual(first_key, second_key) | 337 | self.assertEqual(first_key, second_key) |
209 | 338 | self.assertEqual(testdata, decrypted) | 338 | self.assertEqual(testdata, decrypted) |
210 | @@ -340,29 +340,40 @@ | |||
211 | 340 | def test__can_encrypt_and_decrypt_string(self): | 340 | def test__can_encrypt_and_decrypt_string(self): |
212 | 341 | self.write_secret() | 341 | self.write_secret() |
213 | 342 | testdata = factory.make_string() | 342 | testdata = factory.make_string() |
217 | 343 | token = fernet_encrypt(testdata) | 343 | token = fernet_encrypt_psk(testdata) |
218 | 344 | decrypted = fernet_decrypt(token) | 344 | # Round-trip this to a string, since Fernet tokens are used inside |
219 | 345 | decrypted = decrypted.decode("utf-8") | 345 | # strings (such as JSON objects) typically. |
220 | 346 | token = token.decode("ascii") | ||
221 | 347 | decrypted = fernet_decrypt_psk(token) | ||
222 | 348 | decrypted = decrypted.decode("ascii") | ||
223 | 349 | self.assertThat(decrypted, Equals(testdata)) | ||
224 | 350 | |||
225 | 351 | def test__can_encrypt_and_decrypt_with_raw_bytes(self): | ||
226 | 352 | self.write_secret() | ||
227 | 353 | testdata = factory.make_bytes() | ||
228 | 354 | token = fernet_encrypt_psk(testdata, raw=True) | ||
229 | 355 | self.assertThat(token, IsInstance(bytes)) | ||
230 | 356 | decrypted = fernet_decrypt_psk(token, raw=True) | ||
231 | 346 | self.assertThat(decrypted, Equals(testdata)) | 357 | self.assertThat(decrypted, Equals(testdata)) |
232 | 347 | 358 | ||
233 | 348 | def test__can_encrypt_and_decrypt_bytes(self): | 359 | def test__can_encrypt_and_decrypt_bytes(self): |
234 | 349 | self.write_secret() | 360 | self.write_secret() |
235 | 350 | testdata = factory.make_bytes() | 361 | testdata = factory.make_bytes() |
238 | 351 | token = fernet_encrypt(testdata) | 362 | token = fernet_encrypt_psk(testdata) |
239 | 352 | decrypted = fernet_decrypt(token) | 363 | decrypted = fernet_decrypt_psk(token) |
240 | 353 | self.assertThat(decrypted, Equals(testdata)) | 364 | self.assertThat(decrypted, Equals(testdata)) |
241 | 354 | 365 | ||
242 | 355 | def test__raises_when_no_secret_exists(self): | 366 | def test__raises_when_no_secret_exists(self): |
243 | 356 | testdata = factory.make_bytes() | 367 | testdata = factory.make_bytes() |
248 | 357 | with ExpectedException(AssertionError): | 368 | with ExpectedException(MissingSharedSecret): |
249 | 358 | fernet_encrypt(testdata) | 369 | fernet_encrypt_psk(testdata) |
250 | 359 | with ExpectedException(AssertionError): | 370 | with ExpectedException(MissingSharedSecret): |
251 | 360 | fernet_decrypt(b"") | 371 | fernet_decrypt_psk(b"") |
252 | 361 | 372 | ||
253 | 362 | def test__assures_data_integrity(self): | 373 | def test__assures_data_integrity(self): |
254 | 363 | self.write_secret() | 374 | self.write_secret() |
255 | 364 | testdata = factory.make_bytes(size=10) | 375 | testdata = factory.make_bytes(size=10) |
257 | 365 | token = fernet_encrypt(testdata) | 376 | token = fernet_encrypt_psk(testdata) |
258 | 366 | bad_token = bytearray(token) | 377 | bad_token = bytearray(token) |
259 | 367 | # Flip a bit in the token, so we can ensure it won't decrypt if it | 378 | # Flip a bit in the token, so we can ensure it won't decrypt if it |
260 | 368 | # has been corrupted. Subtract 4 to avoid the end of the token; that | 379 | # has been corrupted. Subtract 4 to avoid the end of the token; that |
261 | @@ -374,30 +385,30 @@ | |||
262 | 374 | test_description = ("token=%s; token[%d] ^= 0x%02x" % ( | 385 | test_description = ("token=%s; token[%d] ^= 0x%02x" % ( |
263 | 375 | token.decode("utf-8"), byte_to_flip, bit_to_flip)) | 386 | token.decode("utf-8"), byte_to_flip, bit_to_flip)) |
264 | 376 | with ExpectedException(InvalidToken, msg=test_description): | 387 | with ExpectedException(InvalidToken, msg=test_description): |
266 | 377 | fernet_decrypt(bad_token) | 388 | fernet_decrypt_psk(bad_token) |
267 | 378 | 389 | ||
268 | 379 | def test__messages_from_up_to_a_minute_in_the_future_accepted(self): | 390 | def test__messages_from_up_to_a_minute_in_the_future_accepted(self): |
269 | 380 | self.write_secret() | 391 | self.write_secret() |
270 | 381 | testdata = factory.make_bytes() | 392 | testdata = factory.make_bytes() |
271 | 382 | now = time.time() | 393 | now = time.time() |
272 | 383 | self.patch(time, "time").side_effect = [now + 60, now] | 394 | self.patch(time, "time").side_effect = [now + 60, now] |
275 | 384 | token = fernet_encrypt(testdata) | 395 | token = fernet_encrypt_psk(testdata) |
276 | 385 | fernet_decrypt(token, ttl=1) | 396 | fernet_decrypt_psk(token, ttl=1) |
277 | 386 | 397 | ||
278 | 387 | def test__messages_from_the_past_exceeding_ttl_rejected(self): | 398 | def test__messages_from_the_past_exceeding_ttl_rejected(self): |
279 | 388 | self.write_secret() | 399 | self.write_secret() |
280 | 389 | testdata = factory.make_bytes() | 400 | testdata = factory.make_bytes() |
281 | 390 | now = time.time() | 401 | now = time.time() |
282 | 391 | self.patch(time, "time").side_effect = [now - 2, now] | 402 | self.patch(time, "time").side_effect = [now - 2, now] |
284 | 392 | token = fernet_encrypt(testdata) | 403 | token = fernet_encrypt_psk(testdata) |
285 | 393 | with ExpectedException(InvalidToken): | 404 | with ExpectedException(InvalidToken): |
287 | 394 | fernet_decrypt(token, ttl=1) | 405 | fernet_decrypt_psk(token, ttl=1) |
288 | 395 | 406 | ||
289 | 396 | def test__messages_from_future_exceeding_clock_skew_limit_rejected(self): | 407 | def test__messages_from_future_exceeding_clock_skew_limit_rejected(self): |
290 | 397 | self.write_secret() | 408 | self.write_secret() |
291 | 398 | testdata = factory.make_bytes() | 409 | testdata = factory.make_bytes() |
292 | 399 | now = time.time() | 410 | now = time.time() |
293 | 400 | self.patch(time, "time").side_effect = [now + 61, now] | 411 | self.patch(time, "time").side_effect = [now + 61, now] |
295 | 401 | token = fernet_encrypt(testdata) | 412 | token = fernet_encrypt_psk(testdata) |
296 | 402 | with ExpectedException(InvalidToken): | 413 | with ExpectedException(InvalidToken): |
298 | 403 | fernet_decrypt(token, ttl=1) | 414 | fernet_decrypt_psk(token, ttl=1) |
299 | 404 | 415 | ||
300 | === modified file 'src/provisioningserver/utils/beaconing.py' | |||
301 | --- src/provisioningserver/utils/beaconing.py 2017-06-13 05:26:53 +0000 | |||
302 | +++ src/provisioningserver/utils/beaconing.py 2017-06-21 15:18:45 +0000 | |||
303 | @@ -5,18 +5,35 @@ | |||
304 | 5 | 5 | ||
305 | 6 | __all__ = [ | 6 | __all__ = [ |
306 | 7 | "BeaconingPacket", | 7 | "BeaconingPacket", |
307 | 8 | "BeaconPayload", | ||
308 | 9 | "InvalidBeaconingPacket", | ||
309 | 10 | "create_beacon_payload", | ||
310 | 11 | "read_beacon_payload", | ||
311 | 8 | "add_arguments", | 12 | "add_arguments", |
312 | 9 | "run" | 13 | "run" |
313 | 10 | ] | 14 | ] |
314 | 11 | 15 | ||
315 | 16 | from collections import namedtuple | ||
316 | 17 | from gzip import ( | ||
317 | 18 | compress, | ||
318 | 19 | decompress, | ||
319 | 20 | ) | ||
320 | 12 | import json | 21 | import json |
321 | 13 | import os | 22 | import os |
322 | 14 | import stat | 23 | import stat |
323 | 24 | import struct | ||
324 | 15 | import subprocess | 25 | import subprocess |
325 | 16 | import sys | 26 | import sys |
326 | 17 | from textwrap import dedent | 27 | from textwrap import dedent |
327 | 28 | import uuid | ||
328 | 18 | 29 | ||
330 | 19 | import bson | 30 | from bson import BSON |
331 | 31 | from bson.errors import BSONError | ||
332 | 32 | from cryptography.fernet import InvalidToken | ||
333 | 33 | from provisioningserver.security import ( | ||
334 | 34 | fernet_decrypt_psk, | ||
335 | 35 | fernet_encrypt_psk, | ||
336 | 36 | ) | ||
337 | 20 | from provisioningserver.utils import sudo | 37 | from provisioningserver.utils import sudo |
338 | 21 | from provisioningserver.utils.network import format_eui | 38 | from provisioningserver.utils.network import format_eui |
339 | 22 | from provisioningserver.utils.pcap import ( | 39 | from provisioningserver.utils.pcap import ( |
340 | @@ -30,8 +47,119 @@ | |||
341 | 30 | ) | 47 | ) |
342 | 31 | 48 | ||
343 | 32 | 49 | ||
344 | 50 | BEACON_PORT = 5240 | ||
345 | 51 | |||
346 | 52 | BEACON_TYPES = { | ||
347 | 53 | "solicitation": 1, | ||
348 | 54 | "advertisement": 2 | ||
349 | 55 | } | ||
350 | 56 | |||
351 | 57 | BEACON_TYPE_VALUES = { | ||
352 | 58 | value: name for name, value in BEACON_TYPES.items() | ||
353 | 59 | } | ||
354 | 60 | |||
355 | 61 | PROTOCOL_VERSION = 1 | ||
356 | 62 | BEACON_HEADER_FORMAT_V1 = "!BBH" | ||
357 | 63 | BEACON_HEADER_LENGTH_V1 = 4 | ||
358 | 64 | |||
359 | 65 | |||
360 | 66 | BeaconPayload = namedtuple('BeaconPayload', ( | ||
361 | 67 | 'bytes', | ||
362 | 68 | 'version', | ||
363 | 69 | 'type', | ||
364 | 70 | 'payload', | ||
365 | 71 | )) | ||
366 | 72 | |||
367 | 73 | |||
368 | 74 | def create_beacon_payload(beacon_type, payload=None, version=PROTOCOL_VERSION): | ||
369 | 75 | """Creates a beacon payload of the specified type, with the given data. | ||
370 | 76 | |||
371 | 77 | :param beacon_type: The beacon packet type. Indicates the purpose of the | ||
372 | 78 | beacon to the receiver. | ||
373 | 79 | :param payload: Optional JSON-encodable dictionary. Will be converted to an | ||
374 | 80 | inner encrypted payload and presented in the "data" field in the | ||
375 | 81 | resulting dictionary. | ||
376 | 82 | :param version: Optional protocol version to use (defaults to most recent). | ||
377 | 83 | :return: BeaconPayload namedtuple representing the packet bytes, the outer | ||
378 | 84 | payload, and the inner encrypted data (if any). | ||
379 | 85 | """ | ||
380 | 86 | beacon_type_code = BEACON_TYPES[beacon_type] | ||
381 | 87 | if payload is not None: | ||
382 | 88 | payload = payload.copy() | ||
383 | 89 | payload["uuid"] = str(uuid.uuid1()) | ||
384 | 90 | payload["type"] = beacon_type_code | ||
385 | 91 | data_bytes = BSON.encode(payload) | ||
386 | 92 | compressed_bytes = compress(data_bytes, compresslevel=9) | ||
387 | 93 | payload_bytes = fernet_encrypt_psk(compressed_bytes, raw=True) | ||
388 | 94 | else: | ||
389 | 95 | payload_bytes = b'' | ||
390 | 96 | beacon_bytes = struct.pack( | ||
391 | 97 | BEACON_HEADER_FORMAT_V1 + "%ds" % len(payload_bytes), | ||
392 | 98 | version, beacon_type_code, len(payload_bytes), payload_bytes) | ||
393 | 99 | return BeaconPayload( | ||
394 | 100 | beacon_bytes, version, BEACON_TYPE_VALUES[beacon_type_code], payload) | ||
395 | 101 | |||
396 | 102 | |||
397 | 103 | def read_beacon_payload(beacon_bytes): | ||
398 | 104 | """Returns a BeaconPayload namedtuple representing the given beacon bytes. | ||
399 | 105 | |||
400 | 106 | Decrypts the inner beacon data if necessary. | ||
401 | 107 | |||
402 | 108 | :param beacon_bytes: beacon payload (bytes). | ||
403 | 109 | :return: dict | ||
404 | 110 | """ | ||
405 | 111 | if len(beacon_bytes) < BEACON_HEADER_LENGTH_V1: | ||
406 | 112 | raise InvalidBeaconingPacket( | ||
407 | 113 | "Beaconing packet must be at least %d bytes." % ( | ||
408 | 114 | BEACON_HEADER_LENGTH_V1)) | ||
409 | 115 | header = beacon_bytes[:BEACON_HEADER_LENGTH_V1] | ||
410 | 116 | version, beacon_type_code, expected_payload_length = struct.unpack( | ||
411 | 117 | BEACON_HEADER_FORMAT_V1, header) | ||
412 | 118 | actual_payload_length = len(beacon_bytes) - BEACON_HEADER_LENGTH_V1 | ||
413 | 119 | if len(beacon_bytes) - BEACON_HEADER_LENGTH_V1 < expected_payload_length: | ||
414 | 120 | raise InvalidBeaconingPacket( | ||
415 | 121 | "Invalid payload length: expected %d bytes, got %d bytes." % ( | ||
416 | 122 | expected_payload_length, actual_payload_length)) | ||
417 | 123 | payload_start = BEACON_HEADER_LENGTH_V1 | ||
418 | 124 | payload_end = BEACON_HEADER_LENGTH_V1 + expected_payload_length | ||
419 | 125 | payload_bytes = beacon_bytes[payload_start:payload_end] | ||
420 | 126 | payload = None | ||
421 | 127 | if version == 1: | ||
422 | 128 | if len(payload_bytes) == 0: | ||
423 | 129 | # No encrypted inner payload; nothing to do. | ||
424 | 130 | pass | ||
425 | 131 | else: | ||
426 | 132 | try: | ||
427 | 133 | decrypted_data = fernet_decrypt_psk( | ||
428 | 134 | payload_bytes, raw=True) | ||
429 | 135 | except InvalidToken: | ||
430 | 136 | raise InvalidBeaconingPacket( | ||
431 | 137 | "Failed to decrypt inner payload: check MAAS secret key.") | ||
432 | 138 | try: | ||
433 | 139 | decompressed_data = decompress(decrypted_data) | ||
434 | 140 | except OSError: | ||
435 | 141 | raise InvalidBeaconingPacket( | ||
436 | 142 | "Failed to decompress inner payload: %r" % decrypted_data) | ||
437 | 143 | try: | ||
438 | 144 | # Replace the data in the dictionary with its decrypted form. | ||
439 | 145 | payload = BSON.decode(decompressed_data) | ||
440 | 146 | except BSONError: | ||
441 | 147 | raise InvalidBeaconingPacket( | ||
442 | 148 | "Inner beacon payload is not BSON: %r" % decompressed_data) | ||
443 | 149 | else: | ||
444 | 150 | raise InvalidBeaconingPacket( | ||
445 | 151 | "Unknown beacon version: %d" % version) | ||
446 | 152 | beacon_type_code = payload["type"] if payload else beacon_type_code | ||
447 | 153 | return BeaconPayload( | ||
448 | 154 | beacon_bytes, version, BEACON_TYPE_VALUES[beacon_type_code], payload) | ||
449 | 155 | |||
450 | 156 | |||
451 | 33 | class InvalidBeaconingPacket(Exception): | 157 | class InvalidBeaconingPacket(Exception): |
453 | 34 | """Raised internally when a beaconing packet is not valid.""" | 158 | """Raised when a beaconing packet is not valid.""" |
454 | 159 | |||
455 | 160 | def __init__(self, invalid_reason): | ||
456 | 161 | self.invalid_reason = invalid_reason | ||
457 | 162 | super().__init__(invalid_reason) | ||
458 | 35 | 163 | ||
459 | 36 | 164 | ||
460 | 37 | class BeaconingPacket: | 165 | class BeaconingPacket: |
461 | @@ -57,12 +185,12 @@ | |||
462 | 57 | :param out: An object with `write(str)` and `flush()` methods. | 185 | :param out: An object with `write(str)` and `flush()` methods. |
463 | 58 | """ | 186 | """ |
464 | 59 | try: | 187 | try: |
466 | 60 | payload = bson.decode_all(self.packet) | 188 | payload = read_beacon_payload(self.packet) |
467 | 61 | self.valid = True | 189 | self.valid = True |
468 | 62 | return payload | 190 | return payload |
470 | 63 | except bson.InvalidBSON: | 191 | except InvalidBeaconingPacket as ibp: |
471 | 64 | self.valid = False | 192 | self.valid = False |
473 | 65 | self.invalid_reason = "Packet payload is not BSON." | 193 | self.invalid_reason = ibp.invalid_reason |
474 | 66 | return None | 194 | return None |
475 | 67 | 195 | ||
476 | 68 | 196 | ||
477 | @@ -90,6 +218,8 @@ | |||
478 | 90 | "destination_mac": format_eui(packet.l2.dst_eui), | 218 | "destination_mac": format_eui(packet.l2.dst_eui), |
479 | 91 | "source_ip": str(packet.l3.src_ip), | 219 | "source_ip": str(packet.l3.src_ip), |
480 | 92 | "destination_ip": str(packet.l3.dst_ip), | 220 | "destination_ip": str(packet.l3.dst_ip), |
481 | 221 | "source_port": packet.l4.packet.src_port, | ||
482 | 222 | "destination_port": packet.l4.packet.dst_port, | ||
483 | 93 | } | 223 | } |
484 | 94 | if packet.l2.vid is not None: | 224 | if packet.l2.vid is not None: |
485 | 95 | output_json["vid"] = packet.l2.vid | 225 | output_json["vid"] = packet.l2.vid |
486 | @@ -99,11 +229,6 @@ | |||
487 | 99 | out.write(json.dumps(output_json)) | 229 | out.write(json.dumps(output_json)) |
488 | 100 | out.write('\n') | 230 | out.write('\n') |
489 | 101 | out.flush() | 231 | out.flush() |
490 | 102 | else: | ||
491 | 103 | err.write( | ||
492 | 104 | "Invalid beacon payload (not BSON): %r.\n" % ( | ||
493 | 105 | beacon.packet)) | ||
494 | 106 | err.flush() | ||
495 | 107 | except PacketProcessingError as e: | 232 | except PacketProcessingError as e: |
496 | 108 | err.write(e.error) | 233 | err.write(e.error) |
497 | 109 | err.write("\n") | 234 | err.write("\n") |
498 | 110 | 235 | ||
499 | === modified file 'src/provisioningserver/utils/tests/test_beaconing.py' | |||
500 | --- src/provisioningserver/utils/tests/test_beaconing.py 2017-06-13 16:09:05 +0000 | |||
501 | +++ src/provisioningserver/utils/tests/test_beaconing.py 2017-06-21 15:18:45 +0000 | |||
502 | @@ -6,39 +6,152 @@ | |||
503 | 6 | __all__ = [] | 6 | __all__ = [] |
504 | 7 | 7 | ||
505 | 8 | from argparse import ArgumentParser | 8 | from argparse import ArgumentParser |
506 | 9 | from gzip import compress | ||
507 | 9 | import io | 10 | import io |
508 | 11 | import random | ||
509 | 12 | import struct | ||
510 | 10 | import subprocess | 13 | import subprocess |
511 | 11 | from tempfile import NamedTemporaryFile | 14 | from tempfile import NamedTemporaryFile |
512 | 12 | from unittest.mock import Mock | 15 | from unittest.mock import Mock |
513 | 16 | from uuid import UUID | ||
514 | 13 | 17 | ||
515 | 14 | from bson import BSON | ||
516 | 15 | from maastesting.factory import factory | 18 | from maastesting.factory import factory |
517 | 16 | from maastesting.matchers import MockCalledOnceWith | 19 | from maastesting.matchers import MockCalledOnceWith |
518 | 17 | from maastesting.testcase import MAASTestCase | 20 | from maastesting.testcase import MAASTestCase |
519 | 21 | from provisioningserver.security import ( | ||
520 | 22 | fernet_encrypt_psk, | ||
521 | 23 | MissingSharedSecret, | ||
522 | 24 | ) | ||
523 | 25 | from provisioningserver.tests.test_security import SharedSecretTestCase | ||
524 | 18 | from provisioningserver.utils import beaconing as beaconing_module | 26 | from provisioningserver.utils import beaconing as beaconing_module |
525 | 19 | from provisioningserver.utils.beaconing import ( | 27 | from provisioningserver.utils.beaconing import ( |
526 | 20 | add_arguments, | 28 | add_arguments, |
527 | 29 | BEACON_HEADER_FORMAT_V1, | ||
528 | 30 | BEACON_TYPES, | ||
529 | 21 | BeaconingPacket, | 31 | BeaconingPacket, |
530 | 32 | create_beacon_payload, | ||
531 | 33 | InvalidBeaconingPacket, | ||
532 | 34 | read_beacon_payload, | ||
533 | 22 | run, | 35 | run, |
534 | 23 | ) | 36 | ) |
535 | 24 | from provisioningserver.utils.script import ActionScriptError | 37 | from provisioningserver.utils.script import ActionScriptError |
536 | 38 | from testtools.matchers import ( | ||
537 | 39 | Equals, | ||
538 | 40 | Is, | ||
539 | 41 | IsInstance, | ||
540 | 42 | ) | ||
541 | 25 | from testtools.testcase import ExpectedException | 43 | from testtools.testcase import ExpectedException |
542 | 26 | 44 | ||
543 | 27 | 45 | ||
548 | 28 | def make_beaconing_packet(payload): | 46 | class TestCreateBeaconPayload(SharedSecretTestCase): |
549 | 29 | # Beaconing packets are BSON-encoded byte strings. | 47 | |
550 | 30 | beaconing_packet = BSON.encode(payload) | 48 | def test__requires_maas_shared_secret_for_inner_data_payload(self): |
551 | 31 | return beaconing_packet | 49 | with ExpectedException( |
552 | 50 | MissingSharedSecret, ".*shared secret not found.*"): | ||
553 | 51 | create_beacon_payload("solicitation", payload={}) | ||
554 | 52 | |||
555 | 53 | def test__returns_beaconpayload_namedtuple(self): | ||
556 | 54 | beacon = create_beacon_payload("solicitation") | ||
557 | 55 | self.assertThat(beacon.bytes, IsInstance(bytes)) | ||
558 | 56 | self.assertThat(beacon.payload, Is(None)) | ||
559 | 57 | self.assertThat(beacon.type, Equals("solicitation")) | ||
560 | 58 | self.assertThat(beacon.version, Equals(1)) | ||
561 | 59 | |||
562 | 60 | def test__succeeds_when_shared_secret_present(self): | ||
563 | 61 | self.write_secret() | ||
564 | 62 | beacon = create_beacon_payload( | ||
565 | 63 | "solicitation", payload={}) | ||
566 | 64 | self.assertThat(beacon.type, Equals("solicitation")) | ||
567 | 65 | self.assertThat( | ||
568 | 66 | beacon.payload['type'], Equals(BEACON_TYPES["solicitation"])) | ||
569 | 67 | |||
570 | 68 | def test__supplements_data_and_returns_complete_data(self): | ||
571 | 69 | self.write_secret() | ||
572 | 70 | random_type = random.choice(list(BEACON_TYPES.keys())) | ||
573 | 71 | random_key = factory.make_string(prefix="_") | ||
574 | 72 | random_value = factory.make_string() | ||
575 | 73 | beacon = create_beacon_payload( | ||
576 | 74 | random_type, payload={random_key: random_value}) | ||
577 | 75 | # Ensure a valid UUID was added. | ||
578 | 76 | self.assertIsNotNone(UUID(beacon.payload['uuid'])) | ||
579 | 77 | self.assertThat(beacon.type, Equals(random_type)) | ||
580 | 78 | # The type is replicated here for authentication purposes. | ||
581 | 79 | self.assertThat( | ||
582 | 80 | beacon.payload['type'], Equals(BEACON_TYPES[random_type])) | ||
583 | 81 | self.assertThat(beacon.payload[random_key], Equals(random_value)) | ||
584 | 82 | |||
585 | 83 | def test__creates_packet_that_can_decode(self): | ||
586 | 84 | self.write_secret() | ||
587 | 85 | random_type = random.choice(list(BEACON_TYPES.keys())) | ||
588 | 86 | random_key = factory.make_string(prefix="_") | ||
589 | 87 | random_value = factory.make_string() | ||
590 | 88 | packet_bytes, _, _, _ = create_beacon_payload( | ||
591 | 89 | random_type, payload={random_key: random_value}) | ||
592 | 90 | decrypted = read_beacon_payload(packet_bytes) | ||
593 | 91 | self.assertThat(decrypted.type, Equals(random_type)) | ||
594 | 92 | self.assertThat(decrypted.payload[random_key], Equals(random_value)) | ||
595 | 93 | |||
596 | 94 | |||
597 | 95 | def _make_beacon_payload(version=1, type_code=1, length=None, payload=None): | ||
598 | 96 | if payload is None: | ||
599 | 97 | payload = b'' | ||
600 | 98 | if length is None: | ||
601 | 99 | length = len(payload) | ||
602 | 100 | packet = struct.pack(BEACON_HEADER_FORMAT_V1, version, type_code, length) | ||
603 | 101 | return packet + payload | ||
604 | 102 | |||
605 | 103 | |||
606 | 104 | class TestReadBeaconPayload(SharedSecretTestCase): | ||
607 | 105 | |||
608 | 106 | def test__raises_if_packet_too_small(self): | ||
609 | 107 | with ExpectedException( | ||
610 | 108 | InvalidBeaconingPacket, ".*packet must be at least 4 bytes.*"): | ||
611 | 109 | read_beacon_payload(b"") | ||
612 | 110 | |||
613 | 111 | def test__raises_if_payload_too_small(self): | ||
614 | 112 | packet = _make_beacon_payload(payload=b'1234')[:6] | ||
615 | 113 | with ExpectedException( | ||
616 | 114 | InvalidBeaconingPacket, ".*expected 4 bytes, got 2 bytes.*"): | ||
617 | 115 | read_beacon_payload(packet) | ||
618 | 116 | |||
619 | 117 | def test__raises_when_version_incorrect(self): | ||
620 | 118 | packet = _make_beacon_payload(version=0xfe) | ||
621 | 119 | with ExpectedException( | ||
622 | 120 | InvalidBeaconingPacket, ".*Unknown beacon version.*"): | ||
623 | 121 | read_beacon_payload(packet) | ||
624 | 122 | |||
625 | 123 | def test__raises_when_inner_payload_does_not_decrypt(self): | ||
626 | 124 | self.write_secret() | ||
627 | 125 | packet = _make_beacon_payload(payload=b'\xfe') | ||
628 | 126 | with ExpectedException( | ||
629 | 127 | InvalidBeaconingPacket, ".*Failed to decrypt.*"): | ||
630 | 128 | read_beacon_payload(packet) | ||
631 | 129 | |||
632 | 130 | def test__raises_when_inner_encapsulation_does_not_decompress(self): | ||
633 | 131 | self.write_secret() | ||
634 | 132 | packet = _make_beacon_payload( | ||
635 | 133 | payload=fernet_encrypt_psk('\n\n', raw=True)) | ||
636 | 134 | with ExpectedException( | ||
637 | 135 | InvalidBeaconingPacket, ".*Failed to decompress.*"): | ||
638 | 136 | read_beacon_payload(packet) | ||
639 | 137 | |||
640 | 138 | def test__raises_when_inner_encapsulation_is_not_bson(self): | ||
641 | 139 | self.write_secret() | ||
642 | 140 | payload = fernet_encrypt_psk(compress(b"\n\n"), raw=True) | ||
643 | 141 | packet = _make_beacon_payload(payload=payload) | ||
644 | 142 | with ExpectedException( | ||
645 | 143 | InvalidBeaconingPacket, ".*beacon payload is not BSON.*"): | ||
646 | 144 | read_beacon_payload(packet) | ||
647 | 32 | 145 | ||
648 | 33 | 146 | ||
649 | 34 | class TestBeaconingPacket(MAASTestCase): | 147 | class TestBeaconingPacket(MAASTestCase): |
650 | 35 | 148 | ||
655 | 36 | def test__is_valid__succeeds_for_valid_bson(self): | 149 | def test__is_valid__succeeds_for_valid_payload(self): |
656 | 37 | packet = make_beaconing_packet({"testing": 123}) | 150 | beacon = create_beacon_payload("solicitation") |
657 | 38 | beacon = BeaconingPacket(packet) | 151 | beacon_packet = BeaconingPacket(beacon.bytes) |
658 | 39 | self.assertTrue(beacon.valid) | 152 | self.assertTrue(beacon_packet.valid) |
659 | 40 | 153 | ||
661 | 41 | def test__is_valid__fails_for_invalid_bson(self): | 154 | def test__is_valid__fails_for_invalid_payload(self): |
662 | 42 | beacon = BeaconingPacket(b"\n\n\n\n") | 155 | beacon = BeaconingPacket(b"\n\n\n\n") |
663 | 43 | self.assertFalse(beacon.valid) | 156 | self.assertFalse(beacon.valid) |
664 | 44 | 157 | ||
665 | @@ -48,7 +161,7 @@ | |||
666 | 48 | b'\x00@\x00\x00\x01\x00\x00\x00v\xe19Y\xadF\x08\x00^\x00\x00\x00^\x00\x00' | 161 | b'\x00@\x00\x00\x01\x00\x00\x00v\xe19Y\xadF\x08\x00^\x00\x00\x00^\x00\x00' |
667 | 49 | b'\x00\x01\x00^\x00\x00v\x00\x16>\x91zz\x08\x00E\x00\x00P\xe2E@\x00\x01' | 162 | b'\x00\x01\x00^\x00\x00v\x00\x16>\x91zz\x08\x00E\x00\x00P\xe2E@\x00\x01' |
668 | 50 | b'\x11\xe0\xce\xac\x10*\x02\xe0\x00\x00v\xda\xc2\x14x\x00<h(4\x00\x00\x00' | 163 | b'\x11\xe0\xce\xac\x10*\x02\xe0\x00\x00v\xda\xc2\x14x\x00<h(4\x00\x00\x00' |
670 | 51 | b'\x02uuid\x00%\x00\x00\x0078d1a4f0-4ca4-11e7-b2bb-00163e917a7a\x00\x00') | 164 | b'\x02uuid\x00%\x00\x00\x0000000000-0000-0000-0000-000000000000\x00\x00') |
671 | 52 | 165 | ||
672 | 53 | 166 | ||
673 | 54 | class TestObserveBeaconsCommand(MAASTestCase): | 167 | class TestObserveBeaconsCommand(MAASTestCase): |
I do have a few questions inline. Looks generally good, though.