Merge lp:~djfroofy/txaws/modernize-924459 into lp:txaws
- modernize-924459
- Merge into trunk
Proposed by
Drew Smathers
Status: | Merged |
---|---|
Approved by: | Duncan McGreggor |
Approved revision: | 146 |
Merged at revision: | 145 |
Proposed branch: | lp:~djfroofy/txaws/modernize-924459 |
Merge into: | lp:txaws |
Diff against target: |
670 lines (+459/-28) 5 files modified
txaws/client/_producers.py (+122/-0) txaws/client/base.py (+178/-13) txaws/client/tests/test_base.py (+116/-11) txaws/client/tests/test_ssl.py (+20/-4) txaws/testing/producers.py (+23/-0) |
To merge this branch: | bzr merge lp:~djfroofy/txaws/modernize-924459 |
Related bugs: |
Reviewer | Review Type | Date Requested | Status |
---|---|---|---|
Duncan McGreggor | Approve | ||
Review via email: mp+92404@code.launchpad.net |
Commit message
Description of the change
Need to care how much we care about compatibility since: (1) This requires Twisted >= 11.1.0 so far as I know and (2) Some public members on BaseQuery no longer exist since they are not applicable when using Agent.
To post a comment you must log in.
- 145. By Drew Smathers
-
cherry pick: change StringIOReceiver to generic StreamingBodyRe
ceiver - 146. By Drew Smathers
-
check for response code in _handle_response and errback if >= 400 with response body
Revision history for this message
Duncan McGreggor (oubiwann) wrote : | # |
Preview Diff
[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1 | === added file 'txaws/client/_producers.py' | |||
2 | --- txaws/client/_producers.py 1970-01-01 00:00:00 +0000 | |||
3 | +++ txaws/client/_producers.py 2012-03-15 05:30:25 +0000 | |||
4 | @@ -0,0 +1,122 @@ | |||
5 | 1 | import os | ||
6 | 2 | |||
7 | 3 | from zope.interface import implements | ||
8 | 4 | |||
9 | 5 | from twisted.internet import defer, task | ||
10 | 6 | from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer | ||
11 | 7 | |||
12 | 8 | |||
13 | 9 | # Code below for FileBodyProducer cut-and-paste from twisted source. | ||
14 | 10 | # Currently this is not released so here temporarily for forward compat. | ||
15 | 11 | |||
16 | 12 | |||
17 | 13 | class FileBodyProducer(object): | ||
18 | 14 | """ | ||
19 | 15 | L{FileBodyProducer} produces bytes from an input file object incrementally | ||
20 | 16 | and writes them to a consumer. | ||
21 | 17 | |||
22 | 18 | Since file-like objects cannot be read from in an event-driven manner, | ||
23 | 19 | L{FileBodyProducer} uses a L{Cooperator} instance to schedule reads from | ||
24 | 20 | the file. This process is also paused and resumed based on notifications | ||
25 | 21 | from the L{IConsumer} provider being written to. | ||
26 | 22 | |||
27 | 23 | The file is closed after it has been read, or if the producer is stopped | ||
28 | 24 | early. | ||
29 | 25 | |||
30 | 26 | @ivar _inputFile: Any file-like object, bytes read from which will be | ||
31 | 27 | written to a consumer. | ||
32 | 28 | |||
33 | 29 | @ivar _cooperate: A method like L{Cooperator.cooperate} which is used to | ||
34 | 30 | schedule all reads. | ||
35 | 31 | |||
36 | 32 | @ivar _readSize: The number of bytes to read from C{_inputFile} at a time. | ||
37 | 33 | """ | ||
38 | 34 | implements(IBodyProducer) | ||
39 | 35 | |||
40 | 36 | # Python 2.4 doesn't have these symbolic constants | ||
41 | 37 | _SEEK_SET = getattr(os, 'SEEK_SET', 0) | ||
42 | 38 | _SEEK_END = getattr(os, 'SEEK_END', 2) | ||
43 | 39 | |||
44 | 40 | def __init__(self, inputFile, cooperator=task, readSize=2 ** 16): | ||
45 | 41 | self._inputFile = inputFile | ||
46 | 42 | self._cooperate = cooperator.cooperate | ||
47 | 43 | self._readSize = readSize | ||
48 | 44 | self.length = self._determineLength(inputFile) | ||
49 | 45 | |||
50 | 46 | |||
51 | 47 | def _determineLength(self, fObj): | ||
52 | 48 | """ | ||
53 | 49 | Determine how many bytes can be read out of C{fObj} (assuming it is not | ||
54 | 50 | modified from this point on). If the determination cannot be made, | ||
55 | 51 | return C{UNKNOWN_LENGTH}. | ||
56 | 52 | """ | ||
57 | 53 | try: | ||
58 | 54 | seek = fObj.seek | ||
59 | 55 | tell = fObj.tell | ||
60 | 56 | except AttributeError: | ||
61 | 57 | return UNKNOWN_LENGTH | ||
62 | 58 | originalPosition = tell() | ||
63 | 59 | seek(0, self._SEEK_END) | ||
64 | 60 | end = tell() | ||
65 | 61 | seek(originalPosition, self._SEEK_SET) | ||
66 | 62 | return end - originalPosition | ||
67 | 63 | |||
68 | 64 | |||
69 | 65 | def stopProducing(self): | ||
70 | 66 | """ | ||
71 | 67 | Permanently stop writing bytes from the file to the consumer by | ||
72 | 68 | stopping the underlying L{CooperativeTask}. | ||
73 | 69 | """ | ||
74 | 70 | self._inputFile.close() | ||
75 | 71 | self._task.stop() | ||
76 | 72 | |||
77 | 73 | |||
78 | 74 | def startProducing(self, consumer): | ||
79 | 75 | """ | ||
80 | 76 | Start a cooperative task which will read bytes from the input file and | ||
81 | 77 | write them to C{consumer}. Return a L{Deferred} which fires after all | ||
82 | 78 | bytes have been written. | ||
83 | 79 | |||
84 | 80 | @param consumer: Any L{IConsumer} provider | ||
85 | 81 | """ | ||
86 | 82 | self._task = self._cooperate(self._writeloop(consumer)) | ||
87 | 83 | d = self._task.whenDone() | ||
88 | 84 | def maybeStopped(reason): | ||
89 | 85 | # IBodyProducer.startProducing's Deferred isn't support to fire if | ||
90 | 86 | # stopProducing is called. | ||
91 | 87 | reason.trap(task.TaskStopped) | ||
92 | 88 | return defer.Deferred() | ||
93 | 89 | d.addCallbacks(lambda ignored: None, maybeStopped) | ||
94 | 90 | return d | ||
95 | 91 | |||
96 | 92 | |||
97 | 93 | def _writeloop(self, consumer): | ||
98 | 94 | """ | ||
99 | 95 | Return an iterator which reads one chunk of bytes from the input file | ||
100 | 96 | and writes them to the consumer for each time it is iterated. | ||
101 | 97 | """ | ||
102 | 98 | while True: | ||
103 | 99 | bytes = self._inputFile.read(self._readSize) | ||
104 | 100 | if not bytes: | ||
105 | 101 | self._inputFile.close() | ||
106 | 102 | break | ||
107 | 103 | consumer.write(bytes) | ||
108 | 104 | yield None | ||
109 | 105 | |||
110 | 106 | |||
111 | 107 | def pauseProducing(self): | ||
112 | 108 | """ | ||
113 | 109 | Temporarily suspend copying bytes from the input file to the consumer | ||
114 | 110 | by pausing the L{CooperativeTask} which drives that activity. | ||
115 | 111 | """ | ||
116 | 112 | self._task.pause() | ||
117 | 113 | |||
118 | 114 | |||
119 | 115 | def resumeProducing(self): | ||
120 | 116 | """ | ||
121 | 117 | Undo the effects of a previous C{pauseProducing} and resume copying | ||
122 | 118 | bytes to the consumer by resuming the L{CooperativeTask} which drives | ||
123 | 119 | the write activity. | ||
124 | 120 | """ | ||
125 | 121 | self._task.resume() | ||
126 | 122 | |||
127 | 0 | 123 | ||
128 | === modified file 'txaws/client/base.py' | |||
129 | --- txaws/client/base.py 2011-11-29 08:17:54 +0000 | |||
130 | +++ txaws/client/base.py 2012-03-15 05:30:25 +0000 | |||
131 | @@ -3,10 +3,25 @@ | |||
132 | 3 | except ImportError: | 3 | except ImportError: |
133 | 4 | from xml.parsers.expat import ExpatError as ParseError | 4 | from xml.parsers.expat import ExpatError as ParseError |
134 | 5 | 5 | ||
135 | 6 | import warnings | ||
136 | 7 | from StringIO import StringIO | ||
137 | 8 | |||
138 | 6 | from twisted.internet.ssl import ClientContextFactory | 9 | from twisted.internet.ssl import ClientContextFactory |
139 | 10 | from twisted.internet.protocol import Protocol | ||
140 | 11 | from twisted.internet.defer import Deferred, succeed, fail | ||
141 | 12 | from twisted.python import failure | ||
142 | 7 | from twisted.web import http | 13 | from twisted.web import http |
143 | 14 | from twisted.web.iweb import UNKNOWN_LENGTH | ||
144 | 8 | from twisted.web.client import HTTPClientFactory | 15 | from twisted.web.client import HTTPClientFactory |
145 | 16 | from twisted.web.client import Agent | ||
146 | 17 | from twisted.web.client import ResponseDone | ||
147 | 18 | from twisted.web.http import NO_CONTENT | ||
148 | 19 | from twisted.web.http_headers import Headers | ||
149 | 9 | from twisted.web.error import Error as TwistedWebError | 20 | from twisted.web.error import Error as TwistedWebError |
150 | 21 | try: | ||
151 | 22 | from twisted.web.client import FileBodyProducer | ||
152 | 23 | except ImportError: | ||
153 | 24 | from txaws.client._producers import FileBodyProducer | ||
154 | 10 | 25 | ||
155 | 11 | from txaws.util import parse | 26 | from txaws.util import parse |
156 | 12 | from txaws.credentials import AWSCredentials | 27 | from txaws.credentials import AWSCredentials |
157 | @@ -71,20 +86,122 @@ | |||
158 | 71 | self.query_factory = query_factory | 86 | self.query_factory = query_factory |
159 | 72 | self.parser = parser | 87 | self.parser = parser |
160 | 73 | 88 | ||
161 | 89 | class StreamingError(Exception): | ||
162 | 90 | """ | ||
163 | 91 | Raised if more data or less data is received than expected. | ||
164 | 92 | """ | ||
165 | 93 | |||
166 | 94 | |||
167 | 95 | class StreamingBodyReceiver(Protocol): | ||
168 | 96 | """ | ||
169 | 97 | Streaming HTTP response body receiver. | ||
170 | 98 | |||
171 | 99 | TODO: perhaps there should be an interface specifying why | ||
172 | 100 | finished (Deferred) and content_length are necessary and | ||
173 | 101 | how to used them; eg. callback/errback finished on completion. | ||
174 | 102 | """ | ||
175 | 103 | finished = None | ||
176 | 104 | content_length = None | ||
177 | 105 | |||
178 | 106 | def __init__(self, fd=None, readback=True): | ||
179 | 107 | """ | ||
180 | 108 | @param fd: a file descriptor to write to | ||
181 | 109 | @param readback: if True read back data from fd to callback finished | ||
182 | 110 | with, otherwise we call back finish with fd itself | ||
183 | 111 | with | ||
184 | 112 | """ | ||
185 | 113 | if fd is None: | ||
186 | 114 | fd = StringIO() | ||
187 | 115 | self._fd = fd | ||
188 | 116 | self._received = 0 | ||
189 | 117 | self._readback = readback | ||
190 | 118 | |||
191 | 119 | def dataReceived(self, bytes): | ||
192 | 120 | streaming = self.content_length is UNKNOWN_LENGTH | ||
193 | 121 | if not streaming and (self._received > self.content_length): | ||
194 | 122 | self.transport.loseConnection() | ||
195 | 123 | raise StreamingError( | ||
196 | 124 | "Buffer overflow - received more data than " | ||
197 | 125 | "Content-Length dictated: %d" % self.content_length) | ||
198 | 126 | # TODO should be some limit on how much we receive | ||
199 | 127 | self._fd.write(bytes) | ||
200 | 128 | self._received += len(bytes) | ||
201 | 129 | |||
202 | 130 | def connectionLost(self, reason): | ||
203 | 131 | reason.trap(ResponseDone) | ||
204 | 132 | d = self.finished | ||
205 | 133 | self.finished = None | ||
206 | 134 | streaming = self.content_length is UNKNOWN_LENGTH | ||
207 | 135 | if streaming or (self._received == self.content_length): | ||
208 | 136 | if self._readback: | ||
209 | 137 | self._fd.seek(0) | ||
210 | 138 | data = self._fd.read() | ||
211 | 139 | self._fd.close() | ||
212 | 140 | self._fd = None | ||
213 | 141 | d.callback(data) | ||
214 | 142 | else: | ||
215 | 143 | d.callback(self._fd) | ||
216 | 144 | else: | ||
217 | 145 | f = failure.Failure(StreamingError("Connection lost before " | ||
218 | 146 | "receiving all data")) | ||
219 | 147 | d.errback(f) | ||
220 | 148 | |||
221 | 149 | |||
222 | 150 | class WebClientContextFactory(ClientContextFactory): | ||
223 | 151 | |||
224 | 152 | def getContext(self, hostname, port): | ||
225 | 153 | return ClientContextFactory.getContext(self) | ||
226 | 154 | |||
227 | 155 | |||
228 | 156 | class WebVerifyingContextFactory(VerifyingContextFactory): | ||
229 | 157 | |||
230 | 158 | def getContext(self, hostname, port): | ||
231 | 159 | return VerifyingContextFactory.getContext(self) | ||
232 | 160 | |||
233 | 161 | |||
234 | 162 | class FakeClient(object): | ||
235 | 163 | """ | ||
236 | 164 | XXX | ||
237 | 165 | A fake client object for some degree of backwards compatability for | ||
238 | 166 | code using the client attibute on BaseQuery to check url, status | ||
239 | 167 | etc. | ||
240 | 168 | """ | ||
241 | 169 | url = None | ||
242 | 170 | status = None | ||
243 | 74 | 171 | ||
244 | 75 | class BaseQuery(object): | 172 | class BaseQuery(object): |
245 | 76 | 173 | ||
247 | 77 | def __init__(self, action=None, creds=None, endpoint=None, reactor=None): | 174 | def __init__(self, action=None, creds=None, endpoint=None, reactor=None, |
248 | 175 | body_producer=None, receiver_factory=None): | ||
249 | 78 | if not action: | 176 | if not action: |
250 | 79 | raise TypeError("The query requires an action parameter.") | 177 | raise TypeError("The query requires an action parameter.") |
251 | 80 | self.factory = HTTPClientFactory | ||
252 | 81 | self.action = action | 178 | self.action = action |
253 | 82 | self.creds = creds | 179 | self.creds = creds |
254 | 83 | self.endpoint = endpoint | 180 | self.endpoint = endpoint |
255 | 84 | if reactor is None: | 181 | if reactor is None: |
256 | 85 | from twisted.internet import reactor | 182 | from twisted.internet import reactor |
257 | 86 | self.reactor = reactor | 183 | self.reactor = reactor |
259 | 87 | self.client = None | 184 | self._client = None |
260 | 185 | self.request_headers = None | ||
261 | 186 | self.response_headers = None | ||
262 | 187 | self.body_producer = body_producer | ||
263 | 188 | self.receiver_factory = receiver_factory or StreamingBodyReceiver | ||
264 | 189 | |||
265 | 190 | @property | ||
266 | 191 | def client(self): | ||
267 | 192 | if self._client is None: | ||
268 | 193 | self._client_deprecation_warning() | ||
269 | 194 | self._client = FakeClient() | ||
270 | 195 | return self._client | ||
271 | 196 | |||
272 | 197 | @client.setter | ||
273 | 198 | def client(self, value): | ||
274 | 199 | self._client_deprecation_warning() | ||
275 | 200 | self._client = value | ||
276 | 201 | |||
277 | 202 | def _client_deprecation_warning(self): | ||
278 | 203 | warnings.warn('The client attribute on BaseQuery is deprecated and' | ||
279 | 204 | ' will go away in future release.') | ||
280 | 88 | 205 | ||
281 | 89 | def get_page(self, url, *args, **kwds): | 206 | def get_page(self, url, *args, **kwds): |
282 | 90 | """ | 207 | """ |
283 | @@ -95,16 +212,39 @@ | |||
284 | 95 | """ | 212 | """ |
285 | 96 | contextFactory = None | 213 | contextFactory = None |
286 | 97 | scheme, host, port, path = parse(url) | 214 | scheme, host, port, path = parse(url) |
288 | 98 | self.client = self.factory(url, *args, **kwds) | 215 | data = kwds.get('postdata', None) |
289 | 216 | self._method = method = kwds.get('method', 'GET') | ||
290 | 217 | self.request_headers = self._headers(kwds.get('headers', {})) | ||
291 | 218 | if (self.body_producer is None) and (data is not None): | ||
292 | 219 | self.body_producer = FileBodyProducer(StringIO(data)) | ||
293 | 99 | if scheme == "https": | 220 | if scheme == "https": |
294 | 100 | if self.endpoint.ssl_hostname_verification: | 221 | if self.endpoint.ssl_hostname_verification: |
296 | 101 | contextFactory = VerifyingContextFactory(host) | 222 | contextFactory = WebVerifyingContextFactory(host) |
297 | 102 | else: | 223 | else: |
300 | 103 | contextFactory = ClientContextFactory() | 224 | contextFactory = WebClientContextFactory() |
301 | 104 | self.reactor.connectSSL(host, port, self.client, contextFactory) | 225 | agent = Agent(self.reactor, contextFactory) |
302 | 226 | self.client.url = url | ||
303 | 227 | d = agent.request(method, url, self.request_headers, | ||
304 | 228 | self.body_producer) | ||
305 | 105 | else: | 229 | else: |
308 | 106 | self.reactor.connectTCP(host, port, self.client) | 230 | agent = Agent(self.reactor) |
309 | 107 | return self.client.deferred | 231 | d = agent.request(method, url, self.request_headers, |
310 | 232 | self.body_producer) | ||
311 | 233 | d.addCallback(self._handle_response) | ||
312 | 234 | return d | ||
313 | 235 | |||
314 | 236 | def _headers(self, headers_dict): | ||
315 | 237 | """ | ||
316 | 238 | Convert dictionary of headers into twisted.web.client.Headers object. | ||
317 | 239 | """ | ||
318 | 240 | return Headers(dict((k,[v]) for (k,v) in headers_dict.items())) | ||
319 | 241 | |||
320 | 242 | def _unpack_headers(self, headers): | ||
321 | 243 | """ | ||
322 | 244 | Unpack twisted.web.client.Headers object to dict. This is to provide | ||
323 | 245 | backwards compatability. | ||
324 | 246 | """ | ||
325 | 247 | return dict((k,v[0]) for (k,v) in headers.getAllRawHeaders()) | ||
326 | 108 | 248 | ||
327 | 109 | def get_request_headers(self, *args, **kwds): | 249 | def get_request_headers(self, *args, **kwds): |
328 | 110 | """ | 250 | """ |
329 | @@ -114,8 +254,32 @@ | |||
330 | 114 | The AWS S3 API depends upon setting headers. This method is provided as | 254 | The AWS S3 API depends upon setting headers. This method is provided as |
331 | 115 | a convenience for debugging issues with the S3 communications. | 255 | a convenience for debugging issues with the S3 communications. |
332 | 116 | """ | 256 | """ |
335 | 117 | if self.client: | 257 | if self.request_headers: |
336 | 118 | return self.client.headers | 258 | return self._unpack_headers(self.request_headers) |
337 | 259 | |||
338 | 260 | def _handle_response(self, response): | ||
339 | 261 | """ | ||
340 | 262 | Handle the HTTP response by memoing the headers and then delivering | ||
341 | 263 | bytes. | ||
342 | 264 | """ | ||
343 | 265 | self.client.status = response.code | ||
344 | 266 | self.response_headers = headers = response.headers | ||
345 | 267 | # XXX This workaround (which needs to be improved at that) for possible | ||
346 | 268 | # bug in Twisted with new client: | ||
347 | 269 | # http://twistedmatrix.com/trac/ticket/5476 | ||
348 | 270 | if self._method.upper() == 'HEAD' or response.code == NO_CONTENT: | ||
349 | 271 | return succeed('') | ||
350 | 272 | receiver = self.receiver_factory() | ||
351 | 273 | receiver.finished = d = Deferred() | ||
352 | 274 | receiver.content_length = response.length | ||
353 | 275 | response.deliverBody(receiver) | ||
354 | 276 | if response.code >= 400: | ||
355 | 277 | d.addCallback(self._fail_response, response) | ||
356 | 278 | return d | ||
357 | 279 | |||
358 | 280 | def _fail_response(self, data, response): | ||
359 | 281 | return fail(failure.Failure( | ||
360 | 282 | TwistedWebError(response.code, response=data))) | ||
361 | 119 | 283 | ||
362 | 120 | def get_response_headers(self, *args, **kwargs): | 284 | def get_response_headers(self, *args, **kwargs): |
363 | 121 | """ | 285 | """ |
364 | @@ -125,5 +289,6 @@ | |||
365 | 125 | The AWS S3 API depends upon setting headers. This method is used by the | 289 | The AWS S3 API depends upon setting headers. This method is used by the |
366 | 126 | head_object API call for getting a S3 object's metadata. | 290 | head_object API call for getting a S3 object's metadata. |
367 | 127 | """ | 291 | """ |
370 | 128 | if self.client: | 292 | if self.response_headers: |
371 | 129 | return self.client.response_headers | 293 | return self._unpack_headers(self.response_headers) |
372 | 294 | |||
373 | 130 | 295 | ||
374 | === modified file 'txaws/client/tests/test_base.py' | |||
375 | --- txaws/client/tests/test_base.py 2012-01-26 18:43:48 +0000 | |||
376 | +++ txaws/client/tests/test_base.py 2012-03-15 05:30:25 +0000 | |||
377 | @@ -1,6 +1,11 @@ | |||
378 | 1 | import os | 1 | import os |
379 | 2 | 2 | ||
380 | 3 | from StringIO import StringIO | ||
381 | 4 | |||
382 | 5 | from zope.interface import implements | ||
383 | 6 | |||
384 | 3 | from twisted.internet import reactor | 7 | from twisted.internet import reactor |
385 | 8 | from twisted.internet.defer import succeed, Deferred | ||
386 | 4 | from twisted.internet.error import ConnectionRefusedError | 9 | from twisted.internet.error import ConnectionRefusedError |
387 | 5 | from twisted.protocols.policies import WrappingFactory | 10 | from twisted.protocols.policies import WrappingFactory |
388 | 6 | from twisted.python import log | 11 | from twisted.python import log |
389 | @@ -8,14 +13,18 @@ | |||
390 | 8 | from twisted.python.failure import Failure | 13 | from twisted.python.failure import Failure |
391 | 9 | from twisted.test.test_sslverify import makeCertificate | 14 | from twisted.test.test_sslverify import makeCertificate |
392 | 10 | from twisted.web import server, static | 15 | from twisted.web import server, static |
393 | 16 | from twisted.web.iweb import IBodyProducer | ||
394 | 11 | from twisted.web.client import HTTPClientFactory | 17 | from twisted.web.client import HTTPClientFactory |
395 | 18 | from twisted.web.client import ResponseDone | ||
396 | 19 | from twisted.web.resource import Resource | ||
397 | 12 | from twisted.web.error import Error as TwistedWebError | 20 | from twisted.web.error import Error as TwistedWebError |
398 | 13 | 21 | ||
399 | 14 | from txaws.client import ssl | 22 | from txaws.client import ssl |
400 | 15 | from txaws.client.base import BaseClient, BaseQuery, error_wrapper | 23 | from txaws.client.base import BaseClient, BaseQuery, error_wrapper |
401 | 24 | from txaws.client.base import StreamingBodyReceiver | ||
402 | 16 | from txaws.service import AWSServiceEndpoint | 25 | from txaws.service import AWSServiceEndpoint |
403 | 17 | from txaws.testing.base import TXAWSTestCase | 26 | from txaws.testing.base import TXAWSTestCase |
405 | 18 | 27 | from txaws.testing.producers import StringBodyProducer | |
406 | 19 | 28 | ||
407 | 20 | class ErrorWrapperTestCase(TXAWSTestCase): | 29 | class ErrorWrapperTestCase(TXAWSTestCase): |
408 | 21 | 30 | ||
409 | @@ -63,6 +72,12 @@ | |||
410 | 63 | self.assertEquals(client.parser, "parser") | 72 | self.assertEquals(client.parser, "parser") |
411 | 64 | 73 | ||
412 | 65 | 74 | ||
413 | 75 | class PuttableResource(Resource): | ||
414 | 76 | |||
415 | 77 | def render_PUT(self, reuqest): | ||
416 | 78 | return '' | ||
417 | 79 | |||
418 | 80 | |||
419 | 66 | class BaseQueryTestCase(TXAWSTestCase): | 81 | class BaseQueryTestCase(TXAWSTestCase): |
420 | 67 | 82 | ||
421 | 68 | def setUp(self): | 83 | def setUp(self): |
422 | @@ -71,6 +86,7 @@ | |||
423 | 71 | os.mkdir(name) | 86 | os.mkdir(name) |
424 | 72 | FilePath(name).child("file").setContent("0123456789") | 87 | FilePath(name).child("file").setContent("0123456789") |
425 | 73 | r = static.File(name) | 88 | r = static.File(name) |
426 | 89 | r.putChild('thing_to_put', PuttableResource()) | ||
427 | 74 | self.site = server.Site(r, timeout=None) | 90 | self.site = server.Site(r, timeout=None) |
428 | 75 | self.wrapper = WrappingFactory(self.site) | 91 | self.wrapper = WrappingFactory(self.site) |
429 | 76 | self.port = self._listen(self.wrapper) | 92 | self.port = self._listen(self.wrapper) |
430 | @@ -99,7 +115,6 @@ | |||
431 | 99 | 115 | ||
432 | 100 | def test_creation(self): | 116 | def test_creation(self): |
433 | 101 | query = BaseQuery("an action", "creds", "http://endpoint") | 117 | query = BaseQuery("an action", "creds", "http://endpoint") |
434 | 102 | self.assertEquals(query.factory, HTTPClientFactory) | ||
435 | 103 | self.assertEquals(query.action, "an action") | 118 | self.assertEquals(query.action, "an action") |
436 | 104 | self.assertEquals(query.creds, "creds") | 119 | self.assertEquals(query.creds, "creds") |
437 | 105 | self.assertEquals(query.endpoint, "http://endpoint") | 120 | self.assertEquals(query.endpoint, "http://endpoint") |
438 | @@ -142,16 +157,58 @@ | |||
439 | 142 | def test_get_response_headers_with_client(self): | 157 | def test_get_response_headers_with_client(self): |
440 | 143 | 158 | ||
441 | 144 | def check_results(results): | 159 | def check_results(results): |
442 | 160 | #self.assertEquals(sorted(results.keys()), [ | ||
443 | 161 | # "accept-ranges", "content-length", "content-type", "date", | ||
444 | 162 | # "last-modified", "server"]) | ||
445 | 163 | # XXX I think newclient exludes content-length from headers? | ||
446 | 164 | # Also the header names are capitalized ... do we need to worry | ||
447 | 165 | # about backwards compat? | ||
448 | 145 | self.assertEquals(sorted(results.keys()), [ | 166 | self.assertEquals(sorted(results.keys()), [ |
452 | 146 | "accept-ranges", "content-length", "content-type", "date", | 167 | "Accept-Ranges", "Content-Type", "Date", |
453 | 147 | "last-modified", "server"]) | 168 | "Last-Modified", "Server"]) |
454 | 148 | self.assertEquals(len(results.values()), 6) | 169 | self.assertEquals(len(results.values()), 5) |
455 | 149 | 170 | ||
456 | 150 | query = BaseQuery("an action", "creds", "http://endpoint") | 171 | query = BaseQuery("an action", "creds", "http://endpoint") |
457 | 151 | d = query.get_page(self._get_url("file")) | 172 | d = query.get_page(self._get_url("file")) |
458 | 152 | d.addCallback(query.get_response_headers) | 173 | d.addCallback(query.get_response_headers) |
459 | 153 | return d.addCallback(check_results) | 174 | return d.addCallback(check_results) |
460 | 154 | 175 | ||
461 | 176 | def test_errors(self): | ||
462 | 177 | query = BaseQuery("an action", "creds", "http://endpoint") | ||
463 | 178 | d = query.get_page(self._get_url("not_there")) | ||
464 | 179 | self.assertFailure(d, TwistedWebError) | ||
465 | 180 | return d | ||
466 | 181 | |||
467 | 182 | def test_custom_body_producer(self): | ||
468 | 183 | |||
469 | 184 | def check_producer_was_used(ignore): | ||
470 | 185 | self.assertEqual(producer.written, 'test data') | ||
471 | 186 | |||
472 | 187 | producer = StringBodyProducer('test data') | ||
473 | 188 | query = BaseQuery("an action", "creds", "http://endpoint", | ||
474 | 189 | body_producer=producer) | ||
475 | 190 | d = query.get_page(self._get_url("thing_to_put"), method='PUT') | ||
476 | 191 | return d.addCallback(check_producer_was_used) | ||
477 | 192 | |||
478 | 193 | def test_custom_receiver_factory(self): | ||
479 | 194 | |||
480 | 195 | class TestReceiverProtocol(StreamingBodyReceiver): | ||
481 | 196 | used = False | ||
482 | 197 | |||
483 | 198 | def __init__(self): | ||
484 | 199 | StreamingBodyReceiver.__init__(self) | ||
485 | 200 | TestReceiverProtocol.used = True | ||
486 | 201 | |||
487 | 202 | def check_used(ignore): | ||
488 | 203 | self.assert_(TestReceiverProtocol.used) | ||
489 | 204 | |||
490 | 205 | query = BaseQuery("an action", "creds", "http://endpoint", | ||
491 | 206 | receiver_factory=TestReceiverProtocol) | ||
492 | 207 | d = query.get_page(self._get_url("file")) | ||
493 | 208 | d.addCallback(self.assertEquals, "0123456789") | ||
494 | 209 | d.addCallback(check_used) | ||
495 | 210 | return d | ||
496 | 211 | |||
497 | 155 | # XXX for systems that don't have certs in the DEFAULT_CERT_PATH, this test | 212 | # XXX for systems that don't have certs in the DEFAULT_CERT_PATH, this test |
498 | 156 | # will fail; instead, let's create some certs in a temp directory and set | 213 | # will fail; instead, let's create some certs in a temp directory and set |
499 | 157 | # the DEFAULT_CERT_PATH to point there. | 214 | # the DEFAULT_CERT_PATH to point there. |
500 | @@ -167,8 +224,9 @@ | |||
501 | 167 | def __init__(self): | 224 | def __init__(self): |
502 | 168 | self.connects = [] | 225 | self.connects = [] |
503 | 169 | 226 | ||
506 | 170 | def connectSSL(self, host, port, client, factory): | 227 | def connectSSL(self, host, port, factory, contextFactory, timeout, |
507 | 171 | self.connects.append((host, port, client, factory)) | 228 | bindAddress): |
508 | 229 | self.connects.append((host, port, factory, contextFactory)) | ||
509 | 172 | 230 | ||
510 | 173 | certs = makeCertificate(O="Test Certificate", CN="something")[1] | 231 | certs = makeCertificate(O="Test Certificate", CN="something")[1] |
511 | 174 | self.patch(ssl, "_ca_certs", certs) | 232 | self.patch(ssl, "_ca_certs", certs) |
512 | @@ -176,9 +234,56 @@ | |||
513 | 176 | endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) | 234 | endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) |
514 | 177 | query = BaseQuery("an action", "creds", endpoint, fake_reactor) | 235 | query = BaseQuery("an action", "creds", endpoint, fake_reactor) |
515 | 178 | query.get_page("https://example.com/file") | 236 | query.get_page("https://example.com/file") |
517 | 179 | [(host, port, client, factory)] = fake_reactor.connects | 237 | [(host, port, factory, contextFactory)] = fake_reactor.connects |
518 | 180 | self.assertEqual("example.com", host) | 238 | self.assertEqual("example.com", host) |
519 | 181 | self.assertEqual(443, port) | 239 | self.assertEqual(443, port) |
523 | 182 | self.assertTrue(isinstance(factory, ssl.VerifyingContextFactory)) | 240 | wrappedFactory = contextFactory._webContext |
524 | 183 | self.assertEqual("example.com", factory.host) | 241 | self.assertTrue(isinstance(wrappedFactory, ssl.VerifyingContextFactory)) |
525 | 184 | self.assertNotEqual([], factory.caCerts) | 242 | self.assertEqual("example.com", wrappedFactory.host) |
526 | 243 | self.assertNotEqual([], wrappedFactory.caCerts) | ||
527 | 244 | |||
528 | 245 | class StreamingBodyReceiverTestCase(TXAWSTestCase): | ||
529 | 246 | |||
530 | 247 | def test_readback_mode_on(self): | ||
531 | 248 | """ | ||
532 | 249 | Test that when readback mode is on inside connectionLost() data will | ||
533 | 250 | be read back from the start of the file we're streaming and results | ||
534 | 251 | passed to finished callback. | ||
535 | 252 | """ | ||
536 | 253 | |||
537 | 254 | receiver = StreamingBodyReceiver() | ||
538 | 255 | d = Deferred() | ||
539 | 256 | receiver.finished = d | ||
540 | 257 | receiver.content_length = 5 | ||
541 | 258 | fd = receiver._fd | ||
542 | 259 | receiver.dataReceived('hello') | ||
543 | 260 | why = Failure(ResponseDone('done')) | ||
544 | 261 | receiver.connectionLost(why) | ||
545 | 262 | self.assertEqual(d.result, 'hello') | ||
546 | 263 | self.assert_(fd.closed) | ||
547 | 264 | |||
548 | 265 | def test_readback_mode_off(self): | ||
549 | 266 | """ | ||
550 | 267 | Test that when readback mode is off connectionLost() will simply | ||
551 | 268 | callback finished with the fd. | ||
552 | 269 | """ | ||
553 | 270 | |||
554 | 271 | receiver = StreamingBodyReceiver(readback=False) | ||
555 | 272 | d = Deferred() | ||
556 | 273 | receiver.finished = d | ||
557 | 274 | receiver.content_length = 5 | ||
558 | 275 | fd = receiver._fd | ||
559 | 276 | receiver.dataReceived('hello') | ||
560 | 277 | why = Failure(ResponseDone('done')) | ||
561 | 278 | receiver.connectionLost(why) | ||
562 | 279 | self.assertIdentical(d.result, fd) | ||
563 | 280 | self.assertIdentical(receiver._fd, fd) | ||
564 | 281 | self.failIf(fd.closed) | ||
565 | 282 | |||
566 | 283 | def test_user_fd(self): | ||
567 | 284 | """ | ||
568 | 285 | Test that user's own file descriptor can be passed to init | ||
569 | 286 | """ | ||
570 | 287 | user_fd = StringIO() | ||
571 | 288 | receiver = StreamingBodyReceiver(user_fd) | ||
572 | 289 | self.assertIdentical(receiver._fd, user_fd) | ||
573 | 185 | 290 | ||
574 | === modified file 'txaws/client/tests/test_ssl.py' | |||
575 | --- txaws/client/tests/test_ssl.py 2012-01-26 22:54:44 +0000 | |||
576 | +++ txaws/client/tests/test_ssl.py 2012-03-15 05:30:25 +0000 | |||
577 | @@ -12,6 +12,10 @@ | |||
578 | 12 | from twisted.python.filepath import FilePath | 12 | from twisted.python.filepath import FilePath |
579 | 13 | from twisted.test.test_sslverify import makeCertificate | 13 | from twisted.test.test_sslverify import makeCertificate |
580 | 14 | from twisted.web import server, static | 14 | from twisted.web import server, static |
581 | 15 | try: | ||
582 | 16 | from twisted.web.client import ResponseFailed | ||
583 | 17 | except ImportError: | ||
584 | 18 | from twisted.web._newclient import ResponseFailed | ||
585 | 15 | 19 | ||
586 | 16 | from txaws import exception | 20 | from txaws import exception |
587 | 17 | from txaws.client import ssl | 21 | from txaws.client import ssl |
588 | @@ -32,6 +36,11 @@ | |||
589 | 32 | PUBSANKEY = sibpath("public_san.ssl") | 36 | PUBSANKEY = sibpath("public_san.ssl") |
590 | 33 | 37 | ||
591 | 34 | 38 | ||
592 | 39 | class WebDefaultOpenSSLContextFactory(DefaultOpenSSLContextFactory): | ||
593 | 40 | def getContext(self, hostname=None, port=None): | ||
594 | 41 | return DefaultOpenSSLContextFactory.getContext(self) | ||
595 | 42 | |||
596 | 43 | |||
597 | 35 | class BaseQuerySSLTestCase(TXAWSTestCase): | 44 | class BaseQuerySSLTestCase(TXAWSTestCase): |
598 | 36 | 45 | ||
599 | 37 | def setUp(self): | 46 | def setUp(self): |
600 | @@ -75,7 +84,7 @@ | |||
601 | 75 | The L{VerifyingContextFactory} properly allows to connect to the | 84 | The L{VerifyingContextFactory} properly allows to connect to the |
602 | 76 | endpoint if the certificates match. | 85 | endpoint if the certificates match. |
603 | 77 | """ | 86 | """ |
605 | 78 | context_factory = DefaultOpenSSLContextFactory(PRIVKEY, PUBKEY) | 87 | context_factory = WebDefaultOpenSSLContextFactory(PRIVKEY, PUBKEY) |
606 | 79 | self.port = reactor.listenSSL( | 88 | self.port = reactor.listenSSL( |
607 | 80 | 0, self.site, context_factory, interface="127.0.0.1") | 89 | 0, self.site, context_factory, interface="127.0.0.1") |
608 | 81 | self.portno = self.port.getHost().port | 90 | self.portno = self.port.getHost().port |
609 | @@ -90,7 +99,7 @@ | |||
610 | 90 | The L{VerifyingContextFactory} fails with a SSL error the certificates | 99 | The L{VerifyingContextFactory} fails with a SSL error the certificates |
611 | 91 | can't be checked. | 100 | can't be checked. |
612 | 92 | """ | 101 | """ |
614 | 93 | context_factory = DefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) | 102 | context_factory = WebDefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) |
615 | 94 | self.port = reactor.listenSSL( | 103 | self.port = reactor.listenSSL( |
616 | 95 | 0, self.site, context_factory, interface="127.0.0.1") | 104 | 0, self.site, context_factory, interface="127.0.0.1") |
617 | 96 | self.portno = self.port.getHost().port | 105 | self.portno = self.port.getHost().port |
618 | @@ -98,7 +107,14 @@ | |||
619 | 98 | endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) | 107 | endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) |
620 | 99 | query = BaseQuery("an action", "creds", endpoint) | 108 | query = BaseQuery("an action", "creds", endpoint) |
621 | 100 | d = query.get_page(self._get_url("file")) | 109 | d = query.get_page(self._get_url("file")) |
623 | 101 | return self.assertFailure(d, SSLError) | 110 | def fail(ignore): |
624 | 111 | self.fail('Expected SSLError') | ||
625 | 112 | def check_exception(why): | ||
626 | 113 | # XXX kind of a mess here ... need to unwrap the | ||
627 | 114 | # exception and check | ||
628 | 115 | root_exc = why.value[0][0].value | ||
629 | 116 | self.assert_(isinstance(root_exc, SSLError)) | ||
630 | 117 | return d.addCallbacks(fail, check_exception) | ||
631 | 102 | 118 | ||
632 | 103 | def test_ssl_verification_bypassed(self): | 119 | def test_ssl_verification_bypassed(self): |
633 | 104 | """ | 120 | """ |
634 | @@ -121,7 +137,7 @@ | |||
635 | 121 | L{VerifyingContextFactory} supports checking C{subjectAltName} in the | 137 | L{VerifyingContextFactory} supports checking C{subjectAltName} in the |
636 | 122 | certificate if it's available. | 138 | certificate if it's available. |
637 | 123 | """ | 139 | """ |
639 | 124 | context_factory = DefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY) | 140 | context_factory = WebDefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY) |
640 | 125 | self.port = reactor.listenSSL( | 141 | self.port = reactor.listenSSL( |
641 | 126 | 0, self.site, context_factory, interface="127.0.0.1") | 142 | 0, self.site, context_factory, interface="127.0.0.1") |
642 | 127 | self.portno = self.port.getHost().port | 143 | self.portno = self.port.getHost().port |
643 | 128 | 144 | ||
644 | === added file 'txaws/testing/producers.py' | |||
645 | --- txaws/testing/producers.py 1970-01-01 00:00:00 +0000 | |||
646 | +++ txaws/testing/producers.py 2012-03-15 05:30:25 +0000 | |||
647 | @@ -0,0 +1,23 @@ | |||
648 | 1 | from zope.interface import implements | ||
649 | 2 | |||
650 | 3 | from twisted.internet.defer import succeed | ||
651 | 4 | from twisted.web.iweb import IBodyProducer | ||
652 | 5 | |||
653 | 6 | class StringBodyProducer(object): | ||
654 | 7 | implements(IBodyProducer) | ||
655 | 8 | |||
656 | 9 | def __init__(self, data): | ||
657 | 10 | self.data = data | ||
658 | 11 | self.length = len(data) | ||
659 | 12 | self.written = None | ||
660 | 13 | |||
661 | 14 | def startProducing(self, consumer): | ||
662 | 15 | consumer.write(self.data) | ||
663 | 16 | self.written = self.data | ||
664 | 17 | return succeed(None) | ||
665 | 18 | |||
666 | 19 | def pauseProducing(self): | ||
667 | 20 | pass | ||
668 | 21 | |||
669 | 22 | def stopProducing(self): | ||
670 | 23 | pass |
On Thu, Feb 9, 2012 at 10:23 PM, Drew Smathers <email address hidden> wrote: web.client. Agent" /bugs.launchpad .net/txaws/ +bug/924459 /code.launchpad .net/~djfroofy/ txaws/modernize -924459/ +merge/ 92404
> Drew Smathers has proposed merging lp:~djfroofy/txaws/modernize-924459 into lp:txaws.
>
> Requested reviews:
> Duncan McGreggor (oubiwann)
> Related bugs:
> Bug #924459 in txAWS: "Modernize txaws.client to use twisted.
> https:/
>
> For more details, see:
> https:/
>
> Need to care how much we care about compatibility since: (1) This requires Twisted >= 11.1.0 so far as I know and (2) Some public members on BaseQuery no longer exist since they are not applicable when using Agent.
I've taken a quick look at the latest code tonight, and first glance
is a happy one :-)
I've branched the code on my laptop so that I can review it on the
plane tomorrow, run the unit tests, etc.
I hope to have more for you soon.