Merge lp:~djfroofy/txaws/921421-completemultipart into lp:txaws
- 921421-completemultipart
- Merge into trunk
Proposed by
Drew Smathers
Status: | Merged | ||||
---|---|---|---|---|---|
Approved by: | Duncan McGreggor | ||||
Approved revision: | 146 | ||||
Merged at revision: | 148 | ||||
Proposed branch: | lp:~djfroofy/txaws/921421-completemultipart | ||||
Merge into: | lp:txaws | ||||
Diff against target: |
1437 lines (+784/-82) 9 files modified
txaws/client/_producers.py (+122/-0) txaws/client/base.py (+159/-14) txaws/client/tests/test_base.py (+53/-11) txaws/client/tests/test_ssl.py (+20/-4) txaws/s3/client.py (+135/-26) txaws/s3/model.py (+62/-0) txaws/s3/tests/test_client.py (+194/-27) txaws/testing/payload.py (+16/-0) txaws/testing/producers.py (+23/-0) |
||||
To merge this branch: | bzr merge lp:~djfroofy/txaws/921421-completemultipart | ||||
Related bugs: |
|
Reviewer | Review Type | Date Requested | Status |
---|---|---|---|
Duncan McGreggor | Approve | ||
Review via email: mp+92593@code.launchpad.net |
Commit message
Description of the change
Some details need to be worked out for full support, but I think a decent happy path to start with. Depends on 921419-uploadpart which depends on ... hehe.
To post a comment you must log in.
Preview Diff
[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1 | === added file 'txaws/client/_producers.py' |
2 | --- txaws/client/_producers.py 1970-01-01 00:00:00 +0000 |
3 | +++ txaws/client/_producers.py 2012-02-10 22:39:30 +0000 |
4 | @@ -0,0 +1,122 @@ |
5 | +import os |
6 | + |
7 | +from zope.interface import implements |
8 | + |
9 | +from twisted.internet import defer, task |
10 | +from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer |
11 | + |
12 | + |
13 | +# Code below for FileBodyProducer cut-and-paste from twisted source. |
14 | +# Currently this is not released so here temporarily for forward compat. |
15 | + |
16 | + |
17 | +class FileBodyProducer(object): |
18 | + """ |
19 | + L{FileBodyProducer} produces bytes from an input file object incrementally |
20 | + and writes them to a consumer. |
21 | + |
22 | + Since file-like objects cannot be read from in an event-driven manner, |
23 | + L{FileBodyProducer} uses a L{Cooperator} instance to schedule reads from |
24 | + the file. This process is also paused and resumed based on notifications |
25 | + from the L{IConsumer} provider being written to. |
26 | + |
27 | + The file is closed after it has been read, or if the producer is stopped |
28 | + early. |
29 | + |
30 | + @ivar _inputFile: Any file-like object, bytes read from which will be |
31 | + written to a consumer. |
32 | + |
33 | + @ivar _cooperate: A method like L{Cooperator.cooperate} which is used to |
34 | + schedule all reads. |
35 | + |
36 | + @ivar _readSize: The number of bytes to read from C{_inputFile} at a time. |
37 | + """ |
38 | + implements(IBodyProducer) |
39 | + |
40 | + # Python 2.4 doesn't have these symbolic constants |
41 | + _SEEK_SET = getattr(os, 'SEEK_SET', 0) |
42 | + _SEEK_END = getattr(os, 'SEEK_END', 2) |
43 | + |
44 | + def __init__(self, inputFile, cooperator=task, readSize=2 ** 16): |
45 | + self._inputFile = inputFile |
46 | + self._cooperate = cooperator.cooperate |
47 | + self._readSize = readSize |
48 | + self.length = self._determineLength(inputFile) |
49 | + |
50 | + |
51 | + def _determineLength(self, fObj): |
52 | + """ |
53 | + Determine how many bytes can be read out of C{fObj} (assuming it is not |
54 | + modified from this point on). If the determination cannot be made, |
55 | + return C{UNKNOWN_LENGTH}. |
56 | + """ |
57 | + try: |
58 | + seek = fObj.seek |
59 | + tell = fObj.tell |
60 | + except AttributeError: |
61 | + return UNKNOWN_LENGTH |
62 | + originalPosition = tell() |
63 | + seek(0, self._SEEK_END) |
64 | + end = tell() |
65 | + seek(originalPosition, self._SEEK_SET) |
66 | + return end - originalPosition |
67 | + |
68 | + |
69 | + def stopProducing(self): |
70 | + """ |
71 | + Permanently stop writing bytes from the file to the consumer by |
72 | + stopping the underlying L{CooperativeTask}. |
73 | + """ |
74 | + self._inputFile.close() |
75 | + self._task.stop() |
76 | + |
77 | + |
78 | + def startProducing(self, consumer): |
79 | + """ |
80 | + Start a cooperative task which will read bytes from the input file and |
81 | + write them to C{consumer}. Return a L{Deferred} which fires after all |
82 | + bytes have been written. |
83 | + |
84 | + @param consumer: Any L{IConsumer} provider |
85 | + """ |
86 | + self._task = self._cooperate(self._writeloop(consumer)) |
87 | + d = self._task.whenDone() |
88 | + def maybeStopped(reason): |
89 | + # IBodyProducer.startProducing's Deferred isn't support to fire if |
90 | + # stopProducing is called. |
91 | + reason.trap(task.TaskStopped) |
92 | + return defer.Deferred() |
93 | + d.addCallbacks(lambda ignored: None, maybeStopped) |
94 | + return d |
95 | + |
96 | + |
97 | + def _writeloop(self, consumer): |
98 | + """ |
99 | + Return an iterator which reads one chunk of bytes from the input file |
100 | + and writes them to the consumer for each time it is iterated. |
101 | + """ |
102 | + while True: |
103 | + bytes = self._inputFile.read(self._readSize) |
104 | + if not bytes: |
105 | + self._inputFile.close() |
106 | + break |
107 | + consumer.write(bytes) |
108 | + yield None |
109 | + |
110 | + |
111 | + def pauseProducing(self): |
112 | + """ |
113 | + Temporarily suspend copying bytes from the input file to the consumer |
114 | + by pausing the L{CooperativeTask} which drives that activity. |
115 | + """ |
116 | + self._task.pause() |
117 | + |
118 | + |
119 | + def resumeProducing(self): |
120 | + """ |
121 | + Undo the effects of a previous C{pauseProducing} and resume copying |
122 | + bytes to the consumer by resuming the L{CooperativeTask} which drives |
123 | + the write activity. |
124 | + """ |
125 | + self._task.resume() |
126 | + |
127 | |
128 | === modified file 'txaws/client/base.py' |
129 | --- txaws/client/base.py 2011-11-29 08:17:54 +0000 |
130 | +++ txaws/client/base.py 2012-02-10 22:39:30 +0000 |
131 | @@ -3,10 +3,25 @@ |
132 | except ImportError: |
133 | from xml.parsers.expat import ExpatError as ParseError |
134 | |
135 | +import warnings |
136 | +from StringIO import StringIO |
137 | + |
138 | from twisted.internet.ssl import ClientContextFactory |
139 | +from twisted.internet.protocol import Protocol |
140 | +from twisted.internet.defer import Deferred, succeed |
141 | +from twisted.python import failure |
142 | from twisted.web import http |
143 | +from twisted.web.iweb import UNKNOWN_LENGTH |
144 | from twisted.web.client import HTTPClientFactory |
145 | +from twisted.web.client import Agent |
146 | +from twisted.web.client import ResponseDone |
147 | +from twisted.web.http import NO_CONTENT |
148 | +from twisted.web.http_headers import Headers |
149 | from twisted.web.error import Error as TwistedWebError |
150 | +try: |
151 | + from twisted.web.client import FileBodyProducer |
152 | +except ImportError: |
153 | + from txaws.client._producers import FileBodyProducer |
154 | |
155 | from txaws.util import parse |
156 | from txaws.credentials import AWSCredentials |
157 | @@ -59,9 +74,10 @@ |
158 | @param query_factory: The class or function that produces a query |
159 | object for making requests to the EC2 service. |
160 | @param parser: A parser object for parsing responses from the EC2 service. |
161 | + @param receiver_factory: Factory for receiving responses from EC2 service. |
162 | """ |
163 | def __init__(self, creds=None, endpoint=None, query_factory=None, |
164 | - parser=None): |
165 | + parser=None, receiver_factory=None): |
166 | if creds is None: |
167 | creds = AWSCredentials() |
168 | if endpoint is None: |
169 | @@ -69,22 +85,109 @@ |
170 | self.creds = creds |
171 | self.endpoint = endpoint |
172 | self.query_factory = query_factory |
173 | + self.receiver_factory = receiver_factory |
174 | self.parser = parser |
175 | |
176 | +class StreamingError(Exception): |
177 | + """ |
178 | + Raised if more data or less data is received than expected. |
179 | + """ |
180 | + |
181 | + |
182 | +class StringIOBodyReceiver(Protocol): |
183 | + """ |
184 | + Simple StringIO-based HTTP response body receiver. |
185 | + |
186 | + TODO: perhaps there should be an interface specifying why |
187 | + finished (Deferred) and content_length are necessary and |
188 | + how to used them; eg. callback/errback finished on completion. |
189 | + """ |
190 | + finished = None |
191 | + content_length = None |
192 | + |
193 | + def __init__(self): |
194 | + self._buffer = StringIO() |
195 | + self._received = 0 |
196 | + |
197 | + def dataReceived(self, bytes): |
198 | + streaming = self.content_length is UNKNOWN_LENGTH |
199 | + if not streaming and (self._received > self.content_length): |
200 | + self.transport.loseConnection() |
201 | + raise StreamingError( |
202 | + "Buffer overflow - received more data than " |
203 | + "Content-Length dictated: %d" % self.content_length) |
204 | + # TODO should be some limit on how much we receive |
205 | + self._buffer.write(bytes) |
206 | + self._received += len(bytes) |
207 | + |
208 | + def connectionLost(self, reason): |
209 | + reason.trap(ResponseDone) |
210 | + d = self.finished |
211 | + self.finished = None |
212 | + streaming = self.content_length is UNKNOWN_LENGTH |
213 | + if streaming or (self._received == self.content_length): |
214 | + d.callback(self._buffer.getvalue()) |
215 | + else: |
216 | + f = failure.Failure(StreamingError("Connection lost before " |
217 | + "receiving all data")) |
218 | + d.errback(f) |
219 | + |
220 | + |
221 | +class WebClientContextFactory(ClientContextFactory): |
222 | + |
223 | + def getContext(self, hostname, port): |
224 | + return ClientContextFactory.getContext(self) |
225 | + |
226 | + |
227 | +class WebVerifyingContextFactory(VerifyingContextFactory): |
228 | + |
229 | + def getContext(self, hostname, port): |
230 | + return VerifyingContextFactory.getContext(self) |
231 | + |
232 | + |
233 | +class FakeClient(object): |
234 | + """ |
235 | + XXX |
236 | + A fake client object for some degree of backwards compatability for |
237 | + code using the client attibute on BaseQuery to check url, status |
238 | + etc. |
239 | + """ |
240 | + url = None |
241 | + status = None |
242 | |
243 | class BaseQuery(object): |
244 | |
245 | - def __init__(self, action=None, creds=None, endpoint=None, reactor=None): |
246 | + def __init__(self, action=None, creds=None, endpoint=None, reactor=None, |
247 | + body_producer=None, receiver_factory=None): |
248 | if not action: |
249 | raise TypeError("The query requires an action parameter.") |
250 | - self.factory = HTTPClientFactory |
251 | self.action = action |
252 | self.creds = creds |
253 | self.endpoint = endpoint |
254 | if reactor is None: |
255 | from twisted.internet import reactor |
256 | self.reactor = reactor |
257 | - self.client = None |
258 | + self._client = None |
259 | + self.request_headers = None |
260 | + self.response_headers = None |
261 | + self.body_producer = body_producer |
262 | + self.receiver_factory = receiver_factory or StringIOBodyReceiver |
263 | + |
264 | + @property |
265 | + def client(self): |
266 | + if self._client is None: |
267 | + self._client_deprecation_warning() |
268 | + self._client = FakeClient() |
269 | + return self._client |
270 | + |
271 | + @client.setter |
272 | + def client(self, value): |
273 | + self._client_deprecation_warning() |
274 | + self._client = value |
275 | + |
276 | + def _client_deprecation_warning(self): |
277 | + warnings.warn('The client attribute on BaseQuery is deprecated and' |
278 | + ' will go away in future release.') |
279 | |
280 | def get_page(self, url, *args, **kwds): |
281 | """ |
282 | @@ -95,16 +198,39 @@ |
283 | """ |
284 | contextFactory = None |
285 | scheme, host, port, path = parse(url) |
286 | - self.client = self.factory(url, *args, **kwds) |
287 | + data = kwds.get('postdata', None) |
288 | + self._method = method = kwds.get('method', 'GET') |
289 | + self.request_headers = self._headers(kwds.get('headers', {})) |
290 | + if (self.body_producer is None) and (data is not None): |
291 | + self.body_producer = FileBodyProducer(StringIO(data)) |
292 | if scheme == "https": |
293 | if self.endpoint.ssl_hostname_verification: |
294 | - contextFactory = VerifyingContextFactory(host) |
295 | + contextFactory = WebVerifyingContextFactory(host) |
296 | else: |
297 | - contextFactory = ClientContextFactory() |
298 | - self.reactor.connectSSL(host, port, self.client, contextFactory) |
299 | + contextFactory = WebClientContextFactory() |
300 | + agent = Agent(self.reactor, contextFactory) |
301 | + self.client.url = url |
302 | + d = agent.request(method, url, self.request_headers, |
303 | + self.body_producer) |
304 | else: |
305 | - self.reactor.connectTCP(host, port, self.client) |
306 | - return self.client.deferred |
307 | + agent = Agent(self.reactor) |
308 | + d = agent.request(method, url, self.request_headers, |
309 | + self.body_producer) |
310 | + d.addCallback(self._handle_response) |
311 | + return d |
312 | + |
313 | + def _headers(self, headers_dict): |
314 | + """ |
315 | + Convert dictionary of headers into twisted.web.client.Headers object. |
316 | + """ |
317 | + return Headers(dict((k,[v]) for (k,v) in headers_dict.items())) |
318 | + |
319 | + def _unpack_headers(self, headers): |
320 | + """ |
321 | + Unpack twisted.web.client.Headers object to dict. This is to provide |
322 | + backwards compatability. |
323 | + """ |
324 | + return dict((k,v[0]) for (k,v) in headers.getAllRawHeaders()) |
325 | |
326 | def get_request_headers(self, *args, **kwds): |
327 | """ |
328 | @@ -114,8 +240,26 @@ |
329 | The AWS S3 API depends upon setting headers. This method is provided as |
330 | a convenience for debugging issues with the S3 communications. |
331 | """ |
332 | - if self.client: |
333 | - return self.client.headers |
334 | + if self.request_headers: |
335 | + return self._unpack_headers(self.request_headers) |
336 | + |
337 | + def _handle_response(self, response): |
338 | + """ |
339 | + Handle the HTTP response by memoing the headers and then delivering |
340 | + bytes. |
341 | + """ |
342 | + self.client.status = response.code |
343 | + self.response_headers = headers = response.headers |
344 | + # XXX This workaround (which needs to be improved at that) for possible |
345 | + # bug in Twisted with new client: |
346 | + # http://twistedmatrix.com/trac/ticket/5476 |
347 | + if self._method.upper() == 'HEAD' or response.code == NO_CONTENT: |
348 | + return succeed('') |
349 | + receiver = self.receiver_factory() |
350 | + receiver.finished = d = Deferred() |
351 | + receiver.content_length = response.length |
352 | + response.deliverBody(receiver) |
353 | + return d |
354 | |
355 | def get_response_headers(self, *args, **kwargs): |
356 | """ |
357 | @@ -125,5 +269,6 @@ |
358 | The AWS S3 API depends upon setting headers. This method is used by the |
359 | head_object API call for getting a S3 object's metadata. |
360 | """ |
361 | - if self.client: |
362 | - return self.client.response_headers |
363 | + if self.response_headers: |
364 | + return self._unpack_headers(self.response_headers) |
365 | + |
366 | |
367 | === modified file 'txaws/client/tests/test_base.py' |
368 | --- txaws/client/tests/test_base.py 2012-01-26 18:43:48 +0000 |
369 | +++ txaws/client/tests/test_base.py 2012-02-10 22:39:30 +0000 |
370 | @@ -1,6 +1,9 @@ |
371 | import os |
372 | |
373 | +from zope.interface import implements |
374 | + |
375 | from twisted.internet import reactor |
376 | +from twisted.internet.defer import succeed |
377 | from twisted.internet.error import ConnectionRefusedError |
378 | from twisted.protocols.policies import WrappingFactory |
379 | from twisted.python import log |
380 | @@ -8,14 +11,16 @@ |
381 | from twisted.python.failure import Failure |
382 | from twisted.test.test_sslverify import makeCertificate |
383 | from twisted.web import server, static |
384 | +from twisted.web.iweb import IBodyProducer |
385 | from twisted.web.client import HTTPClientFactory |
386 | from twisted.web.error import Error as TwistedWebError |
387 | |
388 | from txaws.client import ssl |
389 | from txaws.client.base import BaseClient, BaseQuery, error_wrapper |
390 | +from txaws.client.base import StringIOBodyReceiver |
391 | from txaws.service import AWSServiceEndpoint |
392 | from txaws.testing.base import TXAWSTestCase |
393 | - |
394 | +from txaws.testing.producers import StringBodyProducer |
395 | |
396 | class ErrorWrapperTestCase(TXAWSTestCase): |
397 | |
398 | @@ -99,7 +104,6 @@ |
399 | |
400 | def test_creation(self): |
401 | query = BaseQuery("an action", "creds", "http://endpoint") |
402 | - self.assertEquals(query.factory, HTTPClientFactory) |
403 | self.assertEquals(query.action, "an action") |
404 | self.assertEquals(query.creds, "creds") |
405 | self.assertEquals(query.endpoint, "http://endpoint") |
406 | @@ -142,16 +146,52 @@ |
407 | def test_get_response_headers_with_client(self): |
408 | |
409 | def check_results(results): |
410 | + #self.assertEquals(sorted(results.keys()), [ |
411 | + # "accept-ranges", "content-length", "content-type", "date", |
412 | + # "last-modified", "server"]) |
413 | + # XXX I think newclient exludes content-length from headers? |
414 | + # Also the header names are capitalized ... do we need to worry |
415 | + # about backwards compat? |
416 | self.assertEquals(sorted(results.keys()), [ |
417 | - "accept-ranges", "content-length", "content-type", "date", |
418 | - "last-modified", "server"]) |
419 | - self.assertEquals(len(results.values()), 6) |
420 | + "Accept-Ranges", "Content-Type", "Date", |
421 | + "Last-Modified", "Server"]) |
422 | + self.assertEquals(len(results.values()), 5) |
423 | |
424 | query = BaseQuery("an action", "creds", "http://endpoint") |
425 | d = query.get_page(self._get_url("file")) |
426 | d.addCallback(query.get_response_headers) |
427 | return d.addCallback(check_results) |
428 | |
429 | + def test_custom_body_producer(self): |
430 | + |
431 | + def check_producer_was_used(ignore): |
432 | + self.assertEqual(producer.written, 'test data') |
433 | + |
434 | + producer = StringBodyProducer('test data') |
435 | + query = BaseQuery("an action", "creds", "http://endpoint", |
436 | + body_producer=producer) |
437 | + d = query.get_page(self._get_url("file"), method='PUT') |
438 | + return d.addCallback(check_producer_was_used) |
439 | + |
440 | + def test_custom_receiver_factory(self): |
441 | + |
442 | + class TestReceiverProtocol(StringIOBodyReceiver): |
443 | + used = False |
444 | + |
445 | + def __init__(self): |
446 | + StringIOBodyReceiver.__init__(self) |
447 | + TestReceiverProtocol.used = True |
448 | + |
449 | + def check_used(ignore): |
450 | + self.assert_(TestReceiverProtocol.used) |
451 | + |
452 | + query = BaseQuery("an action", "creds", "http://endpoint", |
453 | + receiver_factory=TestReceiverProtocol) |
454 | + d = query.get_page(self._get_url("file")) |
455 | + d.addCallback(self.assertEquals, "0123456789") |
456 | + d.addCallback(check_used) |
457 | + return d |
458 | + |
459 | # XXX for systems that don't have certs in the DEFAULT_CERT_PATH, this test |
460 | # will fail; instead, let's create some certs in a temp directory and set |
461 | # the DEFAULT_CERT_PATH to point there. |
462 | @@ -167,8 +207,9 @@ |
463 | def __init__(self): |
464 | self.connects = [] |
465 | |
466 | - def connectSSL(self, host, port, client, factory): |
467 | - self.connects.append((host, port, client, factory)) |
468 | + def connectSSL(self, host, port, factory, contextFactory, timeout, |
469 | + bindAddress): |
470 | + self.connects.append((host, port, factory, contextFactory)) |
471 | |
472 | certs = makeCertificate(O="Test Certificate", CN="something")[1] |
473 | self.patch(ssl, "_ca_certs", certs) |
474 | @@ -176,9 +217,10 @@ |
475 | endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) |
476 | query = BaseQuery("an action", "creds", endpoint, fake_reactor) |
477 | query.get_page("https://example.com/file") |
478 | - [(host, port, client, factory)] = fake_reactor.connects |
479 | + [(host, port, factory, contextFactory)] = fake_reactor.connects |
480 | self.assertEqual("example.com", host) |
481 | self.assertEqual(443, port) |
482 | - self.assertTrue(isinstance(factory, ssl.VerifyingContextFactory)) |
483 | - self.assertEqual("example.com", factory.host) |
484 | - self.assertNotEqual([], factory.caCerts) |
485 | + wrappedFactory = contextFactory._webContext |
486 | + self.assertTrue(isinstance(wrappedFactory, ssl.VerifyingContextFactory)) |
487 | + self.assertEqual("example.com", wrappedFactory.host) |
488 | + self.assertNotEqual([], wrappedFactory.caCerts) |
489 | |
490 | === modified file 'txaws/client/tests/test_ssl.py' |
491 | --- txaws/client/tests/test_ssl.py 2012-01-26 22:54:44 +0000 |
492 | +++ txaws/client/tests/test_ssl.py 2012-02-10 22:39:30 +0000 |
493 | @@ -12,6 +12,10 @@ |
494 | from twisted.python.filepath import FilePath |
495 | from twisted.test.test_sslverify import makeCertificate |
496 | from twisted.web import server, static |
497 | +try: |
498 | + from twisted.web.client import ResponseFailed |
499 | +except ImportError: |
500 | + from twisted.web._newclient import ResponseFailed |
501 | |
502 | from txaws import exception |
503 | from txaws.client import ssl |
504 | @@ -32,6 +36,11 @@ |
505 | PUBSANKEY = sibpath("public_san.ssl") |
506 | |
507 | |
508 | +class WebDefaultOpenSSLContextFactory(DefaultOpenSSLContextFactory): |
509 | + def getContext(self, hostname=None, port=None): |
510 | + return DefaultOpenSSLContextFactory.getContext(self) |
511 | + |
512 | + |
513 | class BaseQuerySSLTestCase(TXAWSTestCase): |
514 | |
515 | def setUp(self): |
516 | @@ -75,7 +84,7 @@ |
517 | The L{VerifyingContextFactory} properly allows to connect to the |
518 | endpoint if the certificates match. |
519 | """ |
520 | - context_factory = DefaultOpenSSLContextFactory(PRIVKEY, PUBKEY) |
521 | + context_factory = WebDefaultOpenSSLContextFactory(PRIVKEY, PUBKEY) |
522 | self.port = reactor.listenSSL( |
523 | 0, self.site, context_factory, interface="127.0.0.1") |
524 | self.portno = self.port.getHost().port |
525 | @@ -90,7 +99,7 @@ |
526 | The L{VerifyingContextFactory} fails with a SSL error the certificates |
527 | can't be checked. |
528 | """ |
529 | - context_factory = DefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) |
530 | + context_factory = WebDefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) |
531 | self.port = reactor.listenSSL( |
532 | 0, self.site, context_factory, interface="127.0.0.1") |
533 | self.portno = self.port.getHost().port |
534 | @@ -98,7 +107,14 @@ |
535 | endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) |
536 | query = BaseQuery("an action", "creds", endpoint) |
537 | d = query.get_page(self._get_url("file")) |
538 | - return self.assertFailure(d, SSLError) |
539 | + def fail(ignore): |
540 | + self.fail('Expected SSLError') |
541 | + def check_exception(why): |
542 | + # XXX kind of a mess here ... need to unwrap the |
543 | + # exception and check |
544 | + root_exc = why.value[0][0].value |
545 | + self.assert_(isinstance(root_exc, SSLError)) |
546 | + return d.addCallbacks(fail, check_exception) |
547 | |
548 | def test_ssl_verification_bypassed(self): |
549 | """ |
550 | @@ -121,7 +137,7 @@ |
551 | L{VerifyingContextFactory} supports checking C{subjectAltName} in the |
552 | certificate if it's available. |
553 | """ |
554 | - context_factory = DefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY) |
555 | + context_factory = WebDefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY) |
556 | self.port = reactor.listenSSL( |
557 | 0, self.site, context_factory, interface="127.0.0.1") |
558 | self.portno = self.port.getHost().port |
559 | |
560 | === modified file 'txaws/s3/client.py' |
561 | --- txaws/s3/client.py 2012-01-28 00:39:00 +0000 |
562 | +++ txaws/s3/client.py 2012-02-10 22:39:30 +0000 |
563 | @@ -23,7 +23,8 @@ |
564 | from txaws.s3.model import ( |
565 | Bucket, BucketItem, BucketListing, ItemOwner, LifecycleConfiguration, |
566 | LifecycleConfigurationRule, NotificationConfiguration, RequestPayment, |
567 | - VersioningConfiguration, WebsiteConfiguration) |
568 | + VersioningConfiguration, WebsiteConfiguration, MultipartInitiationResponse, |
569 | + MultipartCompletionResponse) |
570 | from txaws.s3.exception import S3Error |
571 | from txaws.service import AWSServiceEndpoint, S3_ENDPOINT |
572 | from txaws.util import XML, calculate_md5 |
573 | @@ -74,10 +75,12 @@ |
574 | class S3Client(BaseClient): |
575 | """A client for S3.""" |
576 | |
577 | - def __init__(self, creds=None, endpoint=None, query_factory=None): |
578 | + def __init__(self, creds=None, endpoint=None, query_factory=None, |
579 | + receiver_factory=None): |
580 | if query_factory is None: |
581 | query_factory = Query |
582 | - super(S3Client, self).__init__(creds, endpoint, query_factory) |
583 | + super(S3Client, self).__init__(creds, endpoint, query_factory, |
584 | + receiver_factory=receiver_factory) |
585 | |
586 | def list_buckets(self): |
587 | """ |
588 | @@ -87,7 +90,8 @@ |
589 | the request. |
590 | """ |
591 | query = self.query_factory( |
592 | - action="GET", creds=self.creds, endpoint=self.endpoint) |
593 | + action="GET", creds=self.creds, endpoint=self.endpoint, |
594 | + receiver_factory=self.receiver_factory) |
595 | d = query.submit() |
596 | return d.addCallback(self._parse_list_buckets) |
597 | |
598 | @@ -131,7 +135,7 @@ |
599 | """ |
600 | query = self.query_factory( |
601 | action="GET", creds=self.creds, endpoint=self.endpoint, |
602 | - bucket=bucket) |
603 | + bucket=bucket, receiver_factory=self.receiver_factory) |
604 | d = query.submit() |
605 | return d.addCallback(self._parse_get_bucket) |
606 | |
607 | @@ -174,7 +178,8 @@ |
608 | """ |
609 | query = self.query_factory(action="GET", creds=self.creds, |
610 | endpoint=self.endpoint, bucket=bucket, |
611 | - object_name="?location") |
612 | + object_name="?location", |
613 | + receiver_factory=self.receiver_factory) |
614 | d = query.submit() |
615 | return d.addCallback(self._parse_bucket_location) |
616 | |
617 | @@ -193,7 +198,8 @@ |
618 | """ |
619 | query = self.query_factory( |
620 | action='GET', creds=self.creds, endpoint=self.endpoint, |
621 | - bucket=bucket, object_name='?lifecycle') |
622 | + bucket=bucket, object_name='?lifecycle', |
623 | + receiver_factory=self.receiver_factory) |
624 | return query.submit().addCallback(self._parse_lifecycle_config) |
625 | |
626 | def _parse_lifecycle_config(self, xml_bytes): |
627 | @@ -221,7 +227,8 @@ |
628 | """ |
629 | query = self.query_factory( |
630 | action='GET', creds=self.creds, endpoint=self.endpoint, |
631 | - bucket=bucket, object_name='?website') |
632 | + bucket=bucket, object_name='?website', |
633 | + receiver_factory=self.receiver_factory) |
634 | return query.submit().addCallback(self._parse_website_config) |
635 | |
636 | def _parse_website_config(self, xml_bytes): |
637 | @@ -242,7 +249,8 @@ |
638 | """ |
639 | query = self.query_factory( |
640 | action='GET', creds=self.creds, endpoint=self.endpoint, |
641 | - bucket=bucket, object_name='?notification') |
642 | + bucket=bucket, object_name='?notification', |
643 | + receiver_factory=self.receiver_factory) |
644 | return query.submit().addCallback(self._parse_notification_config) |
645 | |
646 | def _parse_notification_config(self, xml_bytes): |
647 | @@ -262,7 +270,8 @@ |
648 | """ |
649 | query = self.query_factory( |
650 | action='GET', creds=self.creds, endpoint=self.endpoint, |
651 | - bucket=bucket, object_name='?versioning') |
652 | + bucket=bucket, object_name='?versioning', |
653 | + receiver_factory=self.receiver_factory) |
654 | return query.submit().addCallback(self._parse_versioning_config) |
655 | |
656 | def _parse_versioning_config(self, xml_bytes): |
657 | @@ -279,7 +288,8 @@ |
658 | """ |
659 | query = self.query_factory( |
660 | action='GET', creds=self.creds, endpoint=self.endpoint, |
661 | - bucket=bucket, object_name='?acl') |
662 | + bucket=bucket, object_name='?acl', |
663 | + receiver_factory=self.receiver_factory) |
664 | return query.submit().addCallback(self._parse_acl) |
665 | |
666 | def put_bucket_acl(self, bucket, access_control_policy): |
667 | @@ -289,7 +299,8 @@ |
668 | data = access_control_policy.to_xml() |
669 | query = self.query_factory( |
670 | action='PUT', creds=self.creds, endpoint=self.endpoint, |
671 | - bucket=bucket, object_name='?acl', data=data) |
672 | + bucket=bucket, object_name='?acl', data=data, |
673 | + receiver_factory=self.receiver_factory) |
674 | return query.submit().addCallback(self._parse_acl) |
675 | |
676 | def _parse_acl(self, xml_bytes): |
677 | @@ -299,8 +310,8 @@ |
678 | """ |
679 | return AccessControlPolicy.from_xml(xml_bytes) |
680 | |
681 | - def put_object(self, bucket, object_name, data, content_type=None, |
682 | - metadata={}, amz_headers={}): |
683 | + def put_object(self, bucket, object_name, data=None, content_type=None, |
684 | + metadata={}, amz_headers={}, body_producer=None): |
685 | """ |
686 | Put an object in a bucket. |
687 | |
688 | @@ -318,7 +329,8 @@ |
689 | action="PUT", creds=self.creds, endpoint=self.endpoint, |
690 | bucket=bucket, object_name=object_name, data=data, |
691 | content_type=content_type, metadata=metadata, |
692 | - amz_headers=amz_headers) |
693 | + amz_headers=amz_headers, body_producer=body_producer, |
694 | + receiver_factory=self.receiver_factory) |
695 | return query.submit() |
696 | |
697 | def copy_object(self, source_bucket, source_object_name, dest_bucket=None, |
698 | @@ -344,7 +356,8 @@ |
699 | query = self.query_factory( |
700 | action="PUT", creds=self.creds, endpoint=self.endpoint, |
701 | bucket=dest_bucket, object_name=dest_object_name, |
702 | - metadata=metadata, amz_headers=amz_headers) |
703 | + metadata=metadata, amz_headers=amz_headers, |
704 | + receiver_factory=self.receiver_factory) |
705 | return query.submit() |
706 | |
707 | def get_object(self, bucket, object_name): |
708 | @@ -353,7 +366,8 @@ |
709 | """ |
710 | query = self.query_factory( |
711 | action="GET", creds=self.creds, endpoint=self.endpoint, |
712 | - bucket=bucket, object_name=object_name) |
713 | + bucket=bucket, object_name=object_name, |
714 | + receiver_factory=self.receiver_factory) |
715 | return query.submit() |
716 | |
717 | def head_object(self, bucket, object_name): |
718 | @@ -384,7 +398,8 @@ |
719 | data = access_control_policy.to_xml() |
720 | query = self.query_factory( |
721 | action='PUT', creds=self.creds, endpoint=self.endpoint, |
722 | - bucket=bucket, object_name='%s?acl' % object_name, data=data) |
723 | + bucket=bucket, object_name='%s?acl' % object_name, data=data, |
724 | + receiver_factory=self.receiver_factory) |
725 | return query.submit().addCallback(self._parse_acl) |
726 | |
727 | def get_object_acl(self, bucket, object_name): |
728 | @@ -393,7 +408,8 @@ |
729 | """ |
730 | query = self.query_factory( |
731 | action='GET', creds=self.creds, endpoint=self.endpoint, |
732 | - bucket=bucket, object_name='%s?acl' % object_name) |
733 | + bucket=bucket, object_name='%s?acl' % object_name, |
734 | + receiver_factory=self.receiver_factory) |
735 | return query.submit().addCallback(self._parse_acl) |
736 | |
737 | def put_request_payment(self, bucket, payer): |
738 | @@ -407,7 +423,8 @@ |
739 | data = RequestPayment(payer).to_xml() |
740 | query = self.query_factory( |
741 | action="PUT", creds=self.creds, endpoint=self.endpoint, |
742 | - bucket=bucket, object_name="?requestPayment", data=data) |
743 | + bucket=bucket, object_name="?requestPayment", data=data, |
744 | + receiver_factory=self.receiver_factory) |
745 | return query.submit() |
746 | |
747 | def get_request_payment(self, bucket): |
748 | @@ -419,7 +436,8 @@ |
749 | """ |
750 | query = self.query_factory( |
751 | action="GET", creds=self.creds, endpoint=self.endpoint, |
752 | - bucket=bucket, object_name="?requestPayment") |
753 | + bucket=bucket, object_name="?requestPayment", |
754 | + receiver_factory=self.receiver_factory) |
755 | return query.submit().addCallback(self._parse_get_request_payment) |
756 | |
757 | def _parse_get_request_payment(self, xml_bytes): |
758 | @@ -429,17 +447,102 @@ |
759 | """ |
760 | return RequestPayment.from_xml(xml_bytes).payer |
761 | |
762 | + def init_multipart_upload(self, bucket, object_name, content_type=None, |
763 | + metadata={}): |
764 | + """ |
765 | + Initiate a multipart upload to a bucket. |
766 | + |
767 | + @param bucket: The name of the bucket |
768 | + @param object_name: The object name |
769 | + @param content_type: The Content-Type for the object |
770 | + @param metadata: C{dict} containing additional metadata |
771 | + @return: C{str} upload_id |
772 | + """ |
773 | + objectname_plus = '%s?uploads' % object_name |
774 | + query = self.query_factory( |
775 | + action="POST", creds=self.creds, endpoint=self.endpoint, |
776 | + bucket=bucket, object_name=objectname_plus, data='', |
777 | + content_type=content_type, metadata=metadata) |
778 | + d = query.submit() |
779 | + return d.addCallback(MultipartInitiationResponse.from_xml) |
780 | + |
781 | + def upload_part(self, bucket, object_name, upload_id, part_number, |
782 | + data=None, content_type=None, metadata={}, |
783 | + body_producer=None): |
784 | + """ |
785 | + Upload a part of data corresponding to a multipart upload. |
786 | + |
787 | + @param bucket: The bucket name |
788 | + @param object_name: The object name |
789 | + @param upload_id: The multipart upload id |
790 | + @param part_number: The part number |
791 | + @param data: Data (optional, requires body_producer if not specified) |
792 | + @param content_type: The Content-Type |
793 | + @param metadata: Additional metadata |
794 | + @param body_producer: an C{IBodyProducer} (optional, requires data if |
795 | + not specified) |
796 | + @return: the C{Deferred} from underlying query.submit() call |
797 | + """ |
798 | + parms = 'partNumber=%s&uploadId=%s' % (str(part_number), upload_id) |
799 | + objectname_plus = '%s?%s' % (object_name, parms) |
800 | + query = self.query_factory( |
801 | + action="PUT", creds=self.creds, endpoint=self.endpoint, |
802 | + bucket=bucket, object_name=objectname_plus, data=data, |
803 | + content_type=content_type, metadata=metadata, |
804 | + body_producer=body_producer, receiver_factory=self.receiver_factory) |
805 | + d = query.submit() |
806 | + return d.addCallback(query.get_response_headers) |
807 | + |
808 | + def complete_multipart_upload(self, bucket, object_name, upload_id, |
809 | + parts_list, content_type=None, metadata={}): |
810 | + """ |
811 | + Complete a multipart upload. |
812 | + |
813 | + N.B. This can be possibly be a slow operation. |
814 | + |
815 | + @param bucket: The bucket name |
816 | + @param object_name: The object name |
817 | + @param upload_id: The multipart upload id |
818 | + @param parts_list: A List of all the parts |
819 | + (2-tuples of part sequence number and etag) |
820 | + @param content_type: The Content-Type of the object |
821 | + @param metadata: C{dict} containing additional metadata |
822 | + @return: a C{Deferred} that fires after request is complete |
823 | + """ |
824 | + data = self._build_complete_multipart_upload_xml(parts_list) |
825 | + objectname_plus = '%s?uploadId=%s' % (object_name, upload_id) |
826 | + query = self.query_factory( |
827 | + action="POST", creds=self.creds, endpoint=self.endpoint, |
828 | + bucket=bucket, object_name=objectname_plus, data=data, |
829 | + content_type=content_type, metadata=metadata) |
830 | + d = query.submit() |
831 | + # TODO - handle error responses |
832 | + return d.addCallback(MultipartCompletionResponse.from_xml) |
833 | + |
834 | + def _build_complete_multipart_upload_xml(self, parts_list): |
835 | + xml = [] |
836 | + parts_list.sort(key=lambda p: int(p[0])) |
837 | + xml.append('<CompleteMultipartUpload>') |
838 | + for pt in parts_list: |
839 | + xml.append('<Part>') |
840 | + xml.append('<PartNumber>%s</PartNumber>' % pt[0]) |
841 | + xml.append('<ETag>%s</ETag>' % pt[1]) |
842 | + xml.append('</Part>') |
843 | + xml.append('</CompleteMultipartUpload>') |
844 | + return '\n'.join(xml) |
845 | + |
846 | |
847 | class Query(BaseQuery): |
848 | """A query for submission to the S3 service.""" |
849 | |
850 | def __init__(self, bucket=None, object_name=None, data="", |
851 | - content_type=None, metadata={}, amz_headers={}, *args, |
852 | - **kwargs): |
853 | + content_type=None, metadata={}, amz_headers={}, |
854 | + body_producer=None, *args, **kwargs): |
855 | super(Query, self).__init__(*args, **kwargs) |
856 | self.bucket = bucket |
857 | self.object_name = object_name |
858 | self.data = data |
859 | + self.body_producer = body_producer |
860 | self.content_type = content_type |
861 | self.metadata = metadata |
862 | self.amz_headers = amz_headers |
863 | @@ -463,9 +566,14 @@ |
864 | """ |
865 | Build the list of headers needed in order to perform S3 operations. |
866 | """ |
867 | - headers = {"Content-Length": len(self.data), |
868 | - "Content-MD5": calculate_md5(self.data), |
869 | + if self.body_producer: |
870 | + content_length = self.body_producer.length |
871 | + else: |
872 | + content_length = len(self.data) |
873 | + headers = {"Content-Length": content_length, |
874 | "Date": self.date} |
875 | + if self.body_producer is None: |
876 | + headers["Content-MD5"] = calculate_md5(self.data) |
877 | for key, value in self.metadata.iteritems(): |
878 | headers["x-amz-meta-" + key] = value |
879 | for key, value in self.amz_headers.iteritems(): |
880 | @@ -529,5 +637,6 @@ |
881 | self.endpoint, self.bucket, self.object_name) |
882 | d = self.get_page( |
883 | url_context.get_url(), method=self.action, postdata=self.data, |
884 | - headers=self.get_headers()) |
885 | + headers=self.get_headers(), body_producer=self.body_producer, |
886 | + receiver_factory=self.receiver_factory) |
887 | return d.addErrback(s3_error_wrapper) |
888 | |
889 | === modified file 'txaws/s3/model.py' |
890 | --- txaws/s3/model.py 2012-01-28 00:42:38 +0000 |
891 | +++ txaws/s3/model.py 2012-02-10 22:39:30 +0000 |
892 | @@ -150,3 +150,65 @@ |
893 | """ |
894 | root = XML(xml_bytes) |
895 | return cls(root.findtext("Payer")) |
896 | + |
897 | + |
898 | +class MultipartInitiationResponse(object): |
899 | + """ |
900 | + A response to Initiate Multipart Upload |
901 | + """ |
902 | + |
903 | + def __init__(self, bucket, object_name, upload_id): |
904 | + """ |
905 | + @param bucket: The bucket name |
906 | + @param object_name: The object name |
907 | + @param upload_id: The upload id |
908 | + """ |
909 | + self.bucket = bucket |
910 | + self.object_name = object_name |
911 | + self.upload_id = upload_id |
912 | + |
913 | + @classmethod |
914 | + def from_xml(cls, xml_bytes): |
915 | + """ |
916 | + Create an instance of this from XML bytes. |
917 | + |
918 | + @param xml_bytes: C{str} bytes of XML to parse |
919 | + @return: an instance of L{MultipartInitiationResponse} |
920 | + """ |
921 | + root = XML(xml_bytes) |
922 | + return cls(root.findtext('Bucket'), |
923 | + root.findtext('Key'), |
924 | + root.findtext('UploadId')) |
925 | + |
926 | + |
927 | +class MultipartCompletionResponse(object): |
928 | + """ |
929 | + Represents a response to Complete Multipart Upload |
930 | + """ |
931 | + |
932 | + def __init__(self, location, bucket, object_name, etag): |
933 | + """ |
934 | + @param location: The URI identifying newly created object |
935 | + @param bucket: The bucket name |
936 | + @param object_name: The object name / key |
937 | + @param etag: The entity tag |
938 | + """ |
939 | + self.location = location |
940 | + self.bucket = bucket |
941 | + self.object_name = object_name |
942 | + self.etag = etag |
943 | + |
944 | + @classmethod |
945 | + def from_xml(cls, xml_bytes): |
946 | + """ |
947 | + Create an instance of this class from XML bytes. |
948 | + |
949 | + @param xml_bytes: C{str} bytes of XML to parse |
950 | + @return: an instance of L{MultipartCompletionResponse} |
951 | + """ |
952 | + root = XML(xml_bytes) |
953 | + return cls(root.findtext('Location'), |
954 | + root.findtext('Bucket'), |
955 | + root.findtext('Key'), |
956 | + root.findtext('ETag')) |
957 | + |
958 | |
959 | === modified file 'txaws/s3/tests/test_client.py' |
960 | --- txaws/s3/tests/test_client.py 2012-01-28 00:44:53 +0000 |
961 | +++ txaws/s3/tests/test_client.py 2012-02-10 22:39:30 +0000 |
962 | @@ -9,7 +9,9 @@ |
963 | else: |
964 | s3clientSkip = None |
965 | from txaws.s3.acls import AccessControlPolicy |
966 | -from txaws.s3.model import RequestPayment |
967 | +from txaws.s3.model import (RequestPayment, MultipartInitiationResponse, |
968 | + MultipartCompletionResponse) |
969 | +from txaws.testing.producers import StringBodyProducer |
970 | from txaws.service import AWSServiceEndpoint |
971 | from txaws.testing import payload |
972 | from txaws.testing.base import TXAWSTestCase |
973 | @@ -100,7 +102,8 @@ |
974 | |
975 | class StubQuery(client.Query): |
976 | |
977 | - def __init__(query, action, creds, endpoint): |
978 | + def __init__(query, action, creds, endpoint, |
979 | + body_producer=None, receiver_factory=None): |
980 | super(StubQuery, query).__init__( |
981 | action=action, creds=creds) |
982 | self.assertEquals(action, "GET") |
983 | @@ -134,7 +137,8 @@ |
984 | |
985 | class StubQuery(client.Query): |
986 | |
987 | - def __init__(query, action, creds, endpoint, bucket=None): |
988 | + def __init__(query, action, creds, endpoint, bucket=None, |
989 | + body_producer=None, receiver_factory=None): |
990 | super(StubQuery, query).__init__( |
991 | action=action, creds=creds, bucket=bucket) |
992 | self.assertEquals(action, "PUT") |
993 | @@ -156,7 +160,8 @@ |
994 | |
995 | class StubQuery(client.Query): |
996 | |
997 | - def __init__(query, action, creds, endpoint, bucket=None): |
998 | + def __init__(query, action, creds, endpoint, bucket=None, |
999 | + body_producer=None, receiver_factory=None): |
1000 | super(StubQuery, query).__init__( |
1001 | action=action, creds=creds, bucket=bucket) |
1002 | self.assertEquals(action, "GET") |
1003 | @@ -208,7 +213,8 @@ |
1004 | class StubQuery(client.Query): |
1005 | |
1006 | def __init__(query, action, creds, endpoint, bucket=None, |
1007 | - object_name=None): |
1008 | + object_name=None, body_producer=None, |
1009 | + receiver_factory=None): |
1010 | super(StubQuery, query).__init__(action=action, creds=creds, |
1011 | bucket=bucket, |
1012 | object_name=object_name) |
1013 | @@ -243,7 +249,8 @@ |
1014 | class StubQuery(client.Query): |
1015 | |
1016 | def __init__(query, action, creds, endpoint, bucket=None, |
1017 | - object_name=None): |
1018 | + object_name=None, body_producer=None, |
1019 | + receiver_factory=None): |
1020 | super(StubQuery, query).__init__(action=action, creds=creds, |
1021 | bucket=bucket, |
1022 | object_name=object_name) |
1023 | @@ -284,7 +291,8 @@ |
1024 | class StubQuery(client.Query): |
1025 | |
1026 | def __init__(query, action, creds, endpoint, bucket=None, |
1027 | - object_name=None): |
1028 | + object_name=None, body_producer=None, |
1029 | + receiver_factory=None): |
1030 | super(StubQuery, query).__init__(action=action, creds=creds, |
1031 | bucket=bucket, |
1032 | object_name=object_name) |
1033 | @@ -323,7 +331,8 @@ |
1034 | class StubQuery(client.Query): |
1035 | |
1036 | def __init__(query, action, creds, endpoint, bucket=None, |
1037 | - object_name=None): |
1038 | + object_name=None, body_producer=None, |
1039 | + receiver_factory=None): |
1040 | super(StubQuery, query).__init__(action=action, creds=creds, |
1041 | bucket=bucket, |
1042 | object_name=object_name) |
1043 | @@ -360,7 +369,8 @@ |
1044 | class StubQuery(client.Query): |
1045 | |
1046 | def __init__(query, action, creds, endpoint, bucket=None, |
1047 | - object_name=None): |
1048 | + object_name=None, body_producer=None, |
1049 | + receiver_factory=None): |
1050 | super(StubQuery, query).__init__(action=action, creds=creds, |
1051 | bucket=bucket, |
1052 | object_name=object_name) |
1053 | @@ -396,7 +406,8 @@ |
1054 | class StubQuery(client.Query): |
1055 | |
1056 | def __init__(query, action, creds, endpoint, bucket=None, |
1057 | - object_name=None): |
1058 | + object_name=None, body_producer=None, |
1059 | + receiver_factory=None): |
1060 | super(StubQuery, query).__init__(action=action, creds=creds, |
1061 | bucket=bucket, |
1062 | object_name=object_name) |
1063 | @@ -433,7 +444,8 @@ |
1064 | class StubQuery(client.Query): |
1065 | |
1066 | def __init__(query, action, creds, endpoint, bucket=None, |
1067 | - object_name=None): |
1068 | + object_name=None, body_producer=None, |
1069 | + receiver_factory=None): |
1070 | super(StubQuery, query).__init__(action=action, creds=creds, |
1071 | bucket=bucket, |
1072 | object_name=object_name) |
1073 | @@ -473,7 +485,8 @@ |
1074 | class StubQuery(client.Query): |
1075 | |
1076 | def __init__(query, action, creds, endpoint, bucket=None, |
1077 | - object_name=None): |
1078 | + object_name=None, body_producer=None, |
1079 | + receiver_factory=None): |
1080 | super(StubQuery, query).__init__(action=action, creds=creds, |
1081 | bucket=bucket, |
1082 | object_name=object_name) |
1083 | @@ -509,7 +522,8 @@ |
1084 | class StubQuery(client.Query): |
1085 | |
1086 | def __init__(query, action, creds, endpoint, bucket=None, |
1087 | - object_name=None): |
1088 | + object_name=None, body_producer=None, |
1089 | + receiver_factory=None): |
1090 | super(StubQuery, query).__init__(action=action, creds=creds, |
1091 | bucket=bucket, |
1092 | object_name=object_name) |
1093 | @@ -546,7 +560,8 @@ |
1094 | class StubQuery(client.Query): |
1095 | |
1096 | def __init__(query, action, creds, endpoint, bucket=None, |
1097 | - object_name=None): |
1098 | + object_name=None, body_producer=None, |
1099 | + receiver_factory=None): |
1100 | super(StubQuery, query).__init__(action=action, creds=creds, |
1101 | bucket=bucket, |
1102 | object_name=object_name) |
1103 | @@ -576,7 +591,8 @@ |
1104 | |
1105 | class StubQuery(client.Query): |
1106 | |
1107 | - def __init__(query, action, creds, endpoint, bucket=None): |
1108 | + def __init__(query, action, creds, endpoint, bucket=None, |
1109 | + body_producer=None, receiver_factory=None): |
1110 | super(StubQuery, query).__init__( |
1111 | action=action, creds=creds, bucket=bucket) |
1112 | self.assertEquals(action, "DELETE") |
1113 | @@ -599,7 +615,8 @@ |
1114 | class StubQuery(client.Query): |
1115 | |
1116 | def __init__(query, action, creds, endpoint, bucket=None, |
1117 | - object_name=None, data=""): |
1118 | + object_name=None, data="", body_producer=None, |
1119 | + receiver_factory=None): |
1120 | super(StubQuery, query).__init__(action=action, creds=creds, |
1121 | bucket=bucket, |
1122 | object_name=object_name, |
1123 | @@ -630,7 +647,8 @@ |
1124 | class StubQuery(client.Query): |
1125 | |
1126 | def __init__(query, action, creds, endpoint, bucket=None, |
1127 | - object_name=None, data=""): |
1128 | + object_name=None, data="", receiver_factory=None, |
1129 | + body_producer=None): |
1130 | super(StubQuery, query).__init__(action=action, creds=creds, |
1131 | bucket=bucket, |
1132 | object_name=object_name, |
1133 | @@ -665,7 +683,7 @@ |
1134 | |
1135 | def __init__(query, action, creds, endpoint, bucket=None, |
1136 | object_name=None, data=None, content_type=None, |
1137 | - metadata=None): |
1138 | + metadata=None, body_producer=None, receiver_factory=None): |
1139 | super(StubQuery, query).__init__( |
1140 | action=action, creds=creds, bucket=bucket, |
1141 | object_name=object_name, data=data, |
1142 | @@ -701,7 +719,7 @@ |
1143 | |
1144 | def __init__(query, action, creds, endpoint, bucket=None, |
1145 | object_name=None, data=None, content_type=None, |
1146 | - metadata=None): |
1147 | + metadata=None, body_producer=None, receiver_factory=None): |
1148 | super(StubQuery, query).__init__( |
1149 | action=action, creds=creds, bucket=bucket, |
1150 | object_name=object_name, data=data, |
1151 | @@ -730,7 +748,8 @@ |
1152 | |
1153 | def __init__(query, action, creds, endpoint, bucket=None, |
1154 | object_name=None, data=None, content_type=None, |
1155 | - metadata=None, amz_headers=None): |
1156 | + metadata=None, amz_headers=None, body_producer=None, |
1157 | + receiver_factory=None): |
1158 | super(StubQuery, query).__init__( |
1159 | action=action, creds=creds, bucket=bucket, |
1160 | object_name=object_name, data=data, |
1161 | @@ -756,6 +775,42 @@ |
1162 | metadata={"key": "some meta data"}, |
1163 | amz_headers={"acl": "public-read"}) |
1164 | |
1165 | + def test_put_object_with_custom_body_producer(self): |
1166 | + |
1167 | + class StubQuery(client.Query): |
1168 | + |
1169 | + def __init__(query, action, creds, endpoint, bucket=None, |
1170 | + object_name=None, data=None, content_type=None, |
1171 | + metadata=None, amz_headers=None, body_producer=None, |
1172 | + receiver_factory=None): |
1173 | + super(StubQuery, query).__init__( |
1174 | + action=action, creds=creds, bucket=bucket, |
1175 | + object_name=object_name, data=data, |
1176 | + content_type=content_type, metadata=metadata, |
1177 | + amz_headers=amz_headers, body_producer=body_producer) |
1178 | + self.assertEqual(action, "PUT") |
1179 | + self.assertEqual(creds.access_key, "foo") |
1180 | + self.assertEqual(creds.secret_key, "bar") |
1181 | + self.assertEqual(query.bucket, "mybucket") |
1182 | + self.assertEqual(query.object_name, "objectname") |
1183 | + self.assertEqual(query.content_type, "text/plain") |
1184 | + self.assertEqual(query.metadata, {"key": "some meta data"}) |
1185 | + self.assertEqual(query.amz_headers, {"acl": "public-read"}) |
1186 | + self.assertIdentical(body_producer, string_producer) |
1187 | + |
1188 | + def submit(query): |
1189 | + return succeed(None) |
1190 | + |
1191 | + |
1192 | + string_producer = StringBodyProducer("some data") |
1193 | + creds = AWSCredentials("foo", "bar") |
1194 | + s3 = client.S3Client(creds, query_factory=StubQuery) |
1195 | + return s3.put_object("mybucket", "objectname", |
1196 | + content_type="text/plain", |
1197 | + metadata={"key": "some meta data"}, |
1198 | + amz_headers={"acl": "public-read"}, |
1199 | + body_producer=string_producer) |
1200 | + |
1201 | def test_copy_object(self): |
1202 | """ |
1203 | L{S3Client.copy_object} creates a L{Query} to copy an object from one |
1204 | @@ -766,7 +821,8 @@ |
1205 | |
1206 | def __init__(query, action, creds, endpoint, bucket=None, |
1207 | object_name=None, data=None, content_type=None, |
1208 | - metadata=None, amz_headers=None): |
1209 | + metadata=None, amz_headers=None, body_producer=None, |
1210 | + receiver_factory=None): |
1211 | super(StubQuery, query).__init__( |
1212 | action=action, creds=creds, bucket=bucket, |
1213 | object_name=object_name, data=data, |
1214 | @@ -798,7 +854,8 @@ |
1215 | |
1216 | def __init__(query, action, creds, endpoint, bucket=None, |
1217 | object_name=None, data=None, content_type=None, |
1218 | - metadata=None, amz_headers=None): |
1219 | + metadata=None, amz_headers=None, body_producer=None, |
1220 | + receiver_factory=None): |
1221 | super(StubQuery, query).__init__( |
1222 | action=action, creds=creds, bucket=bucket, |
1223 | object_name=object_name, data=data, |
1224 | @@ -822,7 +879,7 @@ |
1225 | |
1226 | def __init__(query, action, creds, endpoint, bucket=None, |
1227 | object_name=None, data=None, content_type=None, |
1228 | - metadata=None): |
1229 | + metadata=None, body_producer=None, receiver_factory=None): |
1230 | super(StubQuery, query).__init__( |
1231 | action=action, creds=creds, bucket=bucket, |
1232 | object_name=object_name, data=data, |
1233 | @@ -846,7 +903,7 @@ |
1234 | |
1235 | def __init__(query, action, creds, endpoint, bucket=None, |
1236 | object_name=None, data=None, content_type=None, |
1237 | - metadata=None): |
1238 | + metadata=None, body_producer=None, receiver_factory=None): |
1239 | super(StubQuery, query).__init__( |
1240 | action=action, creds=creds, bucket=bucket, |
1241 | object_name=object_name, data=data, |
1242 | @@ -869,7 +926,8 @@ |
1243 | class StubQuery(client.Query): |
1244 | |
1245 | def __init__(query, action, creds, endpoint, bucket=None, |
1246 | - object_name=None, data=""): |
1247 | + object_name=None, data="", body_producer=None, |
1248 | + receiver_factory=None): |
1249 | super(StubQuery, query).__init__(action=action, creds=creds, |
1250 | bucket=bucket, |
1251 | object_name=object_name, |
1252 | @@ -902,7 +960,8 @@ |
1253 | class StubQuery(client.Query): |
1254 | |
1255 | def __init__(query, action, creds, endpoint, bucket=None, |
1256 | - object_name=None, data=""): |
1257 | + object_name=None, data="", body_producer=None, |
1258 | + receiver_factory=None): |
1259 | super(StubQuery, query).__init__(action=action, creds=creds, |
1260 | bucket=bucket, |
1261 | object_name=object_name, |
1262 | @@ -926,6 +985,113 @@ |
1263 | deferred = s3.get_object_acl("mybucket", "myobject") |
1264 | return deferred.addCallback(check_result) |
1265 | |
1266 | + def test_init_multipart_upload(self): |
1267 | + |
1268 | + class StubQuery(client.Query): |
1269 | + |
1270 | + def __init__(query, action, creds, endpoint, bucket=None, |
1271 | + object_name=None, data="", body_producer=None, |
1272 | + content_type=None, receiver_factory=None, metadata={}): |
1273 | + super(StubQuery, query).__init__(action=action, creds=creds, |
1274 | + bucket=bucket, |
1275 | + object_name=object_name, |
1276 | + data=data) |
1277 | + self.assertEquals(action, "POST") |
1278 | + self.assertEqual(creds.access_key, "foo") |
1279 | + self.assertEqual(creds.secret_key, "bar") |
1280 | + self.assertEqual(query.bucket, "example-bucket") |
1281 | + self.assertEqual(query.object_name, "example-object?uploads") |
1282 | + self.assertEqual(query.data, "") |
1283 | + self.assertEqual(query.metadata, {}) |
1284 | + |
1285 | + def submit(query, url_context=None): |
1286 | + return succeed(payload.sample_s3_init_multipart_upload_result) |
1287 | + |
1288 | + |
1289 | + def check_result(result): |
1290 | + self.assert_(isinstance(result, MultipartInitiationResponse)) |
1291 | + self.assertEqual(result.bucket, "example-bucket") |
1292 | + self.assertEqual(result.object_name, "example-object") |
1293 | + self.assertEqual(result.upload_id, "deadbeef") |
1294 | + |
1295 | + creds = AWSCredentials("foo", "bar") |
1296 | + s3 = client.S3Client(creds, query_factory=StubQuery) |
1297 | + deferred = s3.init_multipart_upload("example-bucket", "example-object") |
1298 | + return deferred.addCallback(check_result) |
1299 | + |
1300 | + def test_upload_part(self): |
1301 | + |
1302 | + class StubQuery(client.Query): |
1303 | + |
1304 | + def __init__(query, action, creds, endpoint, bucket=None, |
1305 | + object_name=None, data="", body_producer=None, |
1306 | + content_type=None, receiver_factory=None, metadata={}): |
1307 | + super(StubQuery, query).__init__(action=action, creds=creds, |
1308 | + bucket=bucket, |
1309 | + object_name=object_name, |
1310 | + data=data) |
1311 | + self.assertEquals(action, "PUT") |
1312 | + self.assertEqual(creds.access_key, "foo") |
1313 | + self.assertEqual(creds.secret_key, "bar") |
1314 | + self.assertEqual(query.bucket, "example-bucket") |
1315 | + self.assertEqual(query.object_name, |
1316 | + "example-object?partNumber=3&uploadId=testid") |
1317 | + self.assertEqual(query.data, "some data") |
1318 | + self.assertEqual(query.metadata, {}) |
1319 | + |
1320 | + def submit(query, url_context=None): |
1321 | + return succeed(None) |
1322 | + |
1323 | + creds = AWSCredentials("foo", "bar") |
1324 | + s3 = client.S3Client(creds, query_factory=StubQuery) |
1325 | + return s3.upload_part("example-bucket", "example-object", "testid", 3, |
1326 | + "some data") |
1327 | + |
1328 | + def test_complete_multipart_upload(self): |
1329 | + |
1330 | + class StubQuery(client.Query): |
1331 | + |
1332 | + def __init__(query, action, creds, endpoint, bucket=None, |
1333 | + object_name=None, data="", body_producer=None, |
1334 | + content_type=None, receiver_factory=None, metadata={}): |
1335 | + super(StubQuery, query).__init__(action=action, creds=creds, |
1336 | + bucket=bucket, |
1337 | + object_name=object_name, |
1338 | + data=data) |
1339 | + self.assertEquals(action, "POST") |
1340 | + self.assertEqual(creds.access_key, "foo") |
1341 | + self.assertEqual(creds.secret_key, "bar") |
1342 | + self.assertEqual(query.bucket, "example-bucket") |
1343 | + self.assertEqual(query.object_name, |
1344 | + "example-object?uploadId=testid") |
1345 | + self.assertEqual(query.data, "<CompleteMultipartUpload>\n" |
1346 | + "<Part>\n<PartNumber>1</PartNumber>\n<ETag>a</ETag>\n" |
1347 | + "</Part>\n<Part>\n<PartNumber>2</PartNumber>\n" |
1348 | + "<ETag>b</ETag>\n</Part>\n</CompleteMultipartUpload>") |
1349 | + self.assertEqual(query.metadata, {}) |
1350 | + |
1351 | + def submit(query, url_context=None): |
1352 | + return succeed( |
1353 | + payload.sample_s3_complete_multipart_upload_result) |
1354 | + |
1355 | + |
1356 | + def check_result(result): |
1357 | + self.assert_(isinstance(result, MultipartCompletionResponse)) |
1358 | + self.assertEqual(result.bucket, "example-bucket") |
1359 | + self.assertEqual(result.object_name, "example-object") |
1360 | + self.assertEqual(result.location, |
1361 | + "http://example-bucket.s3.amazonaws.com/example-object") |
1362 | + self.assertEqual(result.etag, |
1363 | + '"3858f62230ac3c915f300c664312c11f-9"') |
1364 | + |
1365 | + creds = AWSCredentials("foo", "bar") |
1366 | + s3 = client.S3Client(creds, query_factory=StubQuery) |
1367 | + deferred = s3.complete_multipart_upload("example-bucket", |
1368 | + "example-object", |
1369 | + "testid", [(1, "a"), (2, "b")]) |
1370 | + return deferred.addCallback(check_result) |
1371 | + |
1372 | + |
1373 | S3ClientTestCase.skip = s3clientSkip |
1374 | |
1375 | |
1376 | @@ -1077,7 +1243,8 @@ |
1377 | """ |
1378 | class StubQuery(client.Query): |
1379 | |
1380 | - def __init__(query, action, creds, endpoint, bucket): |
1381 | + def __init__(query, action, creds, endpoint, bucket, |
1382 | + body_producer=None, receiver_factory=None): |
1383 | super(StubQuery, query).__init__( |
1384 | action=action, creds=creds, bucket=bucket) |
1385 | self.assertEquals(action, "GET") |
1386 | |
1387 | === modified file 'txaws/testing/payload.py' |
1388 | --- txaws/testing/payload.py 2012-01-28 00:39:00 +0000 |
1389 | +++ txaws/testing/payload.py 2012-02-10 22:39:30 +0000 |
1390 | @@ -1085,3 +1085,19 @@ |
1391 | <Status>Enabled</Status> |
1392 | <MfaDelete>Disabled</MfaDelete> |
1393 | </VersioningConfiguration>""" |
1394 | + |
1395 | +sample_s3_init_multipart_upload_result = """\ |
1396 | +<InitiateMultipartUploadResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"> |
1397 | + <Bucket>example-bucket</Bucket> |
1398 | + <Key>example-object</Key> |
1399 | + <UploadId>deadbeef</UploadId> |
1400 | +</InitiateMultipartUploadResult>""" |
1401 | + |
1402 | +sample_s3_complete_multipart_upload_result = """\ |
1403 | +<?xml version="1.0" encoding="UTF-8"?> |
1404 | +<CompleteMultipartUploadResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"> |
1405 | + <Location>http://example-bucket.s3.amazonaws.com/example-object</Location> |
1406 | + <Bucket>example-bucket</Bucket> |
1407 | + <Key>example-object</Key> |
1408 | + <ETag>"3858f62230ac3c915f300c664312c11f-9"</ETag> |
1409 | +</CompleteMultipartUploadResult>""" |
1410 | |
1411 | === added file 'txaws/testing/producers.py' |
1412 | --- txaws/testing/producers.py 1970-01-01 00:00:00 +0000 |
1413 | +++ txaws/testing/producers.py 2012-02-10 22:39:30 +0000 |
1414 | @@ -0,0 +1,23 @@ |
1415 | +from zope.interface import implements |
1416 | + |
1417 | +from twisted.internet.defer import succeed |
1418 | +from twisted.web.iweb import IBodyProducer |
1419 | + |
1420 | +class StringBodyProducer(object): |
1421 | + implements(IBodyProducer) |
1422 | + |
1423 | + def __init__(self, data): |
1424 | + self.data = data |
1425 | + self.length = len(data) |
1426 | + self.written = None |
1427 | + |
1428 | + def startProducing(self, consumer): |
1429 | + consumer.write(self.data) |
1430 | + self.written = self.data |
1431 | + return succeed(None) |
1432 | + |
1433 | + def pauseProducing(self): |
1434 | + pass |
1435 | + |
1436 | + def stopProducing(self): |
1437 | + pass |
Merge away!