Merge ~blake-rouse/maas:faster-tftp into maas:master

Proposed by Blake Rouse
Status: Rejected
Rejected by: Adam Collard
Proposed branch: ~blake-rouse/maas:faster-tftp
Merge into: maas:master
Diff against target: 495 lines (+251/-113)
4 files modified
required-packages/dev (+1/-0)
src/provisioningserver/monkey.py (+0/-94)
src/provisioningserver/rackdservices/tests/test_tftp.py (+58/-4)
src/provisioningserver/rackdservices/tftp.py (+192/-15)
Reviewer Review Type Date Requested Status
MAAS Maintainers Pending
Review via email: mp+364042@code.launchpad.net

Commit message

Fix time tracking with TFTP requests. No longer use monkey patching for the TFTP service as its not required. Use sendfile to send data instead of reading the data into userspace.

To post a comment you must log in.
~blake-rouse/maas:faster-tftp updated
be76a50... by Blake Rouse

Try using a threadpool to send the binary data.

f0bd30e... by Blake Rouse

Use a new eventloop on its own thread to handle read requests.

Unmerged commits

f0bd30e... by Blake Rouse

Use a new eventloop on its own thread to handle read requests.

be76a50... by Blake Rouse

Try using a threadpool to send the binary data.

af10523... by Blake Rouse

Fix time tracking with TFTP requests. Use sendfile to send data instead of reading the data into userspace.

229af11... by Blake Rouse

Try to read tftp request from the filesystem before processing the request with the boot methods.

Preview Diff

[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1diff --git a/required-packages/dev b/required-packages/dev
2index 0fc3b6d..d799095 100644
3--- a/required-packages/dev
4+++ b/required-packages/dev
5@@ -36,4 +36,5 @@ python-tempita
6 python-twisted
7 python-yaml
8 socat
9+tftp-hpa
10 xvfb
11diff --git a/src/provisioningserver/monkey.py b/src/provisioningserver/monkey.py
12index f69da32..0f23b75 100644
13--- a/src/provisioningserver/monkey.py
14+++ b/src/provisioningserver/monkey.py
15@@ -24,99 +24,6 @@ def add_term_error_code_to_tftp():
16 "Terminate transfer due to option negotiation")
17
18
19-def fix_tftp_requests():
20- """Use intelligence in determining IPv4 vs IPv6 when creatinging a session.
21-
22- Specifically, look at addr[0] and pass iface to listenUDP based on that.
23-
24- See https://bugs.launchpad.net/ubuntu/+source/python-tx-tftp/1614581
25- """
26- import tftp.protocol
27-
28- from tftp.datagram import (
29- OP_WRQ,
30- ERRORDatagram,
31- ERR_NOT_DEFINED,
32- ERR_ACCESS_VIOLATION,
33- ERR_FILE_EXISTS,
34- ERR_ILLEGAL_OP,
35- OP_RRQ,
36- ERR_FILE_NOT_FOUND
37- )
38- from tftp.bootstrap import (
39- RemoteOriginWriteSession,
40- RemoteOriginReadSession,
41- )
42- from tftp.netascii import NetasciiReceiverProxy, NetasciiSenderProxy
43- from twisted.internet import reactor
44- from twisted.internet.defer import inlineCallbacks, returnValue
45- from twisted.python.context import call
46- from tftp.errors import (
47- FileExists,
48- Unsupported,
49- AccessViolation,
50- BackendError,
51- FileNotFound,
52- )
53- from netaddr import IPAddress
54-
55- @inlineCallbacks
56- def new_startSession(self, datagram, addr, mode):
57- # Set up a call context so that we can pass extra arbitrary
58- # information to interested backends without adding extra call
59- # arguments, or switching to using a request object, for example.
60- context = {}
61- if self.transport is not None:
62- # Add the local and remote addresses to the call context.
63- local = self.transport.getHost()
64- context["local"] = local.host, local.port
65- context["remote"] = addr
66- try:
67- if datagram.opcode == OP_WRQ:
68- fs_interface = yield call(
69- context, self.backend.get_writer, datagram.filename)
70- elif datagram.opcode == OP_RRQ:
71- fs_interface = yield call(
72- context, self.backend.get_reader, datagram.filename)
73- except Unsupported as e:
74- self.transport.write(ERRORDatagram.from_code(
75- ERR_ILLEGAL_OP,
76- u"{}".format(e).encode("ascii", "replace")).to_wire(), addr)
77- except AccessViolation:
78- self.transport.write(
79- ERRORDatagram.from_code(ERR_ACCESS_VIOLATION).to_wire(), addr)
80- except FileExists:
81- self.transport.write(
82- ERRORDatagram.from_code(ERR_FILE_EXISTS).to_wire(), addr)
83- except FileNotFound:
84- self.transport.write(
85- ERRORDatagram.from_code(ERR_FILE_NOT_FOUND).to_wire(), addr)
86- except BackendError as e:
87- self.transport.write(ERRORDatagram.from_code(
88- ERR_NOT_DEFINED,
89- u"{}".format(e).encode("ascii", "replace")).to_wire(), addr)
90- else:
91- if IPAddress(addr[0]).version == 6:
92- iface = '::'
93- else:
94- iface = ''
95- if datagram.opcode == OP_WRQ:
96- if mode == b'netascii':
97- fs_interface = NetasciiReceiverProxy(fs_interface)
98- session = RemoteOriginWriteSession(
99- addr, fs_interface, datagram.options, _clock=self._clock)
100- reactor.listenUDP(0, session, iface)
101- returnValue(session)
102- elif datagram.opcode == OP_RRQ:
103- if mode == b'netascii':
104- fs_interface = NetasciiSenderProxy(fs_interface)
105- session = RemoteOriginReadSession(
106- addr, fs_interface, datagram.options, _clock=self._clock)
107- reactor.listenUDP(0, session, iface)
108- returnValue(session)
109- tftp.protocol.TFTP._startSession = new_startSession
110-
111-
112 def get_patched_URI():
113 """Create the patched `twisted.web.client.URI` to handle IPv6."""
114 import re
115@@ -343,7 +250,6 @@ def augment_twisted_deferToThreadPool():
116
117 def add_patches_to_txtftp():
118 add_term_error_code_to_tftp()
119- fix_tftp_requests()
120
121
122 def add_patches_to_twisted():
123diff --git a/src/provisioningserver/rackdservices/tests/test_tftp.py b/src/provisioningserver/rackdservices/tests/test_tftp.py
124index 64a89e7..26d21af 100644
125--- a/src/provisioningserver/rackdservices/tests/test_tftp.py
126+++ b/src/provisioningserver/rackdservices/tests/test_tftp.py
127@@ -50,8 +50,8 @@ from provisioningserver.rackdservices.tftp import (
128 Port,
129 TFTPBackend,
130 TFTPService,
131+ TransferTimeTrackingIPv6TFTP,
132 TransferTimeTrackingSession,
133- TransferTimeTrackingTFTP,
134 UDPServer,
135 )
136 from provisioningserver.rpc.exceptions import BootConfigNoResponse
137@@ -93,6 +93,7 @@ from twisted.internet.defer import (
138 )
139 from twisted.internet.protocol import Protocol
140 from twisted.internet.task import Clock
141+from twisted.internet.utils import getProcessOutputAndValue
142 from twisted.python import context
143 from zope.interface.verify import verifyObject
144
145@@ -973,6 +974,59 @@ class TestTFTPService(MAASTestCase):
146 })
147
148
149+class TestTFTPService_Sendfile(MAASTestCase):
150+
151+ run_tests_with = MAASTwistedRunTest.make_factory(timeout=30)
152+
153+ scenarios = (
154+ ('div_block_size', {
155+ 'data': factory.make_string(size=1024 * 1024).encode('utf-8')
156+ }),
157+ ('non_div_block_size', {
158+ 'data': factory.make_string(size=(1024 * 1024) + 1).encode('utf-8')
159+ }),
160+ )
161+
162+ def test_tftp_service_with_sendfile(self):
163+ # Make the fake file to read.
164+ example_root = self.make_dir()
165+ filename = factory.make_name('file')
166+ factory.make_file(example_root, filename, self.data)
167+
168+ # Create the TFTP service.
169+ example_client_service = Mock()
170+ example_port = factory.pick_port()
171+ tftp_service = TFTPService(
172+ resource_root=example_root, client_service=example_client_service,
173+ port=example_port)
174+ tftp_service.updateServers()
175+ tftp_service.startService()
176+
177+ # Run subprocess to make request for file over TFTP.
178+ output_dir = self.make_dir()
179+ cmd = [
180+ 'tftp', '-m', 'binary',
181+ 'localhost', str(example_port),
182+ '-c', 'get', filename,
183+ ]
184+
185+ def tftp_cb(result):
186+ tftp_service.stopService()
187+
188+ out, err, code = result
189+ self.assertEquals(0, code, "\nstdout: %s\nstrerr: %s" % (out, err))
190+
191+ with open(os.path.join(output_dir, filename), 'rb') as fp:
192+ content = fp.read()
193+ self.assertEqual(self.data, content)
194+
195+ self.assertFalse(os.path.exists(os.path.join(output_dir, filename)))
196+ d = getProcessOutputAndValue(
197+ cmd[0], cmd[1:], path=output_dir, reactor=reactor)
198+ d.addCallbacks(tftp_cb, tftp_cb)
199+ return d
200+
201+
202 class TestTransferTimeTrackingSession(MAASTestCase):
203
204 def test_track_time(self):
205@@ -991,12 +1045,12 @@ class TestTransferTimeTrackingSession(MAASTestCase):
206 metrics)
207
208
209-class TestTransferTimeTrackingTFTP(MAASTestCase):
210- """Tests for `TransferTimeTrackingTFTP`."""
211+class TestTransferTimeTrackingIPv6TFTP(MAASTestCase):
212+ """Tests for `TransferTimeTrackingIPv6TFTP`."""
213
214 def clean_filename(self, path):
215 datagram = RQDatagram(path, b'octet', {})
216- tftp = TransferTimeTrackingTFTP(sentinel.backend)
217+ tftp = TransferTimeTrackingIPv6TFTP(sentinel.backend)
218 return tftp._clean_filename(datagram)
219
220 def test_clean_filename(self):
221diff --git a/src/provisioningserver/rackdservices/tftp.py b/src/provisioningserver/rackdservices/tftp.py
222index 6254283..d86289c 100644
223--- a/src/provisioningserver/rackdservices/tftp.py
224+++ b/src/provisioningserver/rackdservices/tftp.py
225@@ -8,11 +8,16 @@ __all__ = [
226 "TFTPService",
227 ]
228
229+import asyncio
230+import threading
231 from functools import partial
232+import os
233+import socket
234 from socket import (
235 AF_INET,
236 AF_INET6,
237 )
238+import struct
239 from time import time
240
241 from netaddr import IPAddress
242@@ -46,12 +51,36 @@ from provisioningserver.utils.twisted import (
243 RPCFetcher,
244 )
245 from tftp.backend import FilesystemSynchronousBackend
246+from tftp.bootstrap import (
247+ RemoteOriginReadSession,
248+ RemoteOriginWriteSession,
249+ TFTPBootstrap,
250+)
251+from tftp.datagram import (
252+ ERR_ACCESS_VIOLATION,
253+ ERR_FILE_EXISTS,
254+ ERR_FILE_NOT_FOUND,
255+ ERR_ILLEGAL_OP,
256+ ERR_NOT_DEFINED,
257+ ERRORDatagram,
258+ OP_DATA,
259+ OP_RRQ,
260+ OP_WRQ,
261+)
262 from tftp.errors import (
263+ AccessViolation,
264 BackendError,
265+ FileExists,
266 FileNotFound,
267+ Unsupported,
268+)
269+from tftp.netascii import (
270+ NetasciiReceiverProxy,
271+ NetasciiSenderProxy,
272 )
273 from tftp.protocol import TFTP
274 from tftp.session import ReadSession
275+from tftp.util import timedCaller
276 from twisted.application import internet
277 from twisted.application.service import MultiService
278 from twisted.internet import (
279@@ -69,7 +98,9 @@ from twisted.internet.defer import (
280 returnValue,
281 succeed,
282 )
283+from twisted.internet.asyncioreactor import AsyncioSelectorReactor
284 from twisted.internet.task import deferLater
285+from twisted.python.context import call
286 from twisted.python.filepath import FilePath
287
288
289@@ -372,10 +403,20 @@ class TFTPBackend(FilesystemSynchronousBackend):
290 # unix compatiable.
291 file_name = file_name.replace(b'\\', b'/')
292 log_request(file_name)
293- d = self.get_boot_method(file_name)
294- d.addCallback(partial(self.handle_boot_method, file_name))
295- d.addErrback(self.no_response_errback, file_name)
296- d.addErrback(self.all_is_lost_errback)
297+
298+ def boot_methods(_):
299+ # Failure most likely because file doesn't exist. So send the
300+ # request to the boot methods.
301+ d = self.get_boot_method(file_name)
302+ d.addCallback(partial(self.handle_boot_method, file_name))
303+ d.addErrback(self.no_response_errback, file_name)
304+ d.addErrback(self.all_is_lost_errback)
305+ return d
306+
307+ # Try to load the file directly from the filesystem first, failure will
308+ # pass the request onto the boot methods for processing.
309+ d = super(TFTPBackend, self).get_reader(file_name)
310+ d.addErrback(boot_methods)
311 return d
312
313
314@@ -412,7 +453,66 @@ class UDPServer(internet.UDPServer):
315 return p
316
317
318-class TransferTimeTrackingSession(ReadSession):
319+class ReadSendfileSession(ReadSession):
320+ """A `ReadSession` that can use `os.sendfile`."""
321+
322+ def __init__(self, reader, _clock=None):
323+ super().__init__(reader, _clock)
324+
325+ def startProtocol(self):
326+ self.blocknum_noroll = 0
327+ self.use_sendfile = getattr(self.reader, 'file_obj', None) is not None
328+ if self.use_sendfile:
329+ self.size = self.reader.size
330+ super().startProtocol()
331+
332+ def nextBlock(self):
333+ """ACK datagram for the previous block has been received. Attempt to
334+ read the next block, that will be sent."""
335+ self.blocknum += 1
336+ self.blocknum_noroll += 1
337+ if self.use_sendfile:
338+ self.timeout_watchdog = timedCaller(
339+ (0,) + self.timeout, self.sendWithSendfile,
340+ self.timedOut, clock=self._clock)
341+ return succeed(None)
342+ else:
343+ d = maybeDeferred(self.reader.read, self.block_size)
344+ d.addCallbacks(
345+ callback=self.dataFromReader, errback=self.readFailed)
346+ return d
347+
348+ def sendWithSendfile(self):
349+ d = maybeDeferred(self._sendWithSendfile)
350+ d.addErrback(self.readFailed)
351+ return d
352+
353+ def _sendWithSendfile(self):
354+ # Reached maximum number of blocks. Rolling over
355+ if self.blocknum == 65536:
356+ self.blocknum = 0
357+
358+ offset = (self.blocknum_noroll - 1) * self.block_size
359+ count = self.size - offset
360+ if count > self.block_size:
361+ count = self.block_size
362+ if count == 0:
363+ self.transport.socket.send(
364+ struct.pack(b'!HH', OP_DATA, self.blocknum))
365+ self.completed = True
366+ else:
367+ self.transport.socket.send(
368+ struct.pack(b'!HH', OP_DATA, self.blocknum),
369+ socket.MSG_MORE)
370+ os.sendfile(
371+ self.transport.socket.fileno(), self.reader.file_obj.fileno(),
372+ offset, count)
373+ if count < self.block_size:
374+ self.completed = True
375+
376+
377+class TransferTimeTrackingSession(ReadSendfileSession):
378+ """A `ReadSendfileSession` that tracks the latency of sending."""
379
380 def __init__(
381 self, filename, reader, _clock=None,
382@@ -435,18 +535,95 @@ class TransferTimeTrackingSession(ReadSession):
383 super().cancel()
384
385
386-class TransferTimeTrackingTFTP(TFTP):
387+class RemoteOriginReadTimeTrackingSession(RemoteOriginReadSession):
388+ """A `RemoteOriginReadSession` that uses a
389+ `TransferTimeTrackingSession`."""
390+
391+ def __init__(
392+ self, filename, remote, reader, options=None, _clock=None):
393+ TFTPBootstrap.__init__(self, remote, reader, options, _clock)
394+ self.session = TransferTimeTrackingSession(
395+ filename, reader, self._clock)
396+
397+ def stopProtocol(self):
398+ super().stopProtocol()
399+ self._clock.stop()
400+
401+
402+def asyncio_worker(session, iface):
403+ loop = asyncio.new_event_loop()
404+ read_reactor = AsyncioSelectorReactor(eventloop=loop)
405+ session._clock = read_reactor
406+ session.session._clock = read_reactor
407+ read_reactor.listenUDP(0, session, iface)
408+ read_reactor.run(installSignalHandlers=False)
409+
410+
411+class TransferTimeTrackingIPv6TFTP(TFTP):
412+ """A `TFTP` that tracks the TFTP transfer times and supports IPv6."""
413+
414+ def __init__(self, backend, _clock=None):
415+ super().__init__(backend, _clock)
416
417 @inlineCallbacks
418 def _startSession(self, datagram, addr, mode):
419- session = yield super()._startSession(datagram, addr, mode)
420- stream_session = getattr(session, 'session', None)
421- # replace the standard ReadSession with one that tracks transfer time
422- if stream_session is not None:
423- filename = self._clean_filename(datagram)
424- session.session = TransferTimeTrackingSession(
425- filename, stream_session.reader, _clock=stream_session._clock)
426- returnValue(session)
427+ # Set up a call context so that we can pass extra arbitrary
428+ # information to interested backends without adding extra call
429+ # arguments, or switching to using a request object, for example.
430+ context = {}
431+ if self.transport is not None:
432+ # Add the local and remote addresses to the call context.
433+ local = self.transport.getHost()
434+ context["local"] = local.host, local.port
435+ context["remote"] = addr
436+ try:
437+ if datagram.opcode == OP_WRQ:
438+ fs_interface = yield call(
439+ context, self.backend.get_writer, datagram.filename)
440+ elif datagram.opcode == OP_RRQ:
441+ fs_interface = yield call(
442+ context, self.backend.get_reader, datagram.filename)
443+ except Unsupported as e:
444+ self.transport.write(ERRORDatagram.from_code(
445+ ERR_ILLEGAL_OP,
446+ u"{}".format(e).encode("ascii", "replace")).to_wire(), addr)
447+ except AccessViolation:
448+ self.transport.write(
449+ ERRORDatagram.from_code(ERR_ACCESS_VIOLATION).to_wire(), addr)
450+ except FileExists:
451+ self.transport.write(
452+ ERRORDatagram.from_code(ERR_FILE_EXISTS).to_wire(), addr)
453+ except FileNotFound:
454+ self.transport.write(
455+ ERRORDatagram.from_code(ERR_FILE_NOT_FOUND).to_wire(), addr)
456+ except BackendError as e:
457+ self.transport.write(ERRORDatagram.from_code(
458+ ERR_NOT_DEFINED,
459+ u"{}".format(e).encode("ascii", "replace")).to_wire(), addr)
460+ else:
461+ if IPAddress(addr[0]).version == 6:
462+ iface = '::'
463+ else:
464+ iface = ''
465+ if datagram.opcode == OP_WRQ:
466+ if mode == b'netascii':
467+ fs_interface = NetasciiReceiverProxy(fs_interface)
468+ session = RemoteOriginWriteSession(
469+ addr, fs_interface, datagram.options, _clock=self._clock)
470+ reactor.listenUDP(0, session, iface)
471+ returnValue(session)
472+ elif datagram.opcode == OP_RRQ:
473+ if mode == b'netascii':
474+ fs_interface = NetasciiSenderProxy(fs_interface)
475+ # Spawn a new reactor in its own thread to handle the read
476+ # request. The session will stop the reactor once complete.
477+ session = RemoteOriginReadTimeTrackingSession(
478+ self._clean_filename(datagram), addr, fs_interface,
479+ datagram.options)
480+ t = threading.Thread(
481+ target=asyncio_worker, args=(session, iface))
482+ t.start()
483+ returnValue(session)
484
485 def _clean_filename(self, datagram):
486 filename = datagram.filename.decode('ascii')
487@@ -522,7 +699,7 @@ class TFTPService(MultiService, object):
488 for address in addrs_desired - addrs_established:
489 if not IPAddress(address).is_link_local():
490 tftp_service = UDPServer(
491- self.port, TransferTimeTrackingTFTP(self.backend),
492+ self.port, TransferTimeTrackingIPv6TFTP(self.backend),
493 interface=address)
494 tftp_service.setName(address)
495 tftp_service.setServiceParent(self)

Subscribers

People subscribed via source and target branches