Merge lp:~djfroofy/txaws/921419-uploadpart into lp:txaws
- 921419-uploadpart
- Merge into trunk
Proposed by
Drew Smathers
Status: | Merged | ||||
---|---|---|---|---|---|
Approved by: | Duncan McGreggor | ||||
Approved revision: | 145 | ||||
Merged at revision: | 147 | ||||
Proposed branch: | lp:~djfroofy/txaws/921419-uploadpart | ||||
Merge into: | lp:txaws | ||||
Diff against target: |
1301 lines (+648/-82) 9 files modified
txaws/client/_producers.py (+122/-0) txaws/client/base.py (+159/-14) txaws/client/tests/test_base.py (+53/-11) txaws/client/tests/test_ssl.py (+20/-4) txaws/s3/client.py (+85/-26) txaws/s3/model.py (+29/-0) txaws/s3/tests/test_client.py (+148/-27) txaws/testing/payload.py (+9/-0) txaws/testing/producers.py (+23/-0) |
||||
To merge this branch: | bzr merge lp:~djfroofy/txaws/921419-uploadpart | ||||
Related bugs: |
|
Reviewer | Review Type | Date Requested | Status |
---|---|---|---|
Duncan McGreggor | Approve | ||
Review via email: mp+92583@code.launchpad.net |
Commit message
Description of the change
This depends on initmultipart branch which in turn has other branch deps.
To post a comment you must log in.
Preview Diff
[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1 | === added file 'txaws/client/_producers.py' |
2 | --- txaws/client/_producers.py 1970-01-01 00:00:00 +0000 |
3 | +++ txaws/client/_producers.py 2012-02-10 21:24:20 +0000 |
4 | @@ -0,0 +1,122 @@ |
5 | +import os |
6 | + |
7 | +from zope.interface import implements |
8 | + |
9 | +from twisted.internet import defer, task |
10 | +from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer |
11 | + |
12 | + |
13 | +# Code below for FileBodyProducer cut-and-paste from twisted source. |
14 | +# Currently this is not released so here temporarily for forward compat. |
15 | + |
16 | + |
17 | +class FileBodyProducer(object): |
18 | + """ |
19 | + L{FileBodyProducer} produces bytes from an input file object incrementally |
20 | + and writes them to a consumer. |
21 | + |
22 | + Since file-like objects cannot be read from in an event-driven manner, |
23 | + L{FileBodyProducer} uses a L{Cooperator} instance to schedule reads from |
24 | + the file. This process is also paused and resumed based on notifications |
25 | + from the L{IConsumer} provider being written to. |
26 | + |
27 | + The file is closed after it has been read, or if the producer is stopped |
28 | + early. |
29 | + |
30 | + @ivar _inputFile: Any file-like object, bytes read from which will be |
31 | + written to a consumer. |
32 | + |
33 | + @ivar _cooperate: A method like L{Cooperator.cooperate} which is used to |
34 | + schedule all reads. |
35 | + |
36 | + @ivar _readSize: The number of bytes to read from C{_inputFile} at a time. |
37 | + """ |
38 | + implements(IBodyProducer) |
39 | + |
40 | + # Python 2.4 doesn't have these symbolic constants |
41 | + _SEEK_SET = getattr(os, 'SEEK_SET', 0) |
42 | + _SEEK_END = getattr(os, 'SEEK_END', 2) |
43 | + |
44 | + def __init__(self, inputFile, cooperator=task, readSize=2 ** 16): |
45 | + self._inputFile = inputFile |
46 | + self._cooperate = cooperator.cooperate |
47 | + self._readSize = readSize |
48 | + self.length = self._determineLength(inputFile) |
49 | + |
50 | + |
51 | + def _determineLength(self, fObj): |
52 | + """ |
53 | + Determine how many bytes can be read out of C{fObj} (assuming it is not |
54 | + modified from this point on). If the determination cannot be made, |
55 | + return C{UNKNOWN_LENGTH}. |
56 | + """ |
57 | + try: |
58 | + seek = fObj.seek |
59 | + tell = fObj.tell |
60 | + except AttributeError: |
61 | + return UNKNOWN_LENGTH |
62 | + originalPosition = tell() |
63 | + seek(0, self._SEEK_END) |
64 | + end = tell() |
65 | + seek(originalPosition, self._SEEK_SET) |
66 | + return end - originalPosition |
67 | + |
68 | + |
69 | + def stopProducing(self): |
70 | + """ |
71 | + Permanently stop writing bytes from the file to the consumer by |
72 | + stopping the underlying L{CooperativeTask}. |
73 | + """ |
74 | + self._inputFile.close() |
75 | + self._task.stop() |
76 | + |
77 | + |
78 | + def startProducing(self, consumer): |
79 | + """ |
80 | + Start a cooperative task which will read bytes from the input file and |
81 | + write them to C{consumer}. Return a L{Deferred} which fires after all |
82 | + bytes have been written. |
83 | + |
84 | + @param consumer: Any L{IConsumer} provider |
85 | + """ |
86 | + self._task = self._cooperate(self._writeloop(consumer)) |
87 | + d = self._task.whenDone() |
88 | + def maybeStopped(reason): |
89 | + # IBodyProducer.startProducing's Deferred isn't support to fire if |
90 | + # stopProducing is called. |
91 | + reason.trap(task.TaskStopped) |
92 | + return defer.Deferred() |
93 | + d.addCallbacks(lambda ignored: None, maybeStopped) |
94 | + return d |
95 | + |
96 | + |
97 | + def _writeloop(self, consumer): |
98 | + """ |
99 | + Return an iterator which reads one chunk of bytes from the input file |
100 | + and writes them to the consumer for each time it is iterated. |
101 | + """ |
102 | + while True: |
103 | + bytes = self._inputFile.read(self._readSize) |
104 | + if not bytes: |
105 | + self._inputFile.close() |
106 | + break |
107 | + consumer.write(bytes) |
108 | + yield None |
109 | + |
110 | + |
111 | + def pauseProducing(self): |
112 | + """ |
113 | + Temporarily suspend copying bytes from the input file to the consumer |
114 | + by pausing the L{CooperativeTask} which drives that activity. |
115 | + """ |
116 | + self._task.pause() |
117 | + |
118 | + |
119 | + def resumeProducing(self): |
120 | + """ |
121 | + Undo the effects of a previous C{pauseProducing} and resume copying |
122 | + bytes to the consumer by resuming the L{CooperativeTask} which drives |
123 | + the write activity. |
124 | + """ |
125 | + self._task.resume() |
126 | + |
127 | |
128 | === modified file 'txaws/client/base.py' |
129 | --- txaws/client/base.py 2011-11-29 08:17:54 +0000 |
130 | +++ txaws/client/base.py 2012-02-10 21:24:20 +0000 |
131 | @@ -3,10 +3,25 @@ |
132 | except ImportError: |
133 | from xml.parsers.expat import ExpatError as ParseError |
134 | |
135 | +import warnings |
136 | +from StringIO import StringIO |
137 | + |
138 | from twisted.internet.ssl import ClientContextFactory |
139 | +from twisted.internet.protocol import Protocol |
140 | +from twisted.internet.defer import Deferred, succeed |
141 | +from twisted.python import failure |
142 | from twisted.web import http |
143 | +from twisted.web.iweb import UNKNOWN_LENGTH |
144 | from twisted.web.client import HTTPClientFactory |
145 | +from twisted.web.client import Agent |
146 | +from twisted.web.client import ResponseDone |
147 | +from twisted.web.http import NO_CONTENT |
148 | +from twisted.web.http_headers import Headers |
149 | from twisted.web.error import Error as TwistedWebError |
150 | +try: |
151 | + from twisted.web.client import FileBodyProducer |
152 | +except ImportError: |
153 | + from txaws.client._producers import FileBodyProducer |
154 | |
155 | from txaws.util import parse |
156 | from txaws.credentials import AWSCredentials |
157 | @@ -59,9 +74,10 @@ |
158 | @param query_factory: The class or function that produces a query |
159 | object for making requests to the EC2 service. |
160 | @param parser: A parser object for parsing responses from the EC2 service. |
161 | + @param receiver_factory: Factory for receiving responses from EC2 service. |
162 | """ |
163 | def __init__(self, creds=None, endpoint=None, query_factory=None, |
164 | - parser=None): |
165 | + parser=None, receiver_factory=None): |
166 | if creds is None: |
167 | creds = AWSCredentials() |
168 | if endpoint is None: |
169 | @@ -69,22 +85,109 @@ |
170 | self.creds = creds |
171 | self.endpoint = endpoint |
172 | self.query_factory = query_factory |
173 | + self.receiver_factory = receiver_factory |
174 | self.parser = parser |
175 | |
176 | +class StreamingError(Exception): |
177 | + """ |
178 | + Raised if more data or less data is received than expected. |
179 | + """ |
180 | + |
181 | + |
182 | +class StringIOBodyReceiver(Protocol): |
183 | + """ |
184 | + Simple StringIO-based HTTP response body receiver. |
185 | + |
186 | + TODO: perhaps there should be an interface specifying why |
187 | + finished (Deferred) and content_length are necessary and |
188 | + how to used them; eg. callback/errback finished on completion. |
189 | + """ |
190 | + finished = None |
191 | + content_length = None |
192 | + |
193 | + def __init__(self): |
194 | + self._buffer = StringIO() |
195 | + self._received = 0 |
196 | + |
197 | + def dataReceived(self, bytes): |
198 | + streaming = self.content_length is UNKNOWN_LENGTH |
199 | + if not streaming and (self._received > self.content_length): |
200 | + self.transport.loseConnection() |
201 | + raise StreamingError( |
202 | + "Buffer overflow - received more data than " |
203 | + "Content-Length dictated: %d" % self.content_length) |
204 | + # TODO should be some limit on how much we receive |
205 | + self._buffer.write(bytes) |
206 | + self._received += len(bytes) |
207 | + |
208 | + def connectionLost(self, reason): |
209 | + reason.trap(ResponseDone) |
210 | + d = self.finished |
211 | + self.finished = None |
212 | + streaming = self.content_length is UNKNOWN_LENGTH |
213 | + if streaming or (self._received == self.content_length): |
214 | + d.callback(self._buffer.getvalue()) |
215 | + else: |
216 | + f = failure.Failure(StreamingError("Connection lost before " |
217 | + "receiving all data")) |
218 | + d.errback(f) |
219 | + |
220 | + |
221 | +class WebClientContextFactory(ClientContextFactory): |
222 | + |
223 | + def getContext(self, hostname, port): |
224 | + return ClientContextFactory.getContext(self) |
225 | + |
226 | + |
227 | +class WebVerifyingContextFactory(VerifyingContextFactory): |
228 | + |
229 | + def getContext(self, hostname, port): |
230 | + return VerifyingContextFactory.getContext(self) |
231 | + |
232 | + |
233 | +class FakeClient(object): |
234 | + """ |
235 | + XXX |
236 | + A fake client object for some degree of backwards compatability for |
237 | + code using the client attibute on BaseQuery to check url, status |
238 | + etc. |
239 | + """ |
240 | + url = None |
241 | + status = None |
242 | |
243 | class BaseQuery(object): |
244 | |
245 | - def __init__(self, action=None, creds=None, endpoint=None, reactor=None): |
246 | + def __init__(self, action=None, creds=None, endpoint=None, reactor=None, |
247 | + body_producer=None, receiver_factory=None): |
248 | if not action: |
249 | raise TypeError("The query requires an action parameter.") |
250 | - self.factory = HTTPClientFactory |
251 | self.action = action |
252 | self.creds = creds |
253 | self.endpoint = endpoint |
254 | if reactor is None: |
255 | from twisted.internet import reactor |
256 | self.reactor = reactor |
257 | - self.client = None |
258 | + self._client = None |
259 | + self.request_headers = None |
260 | + self.response_headers = None |
261 | + self.body_producer = body_producer |
262 | + self.receiver_factory = receiver_factory or StringIOBodyReceiver |
263 | + |
264 | + @property |
265 | + def client(self): |
266 | + if self._client is None: |
267 | + self._client_deprecation_warning() |
268 | + self._client = FakeClient() |
269 | + return self._client |
270 | + |
271 | + @client.setter |
272 | + def client(self, value): |
273 | + self._client_deprecation_warning() |
274 | + self._client = value |
275 | + |
276 | + def _client_deprecation_warning(self): |
277 | + warnings.warn('The client attribute on BaseQuery is deprecated and' |
278 | + ' will go away in future release.') |
279 | |
280 | def get_page(self, url, *args, **kwds): |
281 | """ |
282 | @@ -95,16 +198,39 @@ |
283 | """ |
284 | contextFactory = None |
285 | scheme, host, port, path = parse(url) |
286 | - self.client = self.factory(url, *args, **kwds) |
287 | + data = kwds.get('postdata', None) |
288 | + self._method = method = kwds.get('method', 'GET') |
289 | + self.request_headers = self._headers(kwds.get('headers', {})) |
290 | + if (self.body_producer is None) and (data is not None): |
291 | + self.body_producer = FileBodyProducer(StringIO(data)) |
292 | if scheme == "https": |
293 | if self.endpoint.ssl_hostname_verification: |
294 | - contextFactory = VerifyingContextFactory(host) |
295 | + contextFactory = WebVerifyingContextFactory(host) |
296 | else: |
297 | - contextFactory = ClientContextFactory() |
298 | - self.reactor.connectSSL(host, port, self.client, contextFactory) |
299 | + contextFactory = WebClientContextFactory() |
300 | + agent = Agent(self.reactor, contextFactory) |
301 | + self.client.url = url |
302 | + d = agent.request(method, url, self.request_headers, |
303 | + self.body_producer) |
304 | else: |
305 | - self.reactor.connectTCP(host, port, self.client) |
306 | - return self.client.deferred |
307 | + agent = Agent(self.reactor) |
308 | + d = agent.request(method, url, self.request_headers, |
309 | + self.body_producer) |
310 | + d.addCallback(self._handle_response) |
311 | + return d |
312 | + |
313 | + def _headers(self, headers_dict): |
314 | + """ |
315 | + Convert dictionary of headers into twisted.web.client.Headers object. |
316 | + """ |
317 | + return Headers(dict((k,[v]) for (k,v) in headers_dict.items())) |
318 | + |
319 | + def _unpack_headers(self, headers): |
320 | + """ |
321 | + Unpack twisted.web.client.Headers object to dict. This is to provide |
322 | + backwards compatability. |
323 | + """ |
324 | + return dict((k,v[0]) for (k,v) in headers.getAllRawHeaders()) |
325 | |
326 | def get_request_headers(self, *args, **kwds): |
327 | """ |
328 | @@ -114,8 +240,26 @@ |
329 | The AWS S3 API depends upon setting headers. This method is provided as |
330 | a convenience for debugging issues with the S3 communications. |
331 | """ |
332 | - if self.client: |
333 | - return self.client.headers |
334 | + if self.request_headers: |
335 | + return self._unpack_headers(self.request_headers) |
336 | + |
337 | + def _handle_response(self, response): |
338 | + """ |
339 | + Handle the HTTP response by memoing the headers and then delivering |
340 | + bytes. |
341 | + """ |
342 | + self.client.status = response.code |
343 | + self.response_headers = headers = response.headers |
344 | + # XXX This workaround (which needs to be improved at that) for possible |
345 | + # bug in Twisted with new client: |
346 | + # http://twistedmatrix.com/trac/ticket/5476 |
347 | + if self._method.upper() == 'HEAD' or response.code == NO_CONTENT: |
348 | + return succeed('') |
349 | + receiver = self.receiver_factory() |
350 | + receiver.finished = d = Deferred() |
351 | + receiver.content_length = response.length |
352 | + response.deliverBody(receiver) |
353 | + return d |
354 | |
355 | def get_response_headers(self, *args, **kwargs): |
356 | """ |
357 | @@ -125,5 +269,6 @@ |
358 | The AWS S3 API depends upon setting headers. This method is used by the |
359 | head_object API call for getting a S3 object's metadata. |
360 | """ |
361 | - if self.client: |
362 | - return self.client.response_headers |
363 | + if self.response_headers: |
364 | + return self._unpack_headers(self.response_headers) |
365 | + |
366 | |
367 | === modified file 'txaws/client/tests/test_base.py' |
368 | --- txaws/client/tests/test_base.py 2012-01-26 18:43:48 +0000 |
369 | +++ txaws/client/tests/test_base.py 2012-02-10 21:24:20 +0000 |
370 | @@ -1,6 +1,9 @@ |
371 | import os |
372 | |
373 | +from zope.interface import implements |
374 | + |
375 | from twisted.internet import reactor |
376 | +from twisted.internet.defer import succeed |
377 | from twisted.internet.error import ConnectionRefusedError |
378 | from twisted.protocols.policies import WrappingFactory |
379 | from twisted.python import log |
380 | @@ -8,14 +11,16 @@ |
381 | from twisted.python.failure import Failure |
382 | from twisted.test.test_sslverify import makeCertificate |
383 | from twisted.web import server, static |
384 | +from twisted.web.iweb import IBodyProducer |
385 | from twisted.web.client import HTTPClientFactory |
386 | from twisted.web.error import Error as TwistedWebError |
387 | |
388 | from txaws.client import ssl |
389 | from txaws.client.base import BaseClient, BaseQuery, error_wrapper |
390 | +from txaws.client.base import StringIOBodyReceiver |
391 | from txaws.service import AWSServiceEndpoint |
392 | from txaws.testing.base import TXAWSTestCase |
393 | - |
394 | +from txaws.testing.producers import StringBodyProducer |
395 | |
396 | class ErrorWrapperTestCase(TXAWSTestCase): |
397 | |
398 | @@ -99,7 +104,6 @@ |
399 | |
400 | def test_creation(self): |
401 | query = BaseQuery("an action", "creds", "http://endpoint") |
402 | - self.assertEquals(query.factory, HTTPClientFactory) |
403 | self.assertEquals(query.action, "an action") |
404 | self.assertEquals(query.creds, "creds") |
405 | self.assertEquals(query.endpoint, "http://endpoint") |
406 | @@ -142,16 +146,52 @@ |
407 | def test_get_response_headers_with_client(self): |
408 | |
409 | def check_results(results): |
410 | + #self.assertEquals(sorted(results.keys()), [ |
411 | + # "accept-ranges", "content-length", "content-type", "date", |
412 | + # "last-modified", "server"]) |
413 | + # XXX I think newclient exludes content-length from headers? |
414 | + # Also the header names are capitalized ... do we need to worry |
415 | + # about backwards compat? |
416 | self.assertEquals(sorted(results.keys()), [ |
417 | - "accept-ranges", "content-length", "content-type", "date", |
418 | - "last-modified", "server"]) |
419 | - self.assertEquals(len(results.values()), 6) |
420 | + "Accept-Ranges", "Content-Type", "Date", |
421 | + "Last-Modified", "Server"]) |
422 | + self.assertEquals(len(results.values()), 5) |
423 | |
424 | query = BaseQuery("an action", "creds", "http://endpoint") |
425 | d = query.get_page(self._get_url("file")) |
426 | d.addCallback(query.get_response_headers) |
427 | return d.addCallback(check_results) |
428 | |
429 | + def test_custom_body_producer(self): |
430 | + |
431 | + def check_producer_was_used(ignore): |
432 | + self.assertEqual(producer.written, 'test data') |
433 | + |
434 | + producer = StringBodyProducer('test data') |
435 | + query = BaseQuery("an action", "creds", "http://endpoint", |
436 | + body_producer=producer) |
437 | + d = query.get_page(self._get_url("file"), method='PUT') |
438 | + return d.addCallback(check_producer_was_used) |
439 | + |
440 | + def test_custom_receiver_factory(self): |
441 | + |
442 | + class TestReceiverProtocol(StringIOBodyReceiver): |
443 | + used = False |
444 | + |
445 | + def __init__(self): |
446 | + StringIOBodyReceiver.__init__(self) |
447 | + TestReceiverProtocol.used = True |
448 | + |
449 | + def check_used(ignore): |
450 | + self.assert_(TestReceiverProtocol.used) |
451 | + |
452 | + query = BaseQuery("an action", "creds", "http://endpoint", |
453 | + receiver_factory=TestReceiverProtocol) |
454 | + d = query.get_page(self._get_url("file")) |
455 | + d.addCallback(self.assertEquals, "0123456789") |
456 | + d.addCallback(check_used) |
457 | + return d |
458 | + |
459 | # XXX for systems that don't have certs in the DEFAULT_CERT_PATH, this test |
460 | # will fail; instead, let's create some certs in a temp directory and set |
461 | # the DEFAULT_CERT_PATH to point there. |
462 | @@ -167,8 +207,9 @@ |
463 | def __init__(self): |
464 | self.connects = [] |
465 | |
466 | - def connectSSL(self, host, port, client, factory): |
467 | - self.connects.append((host, port, client, factory)) |
468 | + def connectSSL(self, host, port, factory, contextFactory, timeout, |
469 | + bindAddress): |
470 | + self.connects.append((host, port, factory, contextFactory)) |
471 | |
472 | certs = makeCertificate(O="Test Certificate", CN="something")[1] |
473 | self.patch(ssl, "_ca_certs", certs) |
474 | @@ -176,9 +217,10 @@ |
475 | endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) |
476 | query = BaseQuery("an action", "creds", endpoint, fake_reactor) |
477 | query.get_page("https://example.com/file") |
478 | - [(host, port, client, factory)] = fake_reactor.connects |
479 | + [(host, port, factory, contextFactory)] = fake_reactor.connects |
480 | self.assertEqual("example.com", host) |
481 | self.assertEqual(443, port) |
482 | - self.assertTrue(isinstance(factory, ssl.VerifyingContextFactory)) |
483 | - self.assertEqual("example.com", factory.host) |
484 | - self.assertNotEqual([], factory.caCerts) |
485 | + wrappedFactory = contextFactory._webContext |
486 | + self.assertTrue(isinstance(wrappedFactory, ssl.VerifyingContextFactory)) |
487 | + self.assertEqual("example.com", wrappedFactory.host) |
488 | + self.assertNotEqual([], wrappedFactory.caCerts) |
489 | |
490 | === modified file 'txaws/client/tests/test_ssl.py' |
491 | --- txaws/client/tests/test_ssl.py 2012-01-26 22:54:44 +0000 |
492 | +++ txaws/client/tests/test_ssl.py 2012-02-10 21:24:20 +0000 |
493 | @@ -12,6 +12,10 @@ |
494 | from twisted.python.filepath import FilePath |
495 | from twisted.test.test_sslverify import makeCertificate |
496 | from twisted.web import server, static |
497 | +try: |
498 | + from twisted.web.client import ResponseFailed |
499 | +except ImportError: |
500 | + from twisted.web._newclient import ResponseFailed |
501 | |
502 | from txaws import exception |
503 | from txaws.client import ssl |
504 | @@ -32,6 +36,11 @@ |
505 | PUBSANKEY = sibpath("public_san.ssl") |
506 | |
507 | |
508 | +class WebDefaultOpenSSLContextFactory(DefaultOpenSSLContextFactory): |
509 | + def getContext(self, hostname=None, port=None): |
510 | + return DefaultOpenSSLContextFactory.getContext(self) |
511 | + |
512 | + |
513 | class BaseQuerySSLTestCase(TXAWSTestCase): |
514 | |
515 | def setUp(self): |
516 | @@ -75,7 +84,7 @@ |
517 | The L{VerifyingContextFactory} properly allows to connect to the |
518 | endpoint if the certificates match. |
519 | """ |
520 | - context_factory = DefaultOpenSSLContextFactory(PRIVKEY, PUBKEY) |
521 | + context_factory = WebDefaultOpenSSLContextFactory(PRIVKEY, PUBKEY) |
522 | self.port = reactor.listenSSL( |
523 | 0, self.site, context_factory, interface="127.0.0.1") |
524 | self.portno = self.port.getHost().port |
525 | @@ -90,7 +99,7 @@ |
526 | The L{VerifyingContextFactory} fails with a SSL error the certificates |
527 | can't be checked. |
528 | """ |
529 | - context_factory = DefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) |
530 | + context_factory = WebDefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) |
531 | self.port = reactor.listenSSL( |
532 | 0, self.site, context_factory, interface="127.0.0.1") |
533 | self.portno = self.port.getHost().port |
534 | @@ -98,7 +107,14 @@ |
535 | endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) |
536 | query = BaseQuery("an action", "creds", endpoint) |
537 | d = query.get_page(self._get_url("file")) |
538 | - return self.assertFailure(d, SSLError) |
539 | + def fail(ignore): |
540 | + self.fail('Expected SSLError') |
541 | + def check_exception(why): |
542 | + # XXX kind of a mess here ... need to unwrap the |
543 | + # exception and check |
544 | + root_exc = why.value[0][0].value |
545 | + self.assert_(isinstance(root_exc, SSLError)) |
546 | + return d.addCallbacks(fail, check_exception) |
547 | |
548 | def test_ssl_verification_bypassed(self): |
549 | """ |
550 | @@ -121,7 +137,7 @@ |
551 | L{VerifyingContextFactory} supports checking C{subjectAltName} in the |
552 | certificate if it's available. |
553 | """ |
554 | - context_factory = DefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY) |
555 | + context_factory = WebDefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY) |
556 | self.port = reactor.listenSSL( |
557 | 0, self.site, context_factory, interface="127.0.0.1") |
558 | self.portno = self.port.getHost().port |
559 | |
560 | === modified file 'txaws/s3/client.py' |
561 | --- txaws/s3/client.py 2012-01-28 00:39:00 +0000 |
562 | +++ txaws/s3/client.py 2012-02-10 21:24:20 +0000 |
563 | @@ -23,7 +23,7 @@ |
564 | from txaws.s3.model import ( |
565 | Bucket, BucketItem, BucketListing, ItemOwner, LifecycleConfiguration, |
566 | LifecycleConfigurationRule, NotificationConfiguration, RequestPayment, |
567 | - VersioningConfiguration, WebsiteConfiguration) |
568 | + VersioningConfiguration, WebsiteConfiguration, MultipartInitiationResponse) |
569 | from txaws.s3.exception import S3Error |
570 | from txaws.service import AWSServiceEndpoint, S3_ENDPOINT |
571 | from txaws.util import XML, calculate_md5 |
572 | @@ -74,10 +74,12 @@ |
573 | class S3Client(BaseClient): |
574 | """A client for S3.""" |
575 | |
576 | - def __init__(self, creds=None, endpoint=None, query_factory=None): |
577 | + def __init__(self, creds=None, endpoint=None, query_factory=None, |
578 | + receiver_factory=None): |
579 | if query_factory is None: |
580 | query_factory = Query |
581 | - super(S3Client, self).__init__(creds, endpoint, query_factory) |
582 | + super(S3Client, self).__init__(creds, endpoint, query_factory, |
583 | + receiver_factory=receiver_factory) |
584 | |
585 | def list_buckets(self): |
586 | """ |
587 | @@ -87,7 +89,8 @@ |
588 | the request. |
589 | """ |
590 | query = self.query_factory( |
591 | - action="GET", creds=self.creds, endpoint=self.endpoint) |
592 | + action="GET", creds=self.creds, endpoint=self.endpoint, |
593 | + receiver_factory=self.receiver_factory) |
594 | d = query.submit() |
595 | return d.addCallback(self._parse_list_buckets) |
596 | |
597 | @@ -131,7 +134,7 @@ |
598 | """ |
599 | query = self.query_factory( |
600 | action="GET", creds=self.creds, endpoint=self.endpoint, |
601 | - bucket=bucket) |
602 | + bucket=bucket, receiver_factory=self.receiver_factory) |
603 | d = query.submit() |
604 | return d.addCallback(self._parse_get_bucket) |
605 | |
606 | @@ -174,7 +177,8 @@ |
607 | """ |
608 | query = self.query_factory(action="GET", creds=self.creds, |
609 | endpoint=self.endpoint, bucket=bucket, |
610 | - object_name="?location") |
611 | + object_name="?location", |
612 | + receiver_factory=self.receiver_factory) |
613 | d = query.submit() |
614 | return d.addCallback(self._parse_bucket_location) |
615 | |
616 | @@ -193,7 +197,8 @@ |
617 | """ |
618 | query = self.query_factory( |
619 | action='GET', creds=self.creds, endpoint=self.endpoint, |
620 | - bucket=bucket, object_name='?lifecycle') |
621 | + bucket=bucket, object_name='?lifecycle', |
622 | + receiver_factory=self.receiver_factory) |
623 | return query.submit().addCallback(self._parse_lifecycle_config) |
624 | |
625 | def _parse_lifecycle_config(self, xml_bytes): |
626 | @@ -221,7 +226,8 @@ |
627 | """ |
628 | query = self.query_factory( |
629 | action='GET', creds=self.creds, endpoint=self.endpoint, |
630 | - bucket=bucket, object_name='?website') |
631 | + bucket=bucket, object_name='?website', |
632 | + receiver_factory=self.receiver_factory) |
633 | return query.submit().addCallback(self._parse_website_config) |
634 | |
635 | def _parse_website_config(self, xml_bytes): |
636 | @@ -242,7 +248,8 @@ |
637 | """ |
638 | query = self.query_factory( |
639 | action='GET', creds=self.creds, endpoint=self.endpoint, |
640 | - bucket=bucket, object_name='?notification') |
641 | + bucket=bucket, object_name='?notification', |
642 | + receiver_factory=self.receiver_factory) |
643 | return query.submit().addCallback(self._parse_notification_config) |
644 | |
645 | def _parse_notification_config(self, xml_bytes): |
646 | @@ -262,7 +269,8 @@ |
647 | """ |
648 | query = self.query_factory( |
649 | action='GET', creds=self.creds, endpoint=self.endpoint, |
650 | - bucket=bucket, object_name='?versioning') |
651 | + bucket=bucket, object_name='?versioning', |
652 | + receiver_factory=self.receiver_factory) |
653 | return query.submit().addCallback(self._parse_versioning_config) |
654 | |
655 | def _parse_versioning_config(self, xml_bytes): |
656 | @@ -279,7 +287,8 @@ |
657 | """ |
658 | query = self.query_factory( |
659 | action='GET', creds=self.creds, endpoint=self.endpoint, |
660 | - bucket=bucket, object_name='?acl') |
661 | + bucket=bucket, object_name='?acl', |
662 | + receiver_factory=self.receiver_factory) |
663 | return query.submit().addCallback(self._parse_acl) |
664 | |
665 | def put_bucket_acl(self, bucket, access_control_policy): |
666 | @@ -289,7 +298,8 @@ |
667 | data = access_control_policy.to_xml() |
668 | query = self.query_factory( |
669 | action='PUT', creds=self.creds, endpoint=self.endpoint, |
670 | - bucket=bucket, object_name='?acl', data=data) |
671 | + bucket=bucket, object_name='?acl', data=data, |
672 | + receiver_factory=self.receiver_factory) |
673 | return query.submit().addCallback(self._parse_acl) |
674 | |
675 | def _parse_acl(self, xml_bytes): |
676 | @@ -299,8 +309,8 @@ |
677 | """ |
678 | return AccessControlPolicy.from_xml(xml_bytes) |
679 | |
680 | - def put_object(self, bucket, object_name, data, content_type=None, |
681 | - metadata={}, amz_headers={}): |
682 | + def put_object(self, bucket, object_name, data=None, content_type=None, |
683 | + metadata={}, amz_headers={}, body_producer=None): |
684 | """ |
685 | Put an object in a bucket. |
686 | |
687 | @@ -318,7 +328,8 @@ |
688 | action="PUT", creds=self.creds, endpoint=self.endpoint, |
689 | bucket=bucket, object_name=object_name, data=data, |
690 | content_type=content_type, metadata=metadata, |
691 | - amz_headers=amz_headers) |
692 | + amz_headers=amz_headers, body_producer=body_producer, |
693 | + receiver_factory=self.receiver_factory) |
694 | return query.submit() |
695 | |
696 | def copy_object(self, source_bucket, source_object_name, dest_bucket=None, |
697 | @@ -344,7 +355,8 @@ |
698 | query = self.query_factory( |
699 | action="PUT", creds=self.creds, endpoint=self.endpoint, |
700 | bucket=dest_bucket, object_name=dest_object_name, |
701 | - metadata=metadata, amz_headers=amz_headers) |
702 | + metadata=metadata, amz_headers=amz_headers, |
703 | + receiver_factory=self.receiver_factory) |
704 | return query.submit() |
705 | |
706 | def get_object(self, bucket, object_name): |
707 | @@ -353,7 +365,8 @@ |
708 | """ |
709 | query = self.query_factory( |
710 | action="GET", creds=self.creds, endpoint=self.endpoint, |
711 | - bucket=bucket, object_name=object_name) |
712 | + bucket=bucket, object_name=object_name, |
713 | + receiver_factory=self.receiver_factory) |
714 | return query.submit() |
715 | |
716 | def head_object(self, bucket, object_name): |
717 | @@ -384,7 +397,8 @@ |
718 | data = access_control_policy.to_xml() |
719 | query = self.query_factory( |
720 | action='PUT', creds=self.creds, endpoint=self.endpoint, |
721 | - bucket=bucket, object_name='%s?acl' % object_name, data=data) |
722 | + bucket=bucket, object_name='%s?acl' % object_name, data=data, |
723 | + receiver_factory=self.receiver_factory) |
724 | return query.submit().addCallback(self._parse_acl) |
725 | |
726 | def get_object_acl(self, bucket, object_name): |
727 | @@ -393,7 +407,8 @@ |
728 | """ |
729 | query = self.query_factory( |
730 | action='GET', creds=self.creds, endpoint=self.endpoint, |
731 | - bucket=bucket, object_name='%s?acl' % object_name) |
732 | + bucket=bucket, object_name='%s?acl' % object_name, |
733 | + receiver_factory=self.receiver_factory) |
734 | return query.submit().addCallback(self._parse_acl) |
735 | |
736 | def put_request_payment(self, bucket, payer): |
737 | @@ -407,7 +422,8 @@ |
738 | data = RequestPayment(payer).to_xml() |
739 | query = self.query_factory( |
740 | action="PUT", creds=self.creds, endpoint=self.endpoint, |
741 | - bucket=bucket, object_name="?requestPayment", data=data) |
742 | + bucket=bucket, object_name="?requestPayment", data=data, |
743 | + receiver_factory=self.receiver_factory) |
744 | return query.submit() |
745 | |
746 | def get_request_payment(self, bucket): |
747 | @@ -419,7 +435,8 @@ |
748 | """ |
749 | query = self.query_factory( |
750 | action="GET", creds=self.creds, endpoint=self.endpoint, |
751 | - bucket=bucket, object_name="?requestPayment") |
752 | + bucket=bucket, object_name="?requestPayment", |
753 | + receiver_factory=self.receiver_factory) |
754 | return query.submit().addCallback(self._parse_get_request_payment) |
755 | |
756 | def _parse_get_request_payment(self, xml_bytes): |
757 | @@ -429,17 +446,53 @@ |
758 | """ |
759 | return RequestPayment.from_xml(xml_bytes).payer |
760 | |
761 | + def init_multipart_upload(self, bucket, object_name, content_type=None, |
762 | + metadata={}): |
763 | + """ |
764 | + Initiate a multipart upload to a bucket. |
765 | + |
766 | + @param bucket: The name of the bucket |
767 | + @param object_name: The object name |
768 | + @param content_type: The Content-Type for the object |
769 | + @param metadata: C{dict} containing additional metadata |
770 | + @return: C{str} upload_id |
771 | + """ |
772 | + objectname_plus = '%s?uploads' % object_name |
773 | + query = self.query_factory( |
774 | + action="POST", creds=self.creds, endpoint=self.endpoint, |
775 | + bucket=bucket, object_name=objectname_plus, data='', |
776 | + content_type=content_type, metadata=metadata) |
777 | + d = query.submit() |
778 | + return d.addCallback(MultipartInitiationResponse.from_xml) |
779 | + |
780 | + def upload_part(self, bucket, object_name, upload_id, part_number, data=None, |
781 | + content_type=None, metadata={}, body_producer=None): |
782 | + """ |
783 | + Upload a part of data correcsponding to a multipart upload. |
784 | + |
785 | + @return: the C{Deferred} from underlying query.submit() call |
786 | + """ |
787 | + parms = 'partNumber=%s&uploadId=%s' % (str(part_number), upload_id) |
788 | + objectname_plus = '%s?%s' % (object_name, parms) |
789 | + query = self.query_factory( |
790 | + action="PUT", creds=self.creds, endpoint=self.endpoint, |
791 | + bucket=bucket, object_name=objectname_plus, data=data, |
792 | + content_type=content_type, metadata=metadata, body_producer=body_producer, |
793 | + receiver_factory=self.receiver_factory) |
794 | + d = query.submit() |
795 | + return d.addCallback(query.get_response_headers) |
796 | |
797 | class Query(BaseQuery): |
798 | """A query for submission to the S3 service.""" |
799 | |
800 | def __init__(self, bucket=None, object_name=None, data="", |
801 | - content_type=None, metadata={}, amz_headers={}, *args, |
802 | - **kwargs): |
803 | + content_type=None, metadata={}, amz_headers={}, |
804 | + body_producer=None, *args, **kwargs): |
805 | super(Query, self).__init__(*args, **kwargs) |
806 | self.bucket = bucket |
807 | self.object_name = object_name |
808 | self.data = data |
809 | + self.body_producer = body_producer |
810 | self.content_type = content_type |
811 | self.metadata = metadata |
812 | self.amz_headers = amz_headers |
813 | @@ -463,9 +516,14 @@ |
814 | """ |
815 | Build the list of headers needed in order to perform S3 operations. |
816 | """ |
817 | - headers = {"Content-Length": len(self.data), |
818 | - "Content-MD5": calculate_md5(self.data), |
819 | + if self.body_producer: |
820 | + content_length = self.body_producer.length |
821 | + else: |
822 | + content_length = len(self.data) |
823 | + headers = {"Content-Length": content_length, |
824 | "Date": self.date} |
825 | + if self.body_producer is None: |
826 | + headers["Content-MD5"] = calculate_md5(self.data) |
827 | for key, value in self.metadata.iteritems(): |
828 | headers["x-amz-meta-" + key] = value |
829 | for key, value in self.amz_headers.iteritems(): |
830 | @@ -529,5 +587,6 @@ |
831 | self.endpoint, self.bucket, self.object_name) |
832 | d = self.get_page( |
833 | url_context.get_url(), method=self.action, postdata=self.data, |
834 | - headers=self.get_headers()) |
835 | + headers=self.get_headers(), body_producer=self.body_producer, |
836 | + receiver_factory=self.receiver_factory) |
837 | return d.addErrback(s3_error_wrapper) |
838 | |
839 | === modified file 'txaws/s3/model.py' |
840 | --- txaws/s3/model.py 2012-01-28 00:42:38 +0000 |
841 | +++ txaws/s3/model.py 2012-02-10 21:24:20 +0000 |
842 | @@ -150,3 +150,32 @@ |
843 | """ |
844 | root = XML(xml_bytes) |
845 | return cls(root.findtext("Payer")) |
846 | + |
847 | + |
848 | +class MultipartInitiationResponse(object): |
849 | + """ |
850 | + A response to Initiate Multipart Upload |
851 | + """ |
852 | + |
853 | + def __init__(self, bucket, object_name, upload_id): |
854 | + """ |
855 | + @param bucket: The bucket name |
856 | + @param object_name: The object name |
857 | + @param upload_id: The upload id |
858 | + """ |
859 | + self.bucket = bucket |
860 | + self.object_name = object_name |
861 | + self.upload_id = upload_id |
862 | + |
863 | + @classmethod |
864 | + def from_xml(cls, xml_bytes): |
865 | + """ |
866 | + Create an instance of this from XML bytes. |
867 | + |
868 | + @param xml_bytes: C{str} bytes of XML to parse |
869 | + @return: and instance of L{MultipartInitiationResponse} |
870 | + """ |
871 | + root = XML(xml_bytes) |
872 | + return cls(root.findtext('Bucket'), |
873 | + root.findtext('Key'), |
874 | + root.findtext('UploadId')) |
875 | |
876 | === modified file 'txaws/s3/tests/test_client.py' |
877 | --- txaws/s3/tests/test_client.py 2012-01-28 00:44:53 +0000 |
878 | +++ txaws/s3/tests/test_client.py 2012-02-10 21:24:20 +0000 |
879 | @@ -9,7 +9,8 @@ |
880 | else: |
881 | s3clientSkip = None |
882 | from txaws.s3.acls import AccessControlPolicy |
883 | -from txaws.s3.model import RequestPayment |
884 | +from txaws.s3.model import RequestPayment, MultipartInitiationResponse |
885 | +from txaws.testing.producers import StringBodyProducer |
886 | from txaws.service import AWSServiceEndpoint |
887 | from txaws.testing import payload |
888 | from txaws.testing.base import TXAWSTestCase |
889 | @@ -100,7 +101,8 @@ |
890 | |
891 | class StubQuery(client.Query): |
892 | |
893 | - def __init__(query, action, creds, endpoint): |
894 | + def __init__(query, action, creds, endpoint, |
895 | + body_producer=None, receiver_factory=None): |
896 | super(StubQuery, query).__init__( |
897 | action=action, creds=creds) |
898 | self.assertEquals(action, "GET") |
899 | @@ -134,7 +136,8 @@ |
900 | |
901 | class StubQuery(client.Query): |
902 | |
903 | - def __init__(query, action, creds, endpoint, bucket=None): |
904 | + def __init__(query, action, creds, endpoint, bucket=None, |
905 | + body_producer=None, receiver_factory=None): |
906 | super(StubQuery, query).__init__( |
907 | action=action, creds=creds, bucket=bucket) |
908 | self.assertEquals(action, "PUT") |
909 | @@ -156,7 +159,8 @@ |
910 | |
911 | class StubQuery(client.Query): |
912 | |
913 | - def __init__(query, action, creds, endpoint, bucket=None): |
914 | + def __init__(query, action, creds, endpoint, bucket=None, |
915 | + body_producer=None, receiver_factory=None): |
916 | super(StubQuery, query).__init__( |
917 | action=action, creds=creds, bucket=bucket) |
918 | self.assertEquals(action, "GET") |
919 | @@ -208,7 +212,8 @@ |
920 | class StubQuery(client.Query): |
921 | |
922 | def __init__(query, action, creds, endpoint, bucket=None, |
923 | - object_name=None): |
924 | + object_name=None, body_producer=None, |
925 | + receiver_factory=None): |
926 | super(StubQuery, query).__init__(action=action, creds=creds, |
927 | bucket=bucket, |
928 | object_name=object_name) |
929 | @@ -243,7 +248,8 @@ |
930 | class StubQuery(client.Query): |
931 | |
932 | def __init__(query, action, creds, endpoint, bucket=None, |
933 | - object_name=None): |
934 | + object_name=None, body_producer=None, |
935 | + receiver_factory=None): |
936 | super(StubQuery, query).__init__(action=action, creds=creds, |
937 | bucket=bucket, |
938 | object_name=object_name) |
939 | @@ -284,7 +290,8 @@ |
940 | class StubQuery(client.Query): |
941 | |
942 | def __init__(query, action, creds, endpoint, bucket=None, |
943 | - object_name=None): |
944 | + object_name=None, body_producer=None, |
945 | + receiver_factory=None): |
946 | super(StubQuery, query).__init__(action=action, creds=creds, |
947 | bucket=bucket, |
948 | object_name=object_name) |
949 | @@ -323,7 +330,8 @@ |
950 | class StubQuery(client.Query): |
951 | |
952 | def __init__(query, action, creds, endpoint, bucket=None, |
953 | - object_name=None): |
954 | + object_name=None, body_producer=None, |
955 | + receiver_factory=None): |
956 | super(StubQuery, query).__init__(action=action, creds=creds, |
957 | bucket=bucket, |
958 | object_name=object_name) |
959 | @@ -360,7 +368,8 @@ |
960 | class StubQuery(client.Query): |
961 | |
962 | def __init__(query, action, creds, endpoint, bucket=None, |
963 | - object_name=None): |
964 | + object_name=None, body_producer=None, |
965 | + receiver_factory=None): |
966 | super(StubQuery, query).__init__(action=action, creds=creds, |
967 | bucket=bucket, |
968 | object_name=object_name) |
969 | @@ -396,7 +405,8 @@ |
970 | class StubQuery(client.Query): |
971 | |
972 | def __init__(query, action, creds, endpoint, bucket=None, |
973 | - object_name=None): |
974 | + object_name=None, body_producer=None, |
975 | + receiver_factory=None): |
976 | super(StubQuery, query).__init__(action=action, creds=creds, |
977 | bucket=bucket, |
978 | object_name=object_name) |
979 | @@ -433,7 +443,8 @@ |
980 | class StubQuery(client.Query): |
981 | |
982 | def __init__(query, action, creds, endpoint, bucket=None, |
983 | - object_name=None): |
984 | + object_name=None, body_producer=None, |
985 | + receiver_factory=None): |
986 | super(StubQuery, query).__init__(action=action, creds=creds, |
987 | bucket=bucket, |
988 | object_name=object_name) |
989 | @@ -473,7 +484,8 @@ |
990 | class StubQuery(client.Query): |
991 | |
992 | def __init__(query, action, creds, endpoint, bucket=None, |
993 | - object_name=None): |
994 | + object_name=None, body_producer=None, |
995 | + receiver_factory=None): |
996 | super(StubQuery, query).__init__(action=action, creds=creds, |
997 | bucket=bucket, |
998 | object_name=object_name) |
999 | @@ -509,7 +521,8 @@ |
1000 | class StubQuery(client.Query): |
1001 | |
1002 | def __init__(query, action, creds, endpoint, bucket=None, |
1003 | - object_name=None): |
1004 | + object_name=None, body_producer=None, |
1005 | + receiver_factory=None): |
1006 | super(StubQuery, query).__init__(action=action, creds=creds, |
1007 | bucket=bucket, |
1008 | object_name=object_name) |
1009 | @@ -546,7 +559,8 @@ |
1010 | class StubQuery(client.Query): |
1011 | |
1012 | def __init__(query, action, creds, endpoint, bucket=None, |
1013 | - object_name=None): |
1014 | + object_name=None, body_producer=None, |
1015 | + receiver_factory=None): |
1016 | super(StubQuery, query).__init__(action=action, creds=creds, |
1017 | bucket=bucket, |
1018 | object_name=object_name) |
1019 | @@ -576,7 +590,8 @@ |
1020 | |
1021 | class StubQuery(client.Query): |
1022 | |
1023 | - def __init__(query, action, creds, endpoint, bucket=None): |
1024 | + def __init__(query, action, creds, endpoint, bucket=None, |
1025 | + body_producer=None, receiver_factory=None): |
1026 | super(StubQuery, query).__init__( |
1027 | action=action, creds=creds, bucket=bucket) |
1028 | self.assertEquals(action, "DELETE") |
1029 | @@ -599,7 +614,8 @@ |
1030 | class StubQuery(client.Query): |
1031 | |
1032 | def __init__(query, action, creds, endpoint, bucket=None, |
1033 | - object_name=None, data=""): |
1034 | + object_name=None, data="", body_producer=None, |
1035 | + receiver_factory=None): |
1036 | super(StubQuery, query).__init__(action=action, creds=creds, |
1037 | bucket=bucket, |
1038 | object_name=object_name, |
1039 | @@ -630,7 +646,8 @@ |
1040 | class StubQuery(client.Query): |
1041 | |
1042 | def __init__(query, action, creds, endpoint, bucket=None, |
1043 | - object_name=None, data=""): |
1044 | + object_name=None, data="", receiver_factory=None, |
1045 | + body_producer=None): |
1046 | super(StubQuery, query).__init__(action=action, creds=creds, |
1047 | bucket=bucket, |
1048 | object_name=object_name, |
1049 | @@ -665,7 +682,7 @@ |
1050 | |
1051 | def __init__(query, action, creds, endpoint, bucket=None, |
1052 | object_name=None, data=None, content_type=None, |
1053 | - metadata=None): |
1054 | + metadata=None, body_producer=None, receiver_factory=None): |
1055 | super(StubQuery, query).__init__( |
1056 | action=action, creds=creds, bucket=bucket, |
1057 | object_name=object_name, data=data, |
1058 | @@ -701,7 +718,7 @@ |
1059 | |
1060 | def __init__(query, action, creds, endpoint, bucket=None, |
1061 | object_name=None, data=None, content_type=None, |
1062 | - metadata=None): |
1063 | + metadata=None, body_producer=None, receiver_factory=None): |
1064 | super(StubQuery, query).__init__( |
1065 | action=action, creds=creds, bucket=bucket, |
1066 | object_name=object_name, data=data, |
1067 | @@ -730,7 +747,8 @@ |
1068 | |
1069 | def __init__(query, action, creds, endpoint, bucket=None, |
1070 | object_name=None, data=None, content_type=None, |
1071 | - metadata=None, amz_headers=None): |
1072 | + metadata=None, amz_headers=None, body_producer=None, |
1073 | + receiver_factory=None): |
1074 | super(StubQuery, query).__init__( |
1075 | action=action, creds=creds, bucket=bucket, |
1076 | object_name=object_name, data=data, |
1077 | @@ -756,6 +774,42 @@ |
1078 | metadata={"key": "some meta data"}, |
1079 | amz_headers={"acl": "public-read"}) |
1080 | |
1081 | + def test_put_object_with_custom_body_producer(self): |
1082 | + |
1083 | + class StubQuery(client.Query): |
1084 | + |
1085 | + def __init__(query, action, creds, endpoint, bucket=None, |
1086 | + object_name=None, data=None, content_type=None, |
1087 | + metadata=None, amz_headers=None, body_producer=None, |
1088 | + receiver_factory=None): |
1089 | + super(StubQuery, query).__init__( |
1090 | + action=action, creds=creds, bucket=bucket, |
1091 | + object_name=object_name, data=data, |
1092 | + content_type=content_type, metadata=metadata, |
1093 | + amz_headers=amz_headers, body_producer=body_producer) |
1094 | + self.assertEqual(action, "PUT") |
1095 | + self.assertEqual(creds.access_key, "foo") |
1096 | + self.assertEqual(creds.secret_key, "bar") |
1097 | + self.assertEqual(query.bucket, "mybucket") |
1098 | + self.assertEqual(query.object_name, "objectname") |
1099 | + self.assertEqual(query.content_type, "text/plain") |
1100 | + self.assertEqual(query.metadata, {"key": "some meta data"}) |
1101 | + self.assertEqual(query.amz_headers, {"acl": "public-read"}) |
1102 | + self.assertIdentical(body_producer, string_producer) |
1103 | + |
1104 | + def submit(query): |
1105 | + return succeed(None) |
1106 | + |
1107 | + |
1108 | + string_producer = StringBodyProducer("some data") |
1109 | + creds = AWSCredentials("foo", "bar") |
1110 | + s3 = client.S3Client(creds, query_factory=StubQuery) |
1111 | + return s3.put_object("mybucket", "objectname", |
1112 | + content_type="text/plain", |
1113 | + metadata={"key": "some meta data"}, |
1114 | + amz_headers={"acl": "public-read"}, |
1115 | + body_producer=string_producer) |
1116 | + |
1117 | def test_copy_object(self): |
1118 | """ |
1119 | L{S3Client.copy_object} creates a L{Query} to copy an object from one |
1120 | @@ -766,7 +820,8 @@ |
1121 | |
1122 | def __init__(query, action, creds, endpoint, bucket=None, |
1123 | object_name=None, data=None, content_type=None, |
1124 | - metadata=None, amz_headers=None): |
1125 | + metadata=None, amz_headers=None, body_producer=None, |
1126 | + receiver_factory=None): |
1127 | super(StubQuery, query).__init__( |
1128 | action=action, creds=creds, bucket=bucket, |
1129 | object_name=object_name, data=data, |
1130 | @@ -798,7 +853,8 @@ |
1131 | |
1132 | def __init__(query, action, creds, endpoint, bucket=None, |
1133 | object_name=None, data=None, content_type=None, |
1134 | - metadata=None, amz_headers=None): |
1135 | + metadata=None, amz_headers=None, body_producer=None, |
1136 | + receiver_factory=None): |
1137 | super(StubQuery, query).__init__( |
1138 | action=action, creds=creds, bucket=bucket, |
1139 | object_name=object_name, data=data, |
1140 | @@ -822,7 +878,7 @@ |
1141 | |
1142 | def __init__(query, action, creds, endpoint, bucket=None, |
1143 | object_name=None, data=None, content_type=None, |
1144 | - metadata=None): |
1145 | + metadata=None, body_producer=None, receiver_factory=None): |
1146 | super(StubQuery, query).__init__( |
1147 | action=action, creds=creds, bucket=bucket, |
1148 | object_name=object_name, data=data, |
1149 | @@ -846,7 +902,7 @@ |
1150 | |
1151 | def __init__(query, action, creds, endpoint, bucket=None, |
1152 | object_name=None, data=None, content_type=None, |
1153 | - metadata=None): |
1154 | + metadata=None, body_producer=None, receiver_factory=None): |
1155 | super(StubQuery, query).__init__( |
1156 | action=action, creds=creds, bucket=bucket, |
1157 | object_name=object_name, data=data, |
1158 | @@ -869,7 +925,8 @@ |
1159 | class StubQuery(client.Query): |
1160 | |
1161 | def __init__(query, action, creds, endpoint, bucket=None, |
1162 | - object_name=None, data=""): |
1163 | + object_name=None, data="", body_producer=None, |
1164 | + receiver_factory=None): |
1165 | super(StubQuery, query).__init__(action=action, creds=creds, |
1166 | bucket=bucket, |
1167 | object_name=object_name, |
1168 | @@ -902,7 +959,8 @@ |
1169 | class StubQuery(client.Query): |
1170 | |
1171 | def __init__(query, action, creds, endpoint, bucket=None, |
1172 | - object_name=None, data=""): |
1173 | + object_name=None, data="", body_producer=None, |
1174 | + receiver_factory=None): |
1175 | super(StubQuery, query).__init__(action=action, creds=creds, |
1176 | bucket=bucket, |
1177 | object_name=object_name, |
1178 | @@ -926,6 +984,68 @@ |
1179 | deferred = s3.get_object_acl("mybucket", "myobject") |
1180 | return deferred.addCallback(check_result) |
1181 | |
1182 | + def test_init_multipart_upload(self): |
1183 | + |
1184 | + class StubQuery(client.Query): |
1185 | + |
1186 | + def __init__(query, action, creds, endpoint, bucket=None, |
1187 | + object_name=None, data="", body_producer=None, |
1188 | + content_type=None, receiver_factory=None, metadata={}): |
1189 | + super(StubQuery, query).__init__(action=action, creds=creds, |
1190 | + bucket=bucket, |
1191 | + object_name=object_name, |
1192 | + data=data) |
1193 | + self.assertEquals(action, "POST") |
1194 | + self.assertEqual(creds.access_key, "foo") |
1195 | + self.assertEqual(creds.secret_key, "bar") |
1196 | + self.assertEqual(query.bucket, "example-bucket") |
1197 | + self.assertEqual(query.object_name, "example-object?uploads") |
1198 | + self.assertEqual(query.data, "") |
1199 | + self.assertEqual(query.metadata, {}) |
1200 | + |
1201 | + def submit(query, url_context=None): |
1202 | + return succeed(payload.sample_s3_init_multipart_upload_result) |
1203 | + |
1204 | + |
1205 | + def check_result(result): |
1206 | + self.assert_(isinstance(result, MultipartInitiationResponse)) |
1207 | + self.assertEqual(result.bucket, "example-bucket") |
1208 | + self.assertEqual(result.object_name, "example-object") |
1209 | + self.assertEqual(result.upload_id, "deadbeef") |
1210 | + |
1211 | + creds = AWSCredentials("foo", "bar") |
1212 | + s3 = client.S3Client(creds, query_factory=StubQuery) |
1213 | + deferred = s3.init_multipart_upload("example-bucket", "example-object") |
1214 | + return deferred.addCallback(check_result) |
1215 | + |
1216 | + def test_upload_part(self): |
1217 | + |
1218 | + class StubQuery(client.Query): |
1219 | + |
1220 | + def __init__(query, action, creds, endpoint, bucket=None, |
1221 | + object_name=None, data="", body_producer=None, |
1222 | + content_type=None, receiver_factory=None, metadata={}): |
1223 | + super(StubQuery, query).__init__(action=action, creds=creds, |
1224 | + bucket=bucket, |
1225 | + object_name=object_name, |
1226 | + data=data) |
1227 | + self.assertEquals(action, "PUT") |
1228 | + self.assertEqual(creds.access_key, "foo") |
1229 | + self.assertEqual(creds.secret_key, "bar") |
1230 | + self.assertEqual(query.bucket, "example-bucket") |
1231 | + self.assertEqual(query.object_name, |
1232 | + "example-object?partNumber=3&uploadId=testid") |
1233 | + self.assertEqual(query.data, "some data") |
1234 | + self.assertEqual(query.metadata, {}) |
1235 | + |
1236 | + def submit(query, url_context=None): |
1237 | + return succeed(None) |
1238 | + |
1239 | + creds = AWSCredentials("foo", "bar") |
1240 | + s3 = client.S3Client(creds, query_factory=StubQuery) |
1241 | + return s3.upload_part("example-bucket", "example-object", "testid", 3, |
1242 | + "some data") |
1243 | + |
1244 | S3ClientTestCase.skip = s3clientSkip |
1245 | |
1246 | |
1247 | @@ -1077,7 +1197,8 @@ |
1248 | """ |
1249 | class StubQuery(client.Query): |
1250 | |
1251 | - def __init__(query, action, creds, endpoint, bucket): |
1252 | + def __init__(query, action, creds, endpoint, bucket, |
1253 | + body_producer=None, receiver_factory=None): |
1254 | super(StubQuery, query).__init__( |
1255 | action=action, creds=creds, bucket=bucket) |
1256 | self.assertEquals(action, "GET") |
1257 | |
1258 | === modified file 'txaws/testing/payload.py' |
1259 | --- txaws/testing/payload.py 2012-01-28 00:39:00 +0000 |
1260 | +++ txaws/testing/payload.py 2012-02-10 21:24:20 +0000 |
1261 | @@ -1085,3 +1085,12 @@ |
1262 | <Status>Enabled</Status> |
1263 | <MfaDelete>Disabled</MfaDelete> |
1264 | </VersioningConfiguration>""" |
1265 | + |
1266 | +sample_s3_init_multipart_upload_result = """\ |
1267 | +<InitiateMultipartUploadResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"> |
1268 | + <Bucket>example-bucket</Bucket> |
1269 | + <Key>example-object</Key> |
1270 | + <UploadId>deadbeef</UploadId> |
1271 | +</InitiateMultipartUploadResult>""" |
1272 | + |
1273 | + |
1274 | |
1275 | === added file 'txaws/testing/producers.py' |
1276 | --- txaws/testing/producers.py 1970-01-01 00:00:00 +0000 |
1277 | +++ txaws/testing/producers.py 2012-02-10 21:24:20 +0000 |
1278 | @@ -0,0 +1,23 @@ |
1279 | +from zope.interface import implements |
1280 | + |
1281 | +from twisted.internet.defer import succeed |
1282 | +from twisted.web.iweb import IBodyProducer |
1283 | + |
1284 | +class StringBodyProducer(object): |
1285 | + implements(IBodyProducer) |
1286 | + |
1287 | + def __init__(self, data): |
1288 | + self.data = data |
1289 | + self.length = len(data) |
1290 | + self.written = None |
1291 | + |
1292 | + def startProducing(self, consumer): |
1293 | + consumer.write(self.data) |
1294 | + self.written = self.data |
1295 | + return succeed(None) |
1296 | + |
1297 | + def pauseProducing(self): |
1298 | + pass |
1299 | + |
1300 | + def stopProducing(self): |
1301 | + pass |
Merge away!