Merge lp:~djfroofy/txaws/modernize-924459 into lp:txaws
- modernize-924459
- Merge into trunk
Proposed by
Drew Smathers
Status: | Merged |
---|---|
Approved by: | Duncan McGreggor |
Approved revision: | 146 |
Merged at revision: | 145 |
Proposed branch: | lp:~djfroofy/txaws/modernize-924459 |
Merge into: | lp:txaws |
Diff against target: |
670 lines (+459/-28) 5 files modified
txaws/client/_producers.py (+122/-0) txaws/client/base.py (+178/-13) txaws/client/tests/test_base.py (+116/-11) txaws/client/tests/test_ssl.py (+20/-4) txaws/testing/producers.py (+23/-0) |
To merge this branch: | bzr merge lp:~djfroofy/txaws/modernize-924459 |
Related bugs: |
Reviewer | Review Type | Date Requested | Status |
---|---|---|---|
Duncan McGreggor | Approve | ||
Review via email: mp+92404@code.launchpad.net |
Commit message
Description of the change
Need to care how much we care about compatibility since: (1) This requires Twisted >= 11.1.0 so far as I know and (2) Some public members on BaseQuery no longer exist since they are not applicable when using Agent.
To post a comment you must log in.
- 145. By Drew Smathers
-
cherry pick: change StringIOReceiver to generic StreamingBodyRe
ceiver - 146. By Drew Smathers
-
check for response code in _handle_response and errback if >= 400 with response body
Revision history for this message
Duncan McGreggor (oubiwann) wrote : | # |
Preview Diff
[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1 | === added file 'txaws/client/_producers.py' |
2 | --- txaws/client/_producers.py 1970-01-01 00:00:00 +0000 |
3 | +++ txaws/client/_producers.py 2012-03-15 05:30:25 +0000 |
4 | @@ -0,0 +1,122 @@ |
5 | +import os |
6 | + |
7 | +from zope.interface import implements |
8 | + |
9 | +from twisted.internet import defer, task |
10 | +from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer |
11 | + |
12 | + |
13 | +# Code below for FileBodyProducer cut-and-paste from twisted source. |
14 | +# Currently this is not released so here temporarily for forward compat. |
15 | + |
16 | + |
17 | +class FileBodyProducer(object): |
18 | + """ |
19 | + L{FileBodyProducer} produces bytes from an input file object incrementally |
20 | + and writes them to a consumer. |
21 | + |
22 | + Since file-like objects cannot be read from in an event-driven manner, |
23 | + L{FileBodyProducer} uses a L{Cooperator} instance to schedule reads from |
24 | + the file. This process is also paused and resumed based on notifications |
25 | + from the L{IConsumer} provider being written to. |
26 | + |
27 | + The file is closed after it has been read, or if the producer is stopped |
28 | + early. |
29 | + |
30 | + @ivar _inputFile: Any file-like object, bytes read from which will be |
31 | + written to a consumer. |
32 | + |
33 | + @ivar _cooperate: A method like L{Cooperator.cooperate} which is used to |
34 | + schedule all reads. |
35 | + |
36 | + @ivar _readSize: The number of bytes to read from C{_inputFile} at a time. |
37 | + """ |
38 | + implements(IBodyProducer) |
39 | + |
40 | + # Python 2.4 doesn't have these symbolic constants |
41 | + _SEEK_SET = getattr(os, 'SEEK_SET', 0) |
42 | + _SEEK_END = getattr(os, 'SEEK_END', 2) |
43 | + |
44 | + def __init__(self, inputFile, cooperator=task, readSize=2 ** 16): |
45 | + self._inputFile = inputFile |
46 | + self._cooperate = cooperator.cooperate |
47 | + self._readSize = readSize |
48 | + self.length = self._determineLength(inputFile) |
49 | + |
50 | + |
51 | + def _determineLength(self, fObj): |
52 | + """ |
53 | + Determine how many bytes can be read out of C{fObj} (assuming it is not |
54 | + modified from this point on). If the determination cannot be made, |
55 | + return C{UNKNOWN_LENGTH}. |
56 | + """ |
57 | + try: |
58 | + seek = fObj.seek |
59 | + tell = fObj.tell |
60 | + except AttributeError: |
61 | + return UNKNOWN_LENGTH |
62 | + originalPosition = tell() |
63 | + seek(0, self._SEEK_END) |
64 | + end = tell() |
65 | + seek(originalPosition, self._SEEK_SET) |
66 | + return end - originalPosition |
67 | + |
68 | + |
69 | + def stopProducing(self): |
70 | + """ |
71 | + Permanently stop writing bytes from the file to the consumer by |
72 | + stopping the underlying L{CooperativeTask}. |
73 | + """ |
74 | + self._inputFile.close() |
75 | + self._task.stop() |
76 | + |
77 | + |
78 | + def startProducing(self, consumer): |
79 | + """ |
80 | + Start a cooperative task which will read bytes from the input file and |
81 | + write them to C{consumer}. Return a L{Deferred} which fires after all |
82 | + bytes have been written. |
83 | + |
84 | + @param consumer: Any L{IConsumer} provider |
85 | + """ |
86 | + self._task = self._cooperate(self._writeloop(consumer)) |
87 | + d = self._task.whenDone() |
88 | + def maybeStopped(reason): |
89 | + # IBodyProducer.startProducing's Deferred isn't support to fire if |
90 | + # stopProducing is called. |
91 | + reason.trap(task.TaskStopped) |
92 | + return defer.Deferred() |
93 | + d.addCallbacks(lambda ignored: None, maybeStopped) |
94 | + return d |
95 | + |
96 | + |
97 | + def _writeloop(self, consumer): |
98 | + """ |
99 | + Return an iterator which reads one chunk of bytes from the input file |
100 | + and writes them to the consumer for each time it is iterated. |
101 | + """ |
102 | + while True: |
103 | + bytes = self._inputFile.read(self._readSize) |
104 | + if not bytes: |
105 | + self._inputFile.close() |
106 | + break |
107 | + consumer.write(bytes) |
108 | + yield None |
109 | + |
110 | + |
111 | + def pauseProducing(self): |
112 | + """ |
113 | + Temporarily suspend copying bytes from the input file to the consumer |
114 | + by pausing the L{CooperativeTask} which drives that activity. |
115 | + """ |
116 | + self._task.pause() |
117 | + |
118 | + |
119 | + def resumeProducing(self): |
120 | + """ |
121 | + Undo the effects of a previous C{pauseProducing} and resume copying |
122 | + bytes to the consumer by resuming the L{CooperativeTask} which drives |
123 | + the write activity. |
124 | + """ |
125 | + self._task.resume() |
126 | + |
127 | |
128 | === modified file 'txaws/client/base.py' |
129 | --- txaws/client/base.py 2011-11-29 08:17:54 +0000 |
130 | +++ txaws/client/base.py 2012-03-15 05:30:25 +0000 |
131 | @@ -3,10 +3,25 @@ |
132 | except ImportError: |
133 | from xml.parsers.expat import ExpatError as ParseError |
134 | |
135 | +import warnings |
136 | +from StringIO import StringIO |
137 | + |
138 | from twisted.internet.ssl import ClientContextFactory |
139 | +from twisted.internet.protocol import Protocol |
140 | +from twisted.internet.defer import Deferred, succeed, fail |
141 | +from twisted.python import failure |
142 | from twisted.web import http |
143 | +from twisted.web.iweb import UNKNOWN_LENGTH |
144 | from twisted.web.client import HTTPClientFactory |
145 | +from twisted.web.client import Agent |
146 | +from twisted.web.client import ResponseDone |
147 | +from twisted.web.http import NO_CONTENT |
148 | +from twisted.web.http_headers import Headers |
149 | from twisted.web.error import Error as TwistedWebError |
150 | +try: |
151 | + from twisted.web.client import FileBodyProducer |
152 | +except ImportError: |
153 | + from txaws.client._producers import FileBodyProducer |
154 | |
155 | from txaws.util import parse |
156 | from txaws.credentials import AWSCredentials |
157 | @@ -71,20 +86,122 @@ |
158 | self.query_factory = query_factory |
159 | self.parser = parser |
160 | |
161 | +class StreamingError(Exception): |
162 | + """ |
163 | + Raised if more data or less data is received than expected. |
164 | + """ |
165 | + |
166 | + |
167 | +class StreamingBodyReceiver(Protocol): |
168 | + """ |
169 | + Streaming HTTP response body receiver. |
170 | + |
171 | + TODO: perhaps there should be an interface specifying why |
172 | + finished (Deferred) and content_length are necessary and |
173 | + how to used them; eg. callback/errback finished on completion. |
174 | + """ |
175 | + finished = None |
176 | + content_length = None |
177 | + |
178 | + def __init__(self, fd=None, readback=True): |
179 | + """ |
180 | + @param fd: a file descriptor to write to |
181 | + @param readback: if True read back data from fd to callback finished |
182 | + with, otherwise we call back finish with fd itself |
183 | + with |
184 | + """ |
185 | + if fd is None: |
186 | + fd = StringIO() |
187 | + self._fd = fd |
188 | + self._received = 0 |
189 | + self._readback = readback |
190 | + |
191 | + def dataReceived(self, bytes): |
192 | + streaming = self.content_length is UNKNOWN_LENGTH |
193 | + if not streaming and (self._received > self.content_length): |
194 | + self.transport.loseConnection() |
195 | + raise StreamingError( |
196 | + "Buffer overflow - received more data than " |
197 | + "Content-Length dictated: %d" % self.content_length) |
198 | + # TODO should be some limit on how much we receive |
199 | + self._fd.write(bytes) |
200 | + self._received += len(bytes) |
201 | + |
202 | + def connectionLost(self, reason): |
203 | + reason.trap(ResponseDone) |
204 | + d = self.finished |
205 | + self.finished = None |
206 | + streaming = self.content_length is UNKNOWN_LENGTH |
207 | + if streaming or (self._received == self.content_length): |
208 | + if self._readback: |
209 | + self._fd.seek(0) |
210 | + data = self._fd.read() |
211 | + self._fd.close() |
212 | + self._fd = None |
213 | + d.callback(data) |
214 | + else: |
215 | + d.callback(self._fd) |
216 | + else: |
217 | + f = failure.Failure(StreamingError("Connection lost before " |
218 | + "receiving all data")) |
219 | + d.errback(f) |
220 | + |
221 | + |
222 | +class WebClientContextFactory(ClientContextFactory): |
223 | + |
224 | + def getContext(self, hostname, port): |
225 | + return ClientContextFactory.getContext(self) |
226 | + |
227 | + |
228 | +class WebVerifyingContextFactory(VerifyingContextFactory): |
229 | + |
230 | + def getContext(self, hostname, port): |
231 | + return VerifyingContextFactory.getContext(self) |
232 | + |
233 | + |
234 | +class FakeClient(object): |
235 | + """ |
236 | + XXX |
237 | + A fake client object for some degree of backwards compatability for |
238 | + code using the client attibute on BaseQuery to check url, status |
239 | + etc. |
240 | + """ |
241 | + url = None |
242 | + status = None |
243 | |
244 | class BaseQuery(object): |
245 | |
246 | - def __init__(self, action=None, creds=None, endpoint=None, reactor=None): |
247 | + def __init__(self, action=None, creds=None, endpoint=None, reactor=None, |
248 | + body_producer=None, receiver_factory=None): |
249 | if not action: |
250 | raise TypeError("The query requires an action parameter.") |
251 | - self.factory = HTTPClientFactory |
252 | self.action = action |
253 | self.creds = creds |
254 | self.endpoint = endpoint |
255 | if reactor is None: |
256 | from twisted.internet import reactor |
257 | self.reactor = reactor |
258 | - self.client = None |
259 | + self._client = None |
260 | + self.request_headers = None |
261 | + self.response_headers = None |
262 | + self.body_producer = body_producer |
263 | + self.receiver_factory = receiver_factory or StreamingBodyReceiver |
264 | + |
265 | + @property |
266 | + def client(self): |
267 | + if self._client is None: |
268 | + self._client_deprecation_warning() |
269 | + self._client = FakeClient() |
270 | + return self._client |
271 | + |
272 | + @client.setter |
273 | + def client(self, value): |
274 | + self._client_deprecation_warning() |
275 | + self._client = value |
276 | + |
277 | + def _client_deprecation_warning(self): |
278 | + warnings.warn('The client attribute on BaseQuery is deprecated and' |
279 | + ' will go away in future release.') |
280 | |
281 | def get_page(self, url, *args, **kwds): |
282 | """ |
283 | @@ -95,16 +212,39 @@ |
284 | """ |
285 | contextFactory = None |
286 | scheme, host, port, path = parse(url) |
287 | - self.client = self.factory(url, *args, **kwds) |
288 | + data = kwds.get('postdata', None) |
289 | + self._method = method = kwds.get('method', 'GET') |
290 | + self.request_headers = self._headers(kwds.get('headers', {})) |
291 | + if (self.body_producer is None) and (data is not None): |
292 | + self.body_producer = FileBodyProducer(StringIO(data)) |
293 | if scheme == "https": |
294 | if self.endpoint.ssl_hostname_verification: |
295 | - contextFactory = VerifyingContextFactory(host) |
296 | + contextFactory = WebVerifyingContextFactory(host) |
297 | else: |
298 | - contextFactory = ClientContextFactory() |
299 | - self.reactor.connectSSL(host, port, self.client, contextFactory) |
300 | + contextFactory = WebClientContextFactory() |
301 | + agent = Agent(self.reactor, contextFactory) |
302 | + self.client.url = url |
303 | + d = agent.request(method, url, self.request_headers, |
304 | + self.body_producer) |
305 | else: |
306 | - self.reactor.connectTCP(host, port, self.client) |
307 | - return self.client.deferred |
308 | + agent = Agent(self.reactor) |
309 | + d = agent.request(method, url, self.request_headers, |
310 | + self.body_producer) |
311 | + d.addCallback(self._handle_response) |
312 | + return d |
313 | + |
314 | + def _headers(self, headers_dict): |
315 | + """ |
316 | + Convert dictionary of headers into twisted.web.client.Headers object. |
317 | + """ |
318 | + return Headers(dict((k,[v]) for (k,v) in headers_dict.items())) |
319 | + |
320 | + def _unpack_headers(self, headers): |
321 | + """ |
322 | + Unpack twisted.web.client.Headers object to dict. This is to provide |
323 | + backwards compatability. |
324 | + """ |
325 | + return dict((k,v[0]) for (k,v) in headers.getAllRawHeaders()) |
326 | |
327 | def get_request_headers(self, *args, **kwds): |
328 | """ |
329 | @@ -114,8 +254,32 @@ |
330 | The AWS S3 API depends upon setting headers. This method is provided as |
331 | a convenience for debugging issues with the S3 communications. |
332 | """ |
333 | - if self.client: |
334 | - return self.client.headers |
335 | + if self.request_headers: |
336 | + return self._unpack_headers(self.request_headers) |
337 | + |
338 | + def _handle_response(self, response): |
339 | + """ |
340 | + Handle the HTTP response by memoing the headers and then delivering |
341 | + bytes. |
342 | + """ |
343 | + self.client.status = response.code |
344 | + self.response_headers = headers = response.headers |
345 | + # XXX This workaround (which needs to be improved at that) for possible |
346 | + # bug in Twisted with new client: |
347 | + # http://twistedmatrix.com/trac/ticket/5476 |
348 | + if self._method.upper() == 'HEAD' or response.code == NO_CONTENT: |
349 | + return succeed('') |
350 | + receiver = self.receiver_factory() |
351 | + receiver.finished = d = Deferred() |
352 | + receiver.content_length = response.length |
353 | + response.deliverBody(receiver) |
354 | + if response.code >= 400: |
355 | + d.addCallback(self._fail_response, response) |
356 | + return d |
357 | + |
358 | + def _fail_response(self, data, response): |
359 | + return fail(failure.Failure( |
360 | + TwistedWebError(response.code, response=data))) |
361 | |
362 | def get_response_headers(self, *args, **kwargs): |
363 | """ |
364 | @@ -125,5 +289,6 @@ |
365 | The AWS S3 API depends upon setting headers. This method is used by the |
366 | head_object API call for getting a S3 object's metadata. |
367 | """ |
368 | - if self.client: |
369 | - return self.client.response_headers |
370 | + if self.response_headers: |
371 | + return self._unpack_headers(self.response_headers) |
372 | + |
373 | |
374 | === modified file 'txaws/client/tests/test_base.py' |
375 | --- txaws/client/tests/test_base.py 2012-01-26 18:43:48 +0000 |
376 | +++ txaws/client/tests/test_base.py 2012-03-15 05:30:25 +0000 |
377 | @@ -1,6 +1,11 @@ |
378 | import os |
379 | |
380 | +from StringIO import StringIO |
381 | + |
382 | +from zope.interface import implements |
383 | + |
384 | from twisted.internet import reactor |
385 | +from twisted.internet.defer import succeed, Deferred |
386 | from twisted.internet.error import ConnectionRefusedError |
387 | from twisted.protocols.policies import WrappingFactory |
388 | from twisted.python import log |
389 | @@ -8,14 +13,18 @@ |
390 | from twisted.python.failure import Failure |
391 | from twisted.test.test_sslverify import makeCertificate |
392 | from twisted.web import server, static |
393 | +from twisted.web.iweb import IBodyProducer |
394 | from twisted.web.client import HTTPClientFactory |
395 | +from twisted.web.client import ResponseDone |
396 | +from twisted.web.resource import Resource |
397 | from twisted.web.error import Error as TwistedWebError |
398 | |
399 | from txaws.client import ssl |
400 | from txaws.client.base import BaseClient, BaseQuery, error_wrapper |
401 | +from txaws.client.base import StreamingBodyReceiver |
402 | from txaws.service import AWSServiceEndpoint |
403 | from txaws.testing.base import TXAWSTestCase |
404 | - |
405 | +from txaws.testing.producers import StringBodyProducer |
406 | |
407 | class ErrorWrapperTestCase(TXAWSTestCase): |
408 | |
409 | @@ -63,6 +72,12 @@ |
410 | self.assertEquals(client.parser, "parser") |
411 | |
412 | |
413 | +class PuttableResource(Resource): |
414 | + |
415 | + def render_PUT(self, reuqest): |
416 | + return '' |
417 | + |
418 | + |
419 | class BaseQueryTestCase(TXAWSTestCase): |
420 | |
421 | def setUp(self): |
422 | @@ -71,6 +86,7 @@ |
423 | os.mkdir(name) |
424 | FilePath(name).child("file").setContent("0123456789") |
425 | r = static.File(name) |
426 | + r.putChild('thing_to_put', PuttableResource()) |
427 | self.site = server.Site(r, timeout=None) |
428 | self.wrapper = WrappingFactory(self.site) |
429 | self.port = self._listen(self.wrapper) |
430 | @@ -99,7 +115,6 @@ |
431 | |
432 | def test_creation(self): |
433 | query = BaseQuery("an action", "creds", "http://endpoint") |
434 | - self.assertEquals(query.factory, HTTPClientFactory) |
435 | self.assertEquals(query.action, "an action") |
436 | self.assertEquals(query.creds, "creds") |
437 | self.assertEquals(query.endpoint, "http://endpoint") |
438 | @@ -142,16 +157,58 @@ |
439 | def test_get_response_headers_with_client(self): |
440 | |
441 | def check_results(results): |
442 | + #self.assertEquals(sorted(results.keys()), [ |
443 | + # "accept-ranges", "content-length", "content-type", "date", |
444 | + # "last-modified", "server"]) |
445 | + # XXX I think newclient exludes content-length from headers? |
446 | + # Also the header names are capitalized ... do we need to worry |
447 | + # about backwards compat? |
448 | self.assertEquals(sorted(results.keys()), [ |
449 | - "accept-ranges", "content-length", "content-type", "date", |
450 | - "last-modified", "server"]) |
451 | - self.assertEquals(len(results.values()), 6) |
452 | + "Accept-Ranges", "Content-Type", "Date", |
453 | + "Last-Modified", "Server"]) |
454 | + self.assertEquals(len(results.values()), 5) |
455 | |
456 | query = BaseQuery("an action", "creds", "http://endpoint") |
457 | d = query.get_page(self._get_url("file")) |
458 | d.addCallback(query.get_response_headers) |
459 | return d.addCallback(check_results) |
460 | |
461 | + def test_errors(self): |
462 | + query = BaseQuery("an action", "creds", "http://endpoint") |
463 | + d = query.get_page(self._get_url("not_there")) |
464 | + self.assertFailure(d, TwistedWebError) |
465 | + return d |
466 | + |
467 | + def test_custom_body_producer(self): |
468 | + |
469 | + def check_producer_was_used(ignore): |
470 | + self.assertEqual(producer.written, 'test data') |
471 | + |
472 | + producer = StringBodyProducer('test data') |
473 | + query = BaseQuery("an action", "creds", "http://endpoint", |
474 | + body_producer=producer) |
475 | + d = query.get_page(self._get_url("thing_to_put"), method='PUT') |
476 | + return d.addCallback(check_producer_was_used) |
477 | + |
478 | + def test_custom_receiver_factory(self): |
479 | + |
480 | + class TestReceiverProtocol(StreamingBodyReceiver): |
481 | + used = False |
482 | + |
483 | + def __init__(self): |
484 | + StreamingBodyReceiver.__init__(self) |
485 | + TestReceiverProtocol.used = True |
486 | + |
487 | + def check_used(ignore): |
488 | + self.assert_(TestReceiverProtocol.used) |
489 | + |
490 | + query = BaseQuery("an action", "creds", "http://endpoint", |
491 | + receiver_factory=TestReceiverProtocol) |
492 | + d = query.get_page(self._get_url("file")) |
493 | + d.addCallback(self.assertEquals, "0123456789") |
494 | + d.addCallback(check_used) |
495 | + return d |
496 | + |
497 | # XXX for systems that don't have certs in the DEFAULT_CERT_PATH, this test |
498 | # will fail; instead, let's create some certs in a temp directory and set |
499 | # the DEFAULT_CERT_PATH to point there. |
500 | @@ -167,8 +224,9 @@ |
501 | def __init__(self): |
502 | self.connects = [] |
503 | |
504 | - def connectSSL(self, host, port, client, factory): |
505 | - self.connects.append((host, port, client, factory)) |
506 | + def connectSSL(self, host, port, factory, contextFactory, timeout, |
507 | + bindAddress): |
508 | + self.connects.append((host, port, factory, contextFactory)) |
509 | |
510 | certs = makeCertificate(O="Test Certificate", CN="something")[1] |
511 | self.patch(ssl, "_ca_certs", certs) |
512 | @@ -176,9 +234,56 @@ |
513 | endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) |
514 | query = BaseQuery("an action", "creds", endpoint, fake_reactor) |
515 | query.get_page("https://example.com/file") |
516 | - [(host, port, client, factory)] = fake_reactor.connects |
517 | + [(host, port, factory, contextFactory)] = fake_reactor.connects |
518 | self.assertEqual("example.com", host) |
519 | self.assertEqual(443, port) |
520 | - self.assertTrue(isinstance(factory, ssl.VerifyingContextFactory)) |
521 | - self.assertEqual("example.com", factory.host) |
522 | - self.assertNotEqual([], factory.caCerts) |
523 | + wrappedFactory = contextFactory._webContext |
524 | + self.assertTrue(isinstance(wrappedFactory, ssl.VerifyingContextFactory)) |
525 | + self.assertEqual("example.com", wrappedFactory.host) |
526 | + self.assertNotEqual([], wrappedFactory.caCerts) |
527 | + |
528 | +class StreamingBodyReceiverTestCase(TXAWSTestCase): |
529 | + |
530 | + def test_readback_mode_on(self): |
531 | + """ |
532 | + Test that when readback mode is on inside connectionLost() data will |
533 | + be read back from the start of the file we're streaming and results |
534 | + passed to finished callback. |
535 | + """ |
536 | + |
537 | + receiver = StreamingBodyReceiver() |
538 | + d = Deferred() |
539 | + receiver.finished = d |
540 | + receiver.content_length = 5 |
541 | + fd = receiver._fd |
542 | + receiver.dataReceived('hello') |
543 | + why = Failure(ResponseDone('done')) |
544 | + receiver.connectionLost(why) |
545 | + self.assertEqual(d.result, 'hello') |
546 | + self.assert_(fd.closed) |
547 | + |
548 | + def test_readback_mode_off(self): |
549 | + """ |
550 | + Test that when readback mode is off connectionLost() will simply |
551 | + callback finished with the fd. |
552 | + """ |
553 | + |
554 | + receiver = StreamingBodyReceiver(readback=False) |
555 | + d = Deferred() |
556 | + receiver.finished = d |
557 | + receiver.content_length = 5 |
558 | + fd = receiver._fd |
559 | + receiver.dataReceived('hello') |
560 | + why = Failure(ResponseDone('done')) |
561 | + receiver.connectionLost(why) |
562 | + self.assertIdentical(d.result, fd) |
563 | + self.assertIdentical(receiver._fd, fd) |
564 | + self.failIf(fd.closed) |
565 | + |
566 | + def test_user_fd(self): |
567 | + """ |
568 | + Test that user's own file descriptor can be passed to init |
569 | + """ |
570 | + user_fd = StringIO() |
571 | + receiver = StreamingBodyReceiver(user_fd) |
572 | + self.assertIdentical(receiver._fd, user_fd) |
573 | |
574 | === modified file 'txaws/client/tests/test_ssl.py' |
575 | --- txaws/client/tests/test_ssl.py 2012-01-26 22:54:44 +0000 |
576 | +++ txaws/client/tests/test_ssl.py 2012-03-15 05:30:25 +0000 |
577 | @@ -12,6 +12,10 @@ |
578 | from twisted.python.filepath import FilePath |
579 | from twisted.test.test_sslverify import makeCertificate |
580 | from twisted.web import server, static |
581 | +try: |
582 | + from twisted.web.client import ResponseFailed |
583 | +except ImportError: |
584 | + from twisted.web._newclient import ResponseFailed |
585 | |
586 | from txaws import exception |
587 | from txaws.client import ssl |
588 | @@ -32,6 +36,11 @@ |
589 | PUBSANKEY = sibpath("public_san.ssl") |
590 | |
591 | |
592 | +class WebDefaultOpenSSLContextFactory(DefaultOpenSSLContextFactory): |
593 | + def getContext(self, hostname=None, port=None): |
594 | + return DefaultOpenSSLContextFactory.getContext(self) |
595 | + |
596 | + |
597 | class BaseQuerySSLTestCase(TXAWSTestCase): |
598 | |
599 | def setUp(self): |
600 | @@ -75,7 +84,7 @@ |
601 | The L{VerifyingContextFactory} properly allows to connect to the |
602 | endpoint if the certificates match. |
603 | """ |
604 | - context_factory = DefaultOpenSSLContextFactory(PRIVKEY, PUBKEY) |
605 | + context_factory = WebDefaultOpenSSLContextFactory(PRIVKEY, PUBKEY) |
606 | self.port = reactor.listenSSL( |
607 | 0, self.site, context_factory, interface="127.0.0.1") |
608 | self.portno = self.port.getHost().port |
609 | @@ -90,7 +99,7 @@ |
610 | The L{VerifyingContextFactory} fails with a SSL error the certificates |
611 | can't be checked. |
612 | """ |
613 | - context_factory = DefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) |
614 | + context_factory = WebDefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) |
615 | self.port = reactor.listenSSL( |
616 | 0, self.site, context_factory, interface="127.0.0.1") |
617 | self.portno = self.port.getHost().port |
618 | @@ -98,7 +107,14 @@ |
619 | endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) |
620 | query = BaseQuery("an action", "creds", endpoint) |
621 | d = query.get_page(self._get_url("file")) |
622 | - return self.assertFailure(d, SSLError) |
623 | + def fail(ignore): |
624 | + self.fail('Expected SSLError') |
625 | + def check_exception(why): |
626 | + # XXX kind of a mess here ... need to unwrap the |
627 | + # exception and check |
628 | + root_exc = why.value[0][0].value |
629 | + self.assert_(isinstance(root_exc, SSLError)) |
630 | + return d.addCallbacks(fail, check_exception) |
631 | |
632 | def test_ssl_verification_bypassed(self): |
633 | """ |
634 | @@ -121,7 +137,7 @@ |
635 | L{VerifyingContextFactory} supports checking C{subjectAltName} in the |
636 | certificate if it's available. |
637 | """ |
638 | - context_factory = DefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY) |
639 | + context_factory = WebDefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY) |
640 | self.port = reactor.listenSSL( |
641 | 0, self.site, context_factory, interface="127.0.0.1") |
642 | self.portno = self.port.getHost().port |
643 | |
644 | === added file 'txaws/testing/producers.py' |
645 | --- txaws/testing/producers.py 1970-01-01 00:00:00 +0000 |
646 | +++ txaws/testing/producers.py 2012-03-15 05:30:25 +0000 |
647 | @@ -0,0 +1,23 @@ |
648 | +from zope.interface import implements |
649 | + |
650 | +from twisted.internet.defer import succeed |
651 | +from twisted.web.iweb import IBodyProducer |
652 | + |
653 | +class StringBodyProducer(object): |
654 | + implements(IBodyProducer) |
655 | + |
656 | + def __init__(self, data): |
657 | + self.data = data |
658 | + self.length = len(data) |
659 | + self.written = None |
660 | + |
661 | + def startProducing(self, consumer): |
662 | + consumer.write(self.data) |
663 | + self.written = self.data |
664 | + return succeed(None) |
665 | + |
666 | + def pauseProducing(self): |
667 | + pass |
668 | + |
669 | + def stopProducing(self): |
670 | + pass |
On Thu, Feb 9, 2012 at 10:23 PM, Drew Smathers <email address hidden> wrote: web.client. Agent" /bugs.launchpad .net/txaws/ +bug/924459 /code.launchpad .net/~djfroofy/ txaws/modernize -924459/ +merge/ 92404
> Drew Smathers has proposed merging lp:~djfroofy/txaws/modernize-924459 into lp:txaws.
>
> Requested reviews:
> Duncan McGreggor (oubiwann)
> Related bugs:
> Bug #924459 in txAWS: "Modernize txaws.client to use twisted.
> https:/
>
> For more details, see:
> https:/
>
> Need to care how much we care about compatibility since: (1) This requires Twisted >= 11.1.0 so far as I know and (2) Some public members on BaseQuery no longer exist since they are not applicable when using Agent.
I've taken a quick look at the latest code tonight, and first glance
is a happy one :-)
I've branched the code on my laptop so that I can review it on the
plane tomorrow, run the unit tests, etc.
I hope to have more for you soon.