Merge lp:~djfroofy/txaws/921418-initmultipart into lp:txaws
- 921418-initmultipart
- Merge into trunk
Proposed by
Drew Smathers
Status: | Superseded | ||||
---|---|---|---|---|---|
Proposed branch: | lp:~djfroofy/txaws/921418-initmultipart | ||||
Merge into: | lp:txaws | ||||
Diff against target: |
1263 lines (+610/-82) 9 files modified
txaws/client/_producers.py (+122/-0) txaws/client/base.py (+159/-14) txaws/client/tests/test_base.py (+53/-11) txaws/client/tests/test_ssl.py (+20/-4) txaws/s3/client.py (+71/-26) txaws/s3/model.py (+29/-0) txaws/s3/tests/test_client.py (+124/-27) txaws/testing/payload.py (+9/-0) txaws/testing/producers.py (+23/-0) |
||||
To merge this branch: | bzr merge lp:~djfroofy/txaws/921418-initmultipart | ||||
Related bugs: |
|
Reviewer | Review Type | Date Requested | Status |
---|---|---|---|
Duncan McGreggor | Pending | ||
Review via email: mp+92577@code.launchpad.net |
This proposal has been superseded by a proposal from 2012-02-15.
Commit message
Description of the change
Note this depends on my other branches:
modernize-924459
newagent-767205
To post a comment you must log in.
- 145. By Drew Smathers
-
add amz_headers argument to init_multipart_
upload command
Unmerged revisions
Preview Diff
[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1 | === added file 'txaws/client/_producers.py' |
2 | --- txaws/client/_producers.py 1970-01-01 00:00:00 +0000 |
3 | +++ txaws/client/_producers.py 2012-02-15 22:12:17 +0000 |
4 | @@ -0,0 +1,122 @@ |
5 | +import os |
6 | + |
7 | +from zope.interface import implements |
8 | + |
9 | +from twisted.internet import defer, task |
10 | +from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer |
11 | + |
12 | + |
13 | +# Code below for FileBodyProducer cut-and-paste from twisted source. |
14 | +# Currently this is not released so here temporarily for forward compat. |
15 | + |
16 | + |
17 | +class FileBodyProducer(object): |
18 | + """ |
19 | + L{FileBodyProducer} produces bytes from an input file object incrementally |
20 | + and writes them to a consumer. |
21 | + |
22 | + Since file-like objects cannot be read from in an event-driven manner, |
23 | + L{FileBodyProducer} uses a L{Cooperator} instance to schedule reads from |
24 | + the file. This process is also paused and resumed based on notifications |
25 | + from the L{IConsumer} provider being written to. |
26 | + |
27 | + The file is closed after it has been read, or if the producer is stopped |
28 | + early. |
29 | + |
30 | + @ivar _inputFile: Any file-like object, bytes read from which will be |
31 | + written to a consumer. |
32 | + |
33 | + @ivar _cooperate: A method like L{Cooperator.cooperate} which is used to |
34 | + schedule all reads. |
35 | + |
36 | + @ivar _readSize: The number of bytes to read from C{_inputFile} at a time. |
37 | + """ |
38 | + implements(IBodyProducer) |
39 | + |
40 | + # Python 2.4 doesn't have these symbolic constants |
41 | + _SEEK_SET = getattr(os, 'SEEK_SET', 0) |
42 | + _SEEK_END = getattr(os, 'SEEK_END', 2) |
43 | + |
44 | + def __init__(self, inputFile, cooperator=task, readSize=2 ** 16): |
45 | + self._inputFile = inputFile |
46 | + self._cooperate = cooperator.cooperate |
47 | + self._readSize = readSize |
48 | + self.length = self._determineLength(inputFile) |
49 | + |
50 | + |
51 | + def _determineLength(self, fObj): |
52 | + """ |
53 | + Determine how many bytes can be read out of C{fObj} (assuming it is not |
54 | + modified from this point on). If the determination cannot be made, |
55 | + return C{UNKNOWN_LENGTH}. |
56 | + """ |
57 | + try: |
58 | + seek = fObj.seek |
59 | + tell = fObj.tell |
60 | + except AttributeError: |
61 | + return UNKNOWN_LENGTH |
62 | + originalPosition = tell() |
63 | + seek(0, self._SEEK_END) |
64 | + end = tell() |
65 | + seek(originalPosition, self._SEEK_SET) |
66 | + return end - originalPosition |
67 | + |
68 | + |
69 | + def stopProducing(self): |
70 | + """ |
71 | + Permanently stop writing bytes from the file to the consumer by |
72 | + stopping the underlying L{CooperativeTask}. |
73 | + """ |
74 | + self._inputFile.close() |
75 | + self._task.stop() |
76 | + |
77 | + |
78 | + def startProducing(self, consumer): |
79 | + """ |
80 | + Start a cooperative task which will read bytes from the input file and |
81 | + write them to C{consumer}. Return a L{Deferred} which fires after all |
82 | + bytes have been written. |
83 | + |
84 | + @param consumer: Any L{IConsumer} provider |
85 | + """ |
86 | + self._task = self._cooperate(self._writeloop(consumer)) |
87 | + d = self._task.whenDone() |
88 | + def maybeStopped(reason): |
89 | + # IBodyProducer.startProducing's Deferred isn't support to fire if |
90 | + # stopProducing is called. |
91 | + reason.trap(task.TaskStopped) |
92 | + return defer.Deferred() |
93 | + d.addCallbacks(lambda ignored: None, maybeStopped) |
94 | + return d |
95 | + |
96 | + |
97 | + def _writeloop(self, consumer): |
98 | + """ |
99 | + Return an iterator which reads one chunk of bytes from the input file |
100 | + and writes them to the consumer for each time it is iterated. |
101 | + """ |
102 | + while True: |
103 | + bytes = self._inputFile.read(self._readSize) |
104 | + if not bytes: |
105 | + self._inputFile.close() |
106 | + break |
107 | + consumer.write(bytes) |
108 | + yield None |
109 | + |
110 | + |
111 | + def pauseProducing(self): |
112 | + """ |
113 | + Temporarily suspend copying bytes from the input file to the consumer |
114 | + by pausing the L{CooperativeTask} which drives that activity. |
115 | + """ |
116 | + self._task.pause() |
117 | + |
118 | + |
119 | + def resumeProducing(self): |
120 | + """ |
121 | + Undo the effects of a previous C{pauseProducing} and resume copying |
122 | + bytes to the consumer by resuming the L{CooperativeTask} which drives |
123 | + the write activity. |
124 | + """ |
125 | + self._task.resume() |
126 | + |
127 | |
128 | === modified file 'txaws/client/base.py' |
129 | --- txaws/client/base.py 2011-11-29 08:17:54 +0000 |
130 | +++ txaws/client/base.py 2012-02-15 22:12:17 +0000 |
131 | @@ -3,10 +3,25 @@ |
132 | except ImportError: |
133 | from xml.parsers.expat import ExpatError as ParseError |
134 | |
135 | +import warnings |
136 | +from StringIO import StringIO |
137 | + |
138 | from twisted.internet.ssl import ClientContextFactory |
139 | +from twisted.internet.protocol import Protocol |
140 | +from twisted.internet.defer import Deferred, succeed |
141 | +from twisted.python import failure |
142 | from twisted.web import http |
143 | +from twisted.web.iweb import UNKNOWN_LENGTH |
144 | from twisted.web.client import HTTPClientFactory |
145 | +from twisted.web.client import Agent |
146 | +from twisted.web.client import ResponseDone |
147 | +from twisted.web.http import NO_CONTENT |
148 | +from twisted.web.http_headers import Headers |
149 | from twisted.web.error import Error as TwistedWebError |
150 | +try: |
151 | + from twisted.web.client import FileBodyProducer |
152 | +except ImportError: |
153 | + from txaws.client._producers import FileBodyProducer |
154 | |
155 | from txaws.util import parse |
156 | from txaws.credentials import AWSCredentials |
157 | @@ -59,9 +74,10 @@ |
158 | @param query_factory: The class or function that produces a query |
159 | object for making requests to the EC2 service. |
160 | @param parser: A parser object for parsing responses from the EC2 service. |
161 | + @param receiver_factory: Factory for receiving responses from EC2 service. |
162 | """ |
163 | def __init__(self, creds=None, endpoint=None, query_factory=None, |
164 | - parser=None): |
165 | + parser=None, receiver_factory=None): |
166 | if creds is None: |
167 | creds = AWSCredentials() |
168 | if endpoint is None: |
169 | @@ -69,22 +85,109 @@ |
170 | self.creds = creds |
171 | self.endpoint = endpoint |
172 | self.query_factory = query_factory |
173 | + self.receiver_factory = receiver_factory |
174 | self.parser = parser |
175 | |
176 | +class StreamingError(Exception): |
177 | + """ |
178 | + Raised if more data or less data is received than expected. |
179 | + """ |
180 | + |
181 | + |
182 | +class StringIOBodyReceiver(Protocol): |
183 | + """ |
184 | + Simple StringIO-based HTTP response body receiver. |
185 | + |
186 | + TODO: perhaps there should be an interface specifying why |
187 | + finished (Deferred) and content_length are necessary and |
188 | + how to used them; eg. callback/errback finished on completion. |
189 | + """ |
190 | + finished = None |
191 | + content_length = None |
192 | + |
193 | + def __init__(self): |
194 | + self._buffer = StringIO() |
195 | + self._received = 0 |
196 | + |
197 | + def dataReceived(self, bytes): |
198 | + streaming = self.content_length is UNKNOWN_LENGTH |
199 | + if not streaming and (self._received > self.content_length): |
200 | + self.transport.loseConnection() |
201 | + raise StreamingError( |
202 | + "Buffer overflow - received more data than " |
203 | + "Content-Length dictated: %d" % self.content_length) |
204 | + # TODO should be some limit on how much we receive |
205 | + self._buffer.write(bytes) |
206 | + self._received += len(bytes) |
207 | + |
208 | + def connectionLost(self, reason): |
209 | + reason.trap(ResponseDone) |
210 | + d = self.finished |
211 | + self.finished = None |
212 | + streaming = self.content_length is UNKNOWN_LENGTH |
213 | + if streaming or (self._received == self.content_length): |
214 | + d.callback(self._buffer.getvalue()) |
215 | + else: |
216 | + f = failure.Failure(StreamingError("Connection lost before " |
217 | + "receiving all data")) |
218 | + d.errback(f) |
219 | + |
220 | + |
221 | +class WebClientContextFactory(ClientContextFactory): |
222 | + |
223 | + def getContext(self, hostname, port): |
224 | + return ClientContextFactory.getContext(self) |
225 | + |
226 | + |
227 | +class WebVerifyingContextFactory(VerifyingContextFactory): |
228 | + |
229 | + def getContext(self, hostname, port): |
230 | + return VerifyingContextFactory.getContext(self) |
231 | + |
232 | + |
233 | +class FakeClient(object): |
234 | + """ |
235 | + XXX |
236 | + A fake client object for some degree of backwards compatability for |
237 | + code using the client attibute on BaseQuery to check url, status |
238 | + etc. |
239 | + """ |
240 | + url = None |
241 | + status = None |
242 | |
243 | class BaseQuery(object): |
244 | |
245 | - def __init__(self, action=None, creds=None, endpoint=None, reactor=None): |
246 | + def __init__(self, action=None, creds=None, endpoint=None, reactor=None, |
247 | + body_producer=None, receiver_factory=None): |
248 | if not action: |
249 | raise TypeError("The query requires an action parameter.") |
250 | - self.factory = HTTPClientFactory |
251 | self.action = action |
252 | self.creds = creds |
253 | self.endpoint = endpoint |
254 | if reactor is None: |
255 | from twisted.internet import reactor |
256 | self.reactor = reactor |
257 | - self.client = None |
258 | + self._client = None |
259 | + self.request_headers = None |
260 | + self.response_headers = None |
261 | + self.body_producer = body_producer |
262 | + self.receiver_factory = receiver_factory or StringIOBodyReceiver |
263 | + |
264 | + @property |
265 | + def client(self): |
266 | + if self._client is None: |
267 | + self._client_deprecation_warning() |
268 | + self._client = FakeClient() |
269 | + return self._client |
270 | + |
271 | + @client.setter |
272 | + def client(self, value): |
273 | + self._client_deprecation_warning() |
274 | + self._client = value |
275 | + |
276 | + def _client_deprecation_warning(self): |
277 | + warnings.warn('The client attribute on BaseQuery is deprecated and' |
278 | + ' will go away in future release.') |
279 | |
280 | def get_page(self, url, *args, **kwds): |
281 | """ |
282 | @@ -95,16 +198,39 @@ |
283 | """ |
284 | contextFactory = None |
285 | scheme, host, port, path = parse(url) |
286 | - self.client = self.factory(url, *args, **kwds) |
287 | + data = kwds.get('postdata', None) |
288 | + self._method = method = kwds.get('method', 'GET') |
289 | + self.request_headers = self._headers(kwds.get('headers', {})) |
290 | + if (self.body_producer is None) and (data is not None): |
291 | + self.body_producer = FileBodyProducer(StringIO(data)) |
292 | if scheme == "https": |
293 | if self.endpoint.ssl_hostname_verification: |
294 | - contextFactory = VerifyingContextFactory(host) |
295 | + contextFactory = WebVerifyingContextFactory(host) |
296 | else: |
297 | - contextFactory = ClientContextFactory() |
298 | - self.reactor.connectSSL(host, port, self.client, contextFactory) |
299 | + contextFactory = WebClientContextFactory() |
300 | + agent = Agent(self.reactor, contextFactory) |
301 | + self.client.url = url |
302 | + d = agent.request(method, url, self.request_headers, |
303 | + self.body_producer) |
304 | else: |
305 | - self.reactor.connectTCP(host, port, self.client) |
306 | - return self.client.deferred |
307 | + agent = Agent(self.reactor) |
308 | + d = agent.request(method, url, self.request_headers, |
309 | + self.body_producer) |
310 | + d.addCallback(self._handle_response) |
311 | + return d |
312 | + |
313 | + def _headers(self, headers_dict): |
314 | + """ |
315 | + Convert dictionary of headers into twisted.web.client.Headers object. |
316 | + """ |
317 | + return Headers(dict((k,[v]) for (k,v) in headers_dict.items())) |
318 | + |
319 | + def _unpack_headers(self, headers): |
320 | + """ |
321 | + Unpack twisted.web.client.Headers object to dict. This is to provide |
322 | + backwards compatability. |
323 | + """ |
324 | + return dict((k,v[0]) for (k,v) in headers.getAllRawHeaders()) |
325 | |
326 | def get_request_headers(self, *args, **kwds): |
327 | """ |
328 | @@ -114,8 +240,26 @@ |
329 | The AWS S3 API depends upon setting headers. This method is provided as |
330 | a convenience for debugging issues with the S3 communications. |
331 | """ |
332 | - if self.client: |
333 | - return self.client.headers |
334 | + if self.request_headers: |
335 | + return self._unpack_headers(self.request_headers) |
336 | + |
337 | + def _handle_response(self, response): |
338 | + """ |
339 | + Handle the HTTP response by memoing the headers and then delivering |
340 | + bytes. |
341 | + """ |
342 | + self.client.status = response.code |
343 | + self.response_headers = headers = response.headers |
344 | + # XXX This workaround (which needs to be improved at that) for possible |
345 | + # bug in Twisted with new client: |
346 | + # http://twistedmatrix.com/trac/ticket/5476 |
347 | + if self._method.upper() == 'HEAD' or response.code == NO_CONTENT: |
348 | + return succeed('') |
349 | + receiver = self.receiver_factory() |
350 | + receiver.finished = d = Deferred() |
351 | + receiver.content_length = response.length |
352 | + response.deliverBody(receiver) |
353 | + return d |
354 | |
355 | def get_response_headers(self, *args, **kwargs): |
356 | """ |
357 | @@ -125,5 +269,6 @@ |
358 | The AWS S3 API depends upon setting headers. This method is used by the |
359 | head_object API call for getting a S3 object's metadata. |
360 | """ |
361 | - if self.client: |
362 | - return self.client.response_headers |
363 | + if self.response_headers: |
364 | + return self._unpack_headers(self.response_headers) |
365 | + |
366 | |
367 | === modified file 'txaws/client/tests/test_base.py' |
368 | --- txaws/client/tests/test_base.py 2012-01-26 18:43:48 +0000 |
369 | +++ txaws/client/tests/test_base.py 2012-02-15 22:12:17 +0000 |
370 | @@ -1,6 +1,9 @@ |
371 | import os |
372 | |
373 | +from zope.interface import implements |
374 | + |
375 | from twisted.internet import reactor |
376 | +from twisted.internet.defer import succeed |
377 | from twisted.internet.error import ConnectionRefusedError |
378 | from twisted.protocols.policies import WrappingFactory |
379 | from twisted.python import log |
380 | @@ -8,14 +11,16 @@ |
381 | from twisted.python.failure import Failure |
382 | from twisted.test.test_sslverify import makeCertificate |
383 | from twisted.web import server, static |
384 | +from twisted.web.iweb import IBodyProducer |
385 | from twisted.web.client import HTTPClientFactory |
386 | from twisted.web.error import Error as TwistedWebError |
387 | |
388 | from txaws.client import ssl |
389 | from txaws.client.base import BaseClient, BaseQuery, error_wrapper |
390 | +from txaws.client.base import StringIOBodyReceiver |
391 | from txaws.service import AWSServiceEndpoint |
392 | from txaws.testing.base import TXAWSTestCase |
393 | - |
394 | +from txaws.testing.producers import StringBodyProducer |
395 | |
396 | class ErrorWrapperTestCase(TXAWSTestCase): |
397 | |
398 | @@ -99,7 +104,6 @@ |
399 | |
400 | def test_creation(self): |
401 | query = BaseQuery("an action", "creds", "http://endpoint") |
402 | - self.assertEquals(query.factory, HTTPClientFactory) |
403 | self.assertEquals(query.action, "an action") |
404 | self.assertEquals(query.creds, "creds") |
405 | self.assertEquals(query.endpoint, "http://endpoint") |
406 | @@ -142,16 +146,52 @@ |
407 | def test_get_response_headers_with_client(self): |
408 | |
409 | def check_results(results): |
410 | + #self.assertEquals(sorted(results.keys()), [ |
411 | + # "accept-ranges", "content-length", "content-type", "date", |
412 | + # "last-modified", "server"]) |
413 | + # XXX I think newclient exludes content-length from headers? |
414 | + # Also the header names are capitalized ... do we need to worry |
415 | + # about backwards compat? |
416 | self.assertEquals(sorted(results.keys()), [ |
417 | - "accept-ranges", "content-length", "content-type", "date", |
418 | - "last-modified", "server"]) |
419 | - self.assertEquals(len(results.values()), 6) |
420 | + "Accept-Ranges", "Content-Type", "Date", |
421 | + "Last-Modified", "Server"]) |
422 | + self.assertEquals(len(results.values()), 5) |
423 | |
424 | query = BaseQuery("an action", "creds", "http://endpoint") |
425 | d = query.get_page(self._get_url("file")) |
426 | d.addCallback(query.get_response_headers) |
427 | return d.addCallback(check_results) |
428 | |
429 | + def test_custom_body_producer(self): |
430 | + |
431 | + def check_producer_was_used(ignore): |
432 | + self.assertEqual(producer.written, 'test data') |
433 | + |
434 | + producer = StringBodyProducer('test data') |
435 | + query = BaseQuery("an action", "creds", "http://endpoint", |
436 | + body_producer=producer) |
437 | + d = query.get_page(self._get_url("file"), method='PUT') |
438 | + return d.addCallback(check_producer_was_used) |
439 | + |
440 | + def test_custom_receiver_factory(self): |
441 | + |
442 | + class TestReceiverProtocol(StringIOBodyReceiver): |
443 | + used = False |
444 | + |
445 | + def __init__(self): |
446 | + StringIOBodyReceiver.__init__(self) |
447 | + TestReceiverProtocol.used = True |
448 | + |
449 | + def check_used(ignore): |
450 | + self.assert_(TestReceiverProtocol.used) |
451 | + |
452 | + query = BaseQuery("an action", "creds", "http://endpoint", |
453 | + receiver_factory=TestReceiverProtocol) |
454 | + d = query.get_page(self._get_url("file")) |
455 | + d.addCallback(self.assertEquals, "0123456789") |
456 | + d.addCallback(check_used) |
457 | + return d |
458 | + |
459 | # XXX for systems that don't have certs in the DEFAULT_CERT_PATH, this test |
460 | # will fail; instead, let's create some certs in a temp directory and set |
461 | # the DEFAULT_CERT_PATH to point there. |
462 | @@ -167,8 +207,9 @@ |
463 | def __init__(self): |
464 | self.connects = [] |
465 | |
466 | - def connectSSL(self, host, port, client, factory): |
467 | - self.connects.append((host, port, client, factory)) |
468 | + def connectSSL(self, host, port, factory, contextFactory, timeout, |
469 | + bindAddress): |
470 | + self.connects.append((host, port, factory, contextFactory)) |
471 | |
472 | certs = makeCertificate(O="Test Certificate", CN="something")[1] |
473 | self.patch(ssl, "_ca_certs", certs) |
474 | @@ -176,9 +217,10 @@ |
475 | endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) |
476 | query = BaseQuery("an action", "creds", endpoint, fake_reactor) |
477 | query.get_page("https://example.com/file") |
478 | - [(host, port, client, factory)] = fake_reactor.connects |
479 | + [(host, port, factory, contextFactory)] = fake_reactor.connects |
480 | self.assertEqual("example.com", host) |
481 | self.assertEqual(443, port) |
482 | - self.assertTrue(isinstance(factory, ssl.VerifyingContextFactory)) |
483 | - self.assertEqual("example.com", factory.host) |
484 | - self.assertNotEqual([], factory.caCerts) |
485 | + wrappedFactory = contextFactory._webContext |
486 | + self.assertTrue(isinstance(wrappedFactory, ssl.VerifyingContextFactory)) |
487 | + self.assertEqual("example.com", wrappedFactory.host) |
488 | + self.assertNotEqual([], wrappedFactory.caCerts) |
489 | |
490 | === modified file 'txaws/client/tests/test_ssl.py' |
491 | --- txaws/client/tests/test_ssl.py 2012-01-26 22:54:44 +0000 |
492 | +++ txaws/client/tests/test_ssl.py 2012-02-15 22:12:17 +0000 |
493 | @@ -12,6 +12,10 @@ |
494 | from twisted.python.filepath import FilePath |
495 | from twisted.test.test_sslverify import makeCertificate |
496 | from twisted.web import server, static |
497 | +try: |
498 | + from twisted.web.client import ResponseFailed |
499 | +except ImportError: |
500 | + from twisted.web._newclient import ResponseFailed |
501 | |
502 | from txaws import exception |
503 | from txaws.client import ssl |
504 | @@ -32,6 +36,11 @@ |
505 | PUBSANKEY = sibpath("public_san.ssl") |
506 | |
507 | |
508 | +class WebDefaultOpenSSLContextFactory(DefaultOpenSSLContextFactory): |
509 | + def getContext(self, hostname=None, port=None): |
510 | + return DefaultOpenSSLContextFactory.getContext(self) |
511 | + |
512 | + |
513 | class BaseQuerySSLTestCase(TXAWSTestCase): |
514 | |
515 | def setUp(self): |
516 | @@ -75,7 +84,7 @@ |
517 | The L{VerifyingContextFactory} properly allows to connect to the |
518 | endpoint if the certificates match. |
519 | """ |
520 | - context_factory = DefaultOpenSSLContextFactory(PRIVKEY, PUBKEY) |
521 | + context_factory = WebDefaultOpenSSLContextFactory(PRIVKEY, PUBKEY) |
522 | self.port = reactor.listenSSL( |
523 | 0, self.site, context_factory, interface="127.0.0.1") |
524 | self.portno = self.port.getHost().port |
525 | @@ -90,7 +99,7 @@ |
526 | The L{VerifyingContextFactory} fails with a SSL error the certificates |
527 | can't be checked. |
528 | """ |
529 | - context_factory = DefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) |
530 | + context_factory = WebDefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) |
531 | self.port = reactor.listenSSL( |
532 | 0, self.site, context_factory, interface="127.0.0.1") |
533 | self.portno = self.port.getHost().port |
534 | @@ -98,7 +107,14 @@ |
535 | endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) |
536 | query = BaseQuery("an action", "creds", endpoint) |
537 | d = query.get_page(self._get_url("file")) |
538 | - return self.assertFailure(d, SSLError) |
539 | + def fail(ignore): |
540 | + self.fail('Expected SSLError') |
541 | + def check_exception(why): |
542 | + # XXX kind of a mess here ... need to unwrap the |
543 | + # exception and check |
544 | + root_exc = why.value[0][0].value |
545 | + self.assert_(isinstance(root_exc, SSLError)) |
546 | + return d.addCallbacks(fail, check_exception) |
547 | |
548 | def test_ssl_verification_bypassed(self): |
549 | """ |
550 | @@ -121,7 +137,7 @@ |
551 | L{VerifyingContextFactory} supports checking C{subjectAltName} in the |
552 | certificate if it's available. |
553 | """ |
554 | - context_factory = DefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY) |
555 | + context_factory = WebDefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY) |
556 | self.port = reactor.listenSSL( |
557 | 0, self.site, context_factory, interface="127.0.0.1") |
558 | self.portno = self.port.getHost().port |
559 | |
560 | === modified file 'txaws/s3/client.py' |
561 | --- txaws/s3/client.py 2012-01-28 00:39:00 +0000 |
562 | +++ txaws/s3/client.py 2012-02-15 22:12:17 +0000 |
563 | @@ -23,7 +23,7 @@ |
564 | from txaws.s3.model import ( |
565 | Bucket, BucketItem, BucketListing, ItemOwner, LifecycleConfiguration, |
566 | LifecycleConfigurationRule, NotificationConfiguration, RequestPayment, |
567 | - VersioningConfiguration, WebsiteConfiguration) |
568 | + VersioningConfiguration, WebsiteConfiguration, MultipartInitiationResponse) |
569 | from txaws.s3.exception import S3Error |
570 | from txaws.service import AWSServiceEndpoint, S3_ENDPOINT |
571 | from txaws.util import XML, calculate_md5 |
572 | @@ -74,10 +74,12 @@ |
573 | class S3Client(BaseClient): |
574 | """A client for S3.""" |
575 | |
576 | - def __init__(self, creds=None, endpoint=None, query_factory=None): |
577 | + def __init__(self, creds=None, endpoint=None, query_factory=None, |
578 | + receiver_factory=None): |
579 | if query_factory is None: |
580 | query_factory = Query |
581 | - super(S3Client, self).__init__(creds, endpoint, query_factory) |
582 | + super(S3Client, self).__init__(creds, endpoint, query_factory, |
583 | + receiver_factory=receiver_factory) |
584 | |
585 | def list_buckets(self): |
586 | """ |
587 | @@ -87,7 +89,8 @@ |
588 | the request. |
589 | """ |
590 | query = self.query_factory( |
591 | - action="GET", creds=self.creds, endpoint=self.endpoint) |
592 | + action="GET", creds=self.creds, endpoint=self.endpoint, |
593 | + receiver_factory=self.receiver_factory) |
594 | d = query.submit() |
595 | return d.addCallback(self._parse_list_buckets) |
596 | |
597 | @@ -131,7 +134,7 @@ |
598 | """ |
599 | query = self.query_factory( |
600 | action="GET", creds=self.creds, endpoint=self.endpoint, |
601 | - bucket=bucket) |
602 | + bucket=bucket, receiver_factory=self.receiver_factory) |
603 | d = query.submit() |
604 | return d.addCallback(self._parse_get_bucket) |
605 | |
606 | @@ -174,7 +177,8 @@ |
607 | """ |
608 | query = self.query_factory(action="GET", creds=self.creds, |
609 | endpoint=self.endpoint, bucket=bucket, |
610 | - object_name="?location") |
611 | + object_name="?location", |
612 | + receiver_factory=self.receiver_factory) |
613 | d = query.submit() |
614 | return d.addCallback(self._parse_bucket_location) |
615 | |
616 | @@ -193,7 +197,8 @@ |
617 | """ |
618 | query = self.query_factory( |
619 | action='GET', creds=self.creds, endpoint=self.endpoint, |
620 | - bucket=bucket, object_name='?lifecycle') |
621 | + bucket=bucket, object_name='?lifecycle', |
622 | + receiver_factory=self.receiver_factory) |
623 | return query.submit().addCallback(self._parse_lifecycle_config) |
624 | |
625 | def _parse_lifecycle_config(self, xml_bytes): |
626 | @@ -221,7 +226,8 @@ |
627 | """ |
628 | query = self.query_factory( |
629 | action='GET', creds=self.creds, endpoint=self.endpoint, |
630 | - bucket=bucket, object_name='?website') |
631 | + bucket=bucket, object_name='?website', |
632 | + receiver_factory=self.receiver_factory) |
633 | return query.submit().addCallback(self._parse_website_config) |
634 | |
635 | def _parse_website_config(self, xml_bytes): |
636 | @@ -242,7 +248,8 @@ |
637 | """ |
638 | query = self.query_factory( |
639 | action='GET', creds=self.creds, endpoint=self.endpoint, |
640 | - bucket=bucket, object_name='?notification') |
641 | + bucket=bucket, object_name='?notification', |
642 | + receiver_factory=self.receiver_factory) |
643 | return query.submit().addCallback(self._parse_notification_config) |
644 | |
645 | def _parse_notification_config(self, xml_bytes): |
646 | @@ -262,7 +269,8 @@ |
647 | """ |
648 | query = self.query_factory( |
649 | action='GET', creds=self.creds, endpoint=self.endpoint, |
650 | - bucket=bucket, object_name='?versioning') |
651 | + bucket=bucket, object_name='?versioning', |
652 | + receiver_factory=self.receiver_factory) |
653 | return query.submit().addCallback(self._parse_versioning_config) |
654 | |
655 | def _parse_versioning_config(self, xml_bytes): |
656 | @@ -279,7 +287,8 @@ |
657 | """ |
658 | query = self.query_factory( |
659 | action='GET', creds=self.creds, endpoint=self.endpoint, |
660 | - bucket=bucket, object_name='?acl') |
661 | + bucket=bucket, object_name='?acl', |
662 | + receiver_factory=self.receiver_factory) |
663 | return query.submit().addCallback(self._parse_acl) |
664 | |
665 | def put_bucket_acl(self, bucket, access_control_policy): |
666 | @@ -289,7 +298,8 @@ |
667 | data = access_control_policy.to_xml() |
668 | query = self.query_factory( |
669 | action='PUT', creds=self.creds, endpoint=self.endpoint, |
670 | - bucket=bucket, object_name='?acl', data=data) |
671 | + bucket=bucket, object_name='?acl', data=data, |
672 | + receiver_factory=self.receiver_factory) |
673 | return query.submit().addCallback(self._parse_acl) |
674 | |
675 | def _parse_acl(self, xml_bytes): |
676 | @@ -299,8 +309,8 @@ |
677 | """ |
678 | return AccessControlPolicy.from_xml(xml_bytes) |
679 | |
680 | - def put_object(self, bucket, object_name, data, content_type=None, |
681 | - metadata={}, amz_headers={}): |
682 | + def put_object(self, bucket, object_name, data=None, content_type=None, |
683 | + metadata={}, amz_headers={}, body_producer=None): |
684 | """ |
685 | Put an object in a bucket. |
686 | |
687 | @@ -318,7 +328,8 @@ |
688 | action="PUT", creds=self.creds, endpoint=self.endpoint, |
689 | bucket=bucket, object_name=object_name, data=data, |
690 | content_type=content_type, metadata=metadata, |
691 | - amz_headers=amz_headers) |
692 | + amz_headers=amz_headers, body_producer=body_producer, |
693 | + receiver_factory=self.receiver_factory) |
694 | return query.submit() |
695 | |
696 | def copy_object(self, source_bucket, source_object_name, dest_bucket=None, |
697 | @@ -344,7 +355,8 @@ |
698 | query = self.query_factory( |
699 | action="PUT", creds=self.creds, endpoint=self.endpoint, |
700 | bucket=dest_bucket, object_name=dest_object_name, |
701 | - metadata=metadata, amz_headers=amz_headers) |
702 | + metadata=metadata, amz_headers=amz_headers, |
703 | + receiver_factory=self.receiver_factory) |
704 | return query.submit() |
705 | |
706 | def get_object(self, bucket, object_name): |
707 | @@ -353,7 +365,8 @@ |
708 | """ |
709 | query = self.query_factory( |
710 | action="GET", creds=self.creds, endpoint=self.endpoint, |
711 | - bucket=bucket, object_name=object_name) |
712 | + bucket=bucket, object_name=object_name, |
713 | + receiver_factory=self.receiver_factory) |
714 | return query.submit() |
715 | |
716 | def head_object(self, bucket, object_name): |
717 | @@ -384,7 +397,8 @@ |
718 | data = access_control_policy.to_xml() |
719 | query = self.query_factory( |
720 | action='PUT', creds=self.creds, endpoint=self.endpoint, |
721 | - bucket=bucket, object_name='%s?acl' % object_name, data=data) |
722 | + bucket=bucket, object_name='%s?acl' % object_name, data=data, |
723 | + receiver_factory=self.receiver_factory) |
724 | return query.submit().addCallback(self._parse_acl) |
725 | |
726 | def get_object_acl(self, bucket, object_name): |
727 | @@ -393,7 +407,8 @@ |
728 | """ |
729 | query = self.query_factory( |
730 | action='GET', creds=self.creds, endpoint=self.endpoint, |
731 | - bucket=bucket, object_name='%s?acl' % object_name) |
732 | + bucket=bucket, object_name='%s?acl' % object_name, |
733 | + receiver_factory=self.receiver_factory) |
734 | return query.submit().addCallback(self._parse_acl) |
735 | |
736 | def put_request_payment(self, bucket, payer): |
737 | @@ -407,7 +422,8 @@ |
738 | data = RequestPayment(payer).to_xml() |
739 | query = self.query_factory( |
740 | action="PUT", creds=self.creds, endpoint=self.endpoint, |
741 | - bucket=bucket, object_name="?requestPayment", data=data) |
742 | + bucket=bucket, object_name="?requestPayment", data=data, |
743 | + receiver_factory=self.receiver_factory) |
744 | return query.submit() |
745 | |
746 | def get_request_payment(self, bucket): |
747 | @@ -419,7 +435,8 @@ |
748 | """ |
749 | query = self.query_factory( |
750 | action="GET", creds=self.creds, endpoint=self.endpoint, |
751 | - bucket=bucket, object_name="?requestPayment") |
752 | + bucket=bucket, object_name="?requestPayment", |
753 | + receiver_factory=self.receiver_factory) |
754 | return query.submit().addCallback(self._parse_get_request_payment) |
755 | |
756 | def _parse_get_request_payment(self, xml_bytes): |
757 | @@ -429,17 +446,39 @@ |
758 | """ |
759 | return RequestPayment.from_xml(xml_bytes).payer |
760 | |
761 | + def init_multipart_upload(self, bucket, object_name, content_type=None, |
762 | + amz_headers={}, metadata={}): |
763 | + """ |
764 | + Initiate a multipart upload to a bucket. |
765 | + |
766 | + @param bucket: The name of the bucket |
767 | + @param object_name: The object name |
768 | + @param content_type: The Content-Type for the object |
769 | + @param metadata: C{dict} containing additional metadata |
770 | + @param amz_headers: A C{dict} used to build C{x-amz-*} headers. |
771 | + @return: C{str} upload_id |
772 | + """ |
773 | + objectname_plus = '%s?uploads' % object_name |
774 | + query = self.query_factory( |
775 | + action="POST", creds=self.creds, endpoint=self.endpoint, |
776 | + bucket=bucket, object_name=objectname_plus, data='', |
777 | + content_type=content_type, amz_headers=amz_headers, |
778 | + metadata=metadata) |
779 | + d = query.submit() |
780 | + return d.addCallback(MultipartInitiationResponse.from_xml) |
781 | + |
782 | |
783 | class Query(BaseQuery): |
784 | """A query for submission to the S3 service.""" |
785 | |
786 | def __init__(self, bucket=None, object_name=None, data="", |
787 | - content_type=None, metadata={}, amz_headers={}, *args, |
788 | - **kwargs): |
789 | + content_type=None, metadata={}, amz_headers={}, |
790 | + body_producer=None, *args, **kwargs): |
791 | super(Query, self).__init__(*args, **kwargs) |
792 | self.bucket = bucket |
793 | self.object_name = object_name |
794 | self.data = data |
795 | + self.body_producer = body_producer |
796 | self.content_type = content_type |
797 | self.metadata = metadata |
798 | self.amz_headers = amz_headers |
799 | @@ -463,9 +502,14 @@ |
800 | """ |
801 | Build the list of headers needed in order to perform S3 operations. |
802 | """ |
803 | - headers = {"Content-Length": len(self.data), |
804 | - "Content-MD5": calculate_md5(self.data), |
805 | + if self.body_producer: |
806 | + content_length = self.body_producer.length |
807 | + else: |
808 | + content_length = len(self.data) |
809 | + headers = {"Content-Length": content_length, |
810 | "Date": self.date} |
811 | + if self.body_producer is None: |
812 | + headers["Content-MD5"] = calculate_md5(self.data) |
813 | for key, value in self.metadata.iteritems(): |
814 | headers["x-amz-meta-" + key] = value |
815 | for key, value in self.amz_headers.iteritems(): |
816 | @@ -529,5 +573,6 @@ |
817 | self.endpoint, self.bucket, self.object_name) |
818 | d = self.get_page( |
819 | url_context.get_url(), method=self.action, postdata=self.data, |
820 | - headers=self.get_headers()) |
821 | + headers=self.get_headers(), body_producer=self.body_producer, |
822 | + receiver_factory=self.receiver_factory) |
823 | return d.addErrback(s3_error_wrapper) |
824 | |
825 | === modified file 'txaws/s3/model.py' |
826 | --- txaws/s3/model.py 2012-01-28 00:42:38 +0000 |
827 | +++ txaws/s3/model.py 2012-02-15 22:12:17 +0000 |
828 | @@ -150,3 +150,32 @@ |
829 | """ |
830 | root = XML(xml_bytes) |
831 | return cls(root.findtext("Payer")) |
832 | + |
833 | + |
834 | +class MultipartInitiationResponse(object): |
835 | + """ |
836 | + A response to Initiate Multipart Upload |
837 | + """ |
838 | + |
839 | + def __init__(self, bucket, object_name, upload_id): |
840 | + """ |
841 | + @param bucket: The bucket name |
842 | + @param object_name: The object name |
843 | + @param upload_id: The upload id |
844 | + """ |
845 | + self.bucket = bucket |
846 | + self.object_name = object_name |
847 | + self.upload_id = upload_id |
848 | + |
849 | + @classmethod |
850 | + def from_xml(cls, xml_bytes): |
851 | + """ |
852 | + Create an instance of this from XML bytes. |
853 | + |
854 | + @param xml_bytes: C{str} bytes of XML to parse |
855 | + @return: and instance of L{MultipartInitiationResponse} |
856 | + """ |
857 | + root = XML(xml_bytes) |
858 | + return cls(root.findtext('Bucket'), |
859 | + root.findtext('Key'), |
860 | + root.findtext('UploadId')) |
861 | |
862 | === modified file 'txaws/s3/tests/test_client.py' |
863 | --- txaws/s3/tests/test_client.py 2012-01-28 00:44:53 +0000 |
864 | +++ txaws/s3/tests/test_client.py 2012-02-15 22:12:17 +0000 |
865 | @@ -9,7 +9,8 @@ |
866 | else: |
867 | s3clientSkip = None |
868 | from txaws.s3.acls import AccessControlPolicy |
869 | -from txaws.s3.model import RequestPayment |
870 | +from txaws.s3.model import RequestPayment, MultipartInitiationResponse |
871 | +from txaws.testing.producers import StringBodyProducer |
872 | from txaws.service import AWSServiceEndpoint |
873 | from txaws.testing import payload |
874 | from txaws.testing.base import TXAWSTestCase |
875 | @@ -100,7 +101,8 @@ |
876 | |
877 | class StubQuery(client.Query): |
878 | |
879 | - def __init__(query, action, creds, endpoint): |
880 | + def __init__(query, action, creds, endpoint, |
881 | + body_producer=None, receiver_factory=None): |
882 | super(StubQuery, query).__init__( |
883 | action=action, creds=creds) |
884 | self.assertEquals(action, "GET") |
885 | @@ -134,7 +136,8 @@ |
886 | |
887 | class StubQuery(client.Query): |
888 | |
889 | - def __init__(query, action, creds, endpoint, bucket=None): |
890 | + def __init__(query, action, creds, endpoint, bucket=None, |
891 | + body_producer=None, receiver_factory=None): |
892 | super(StubQuery, query).__init__( |
893 | action=action, creds=creds, bucket=bucket) |
894 | self.assertEquals(action, "PUT") |
895 | @@ -156,7 +159,8 @@ |
896 | |
897 | class StubQuery(client.Query): |
898 | |
899 | - def __init__(query, action, creds, endpoint, bucket=None): |
900 | + def __init__(query, action, creds, endpoint, bucket=None, |
901 | + body_producer=None, receiver_factory=None): |
902 | super(StubQuery, query).__init__( |
903 | action=action, creds=creds, bucket=bucket) |
904 | self.assertEquals(action, "GET") |
905 | @@ -208,7 +212,8 @@ |
906 | class StubQuery(client.Query): |
907 | |
908 | def __init__(query, action, creds, endpoint, bucket=None, |
909 | - object_name=None): |
910 | + object_name=None, body_producer=None, |
911 | + receiver_factory=None): |
912 | super(StubQuery, query).__init__(action=action, creds=creds, |
913 | bucket=bucket, |
914 | object_name=object_name) |
915 | @@ -243,7 +248,8 @@ |
916 | class StubQuery(client.Query): |
917 | |
918 | def __init__(query, action, creds, endpoint, bucket=None, |
919 | - object_name=None): |
920 | + object_name=None, body_producer=None, |
921 | + receiver_factory=None): |
922 | super(StubQuery, query).__init__(action=action, creds=creds, |
923 | bucket=bucket, |
924 | object_name=object_name) |
925 | @@ -284,7 +290,8 @@ |
926 | class StubQuery(client.Query): |
927 | |
928 | def __init__(query, action, creds, endpoint, bucket=None, |
929 | - object_name=None): |
930 | + object_name=None, body_producer=None, |
931 | + receiver_factory=None): |
932 | super(StubQuery, query).__init__(action=action, creds=creds, |
933 | bucket=bucket, |
934 | object_name=object_name) |
935 | @@ -323,7 +330,8 @@ |
936 | class StubQuery(client.Query): |
937 | |
938 | def __init__(query, action, creds, endpoint, bucket=None, |
939 | - object_name=None): |
940 | + object_name=None, body_producer=None, |
941 | + receiver_factory=None): |
942 | super(StubQuery, query).__init__(action=action, creds=creds, |
943 | bucket=bucket, |
944 | object_name=object_name) |
945 | @@ -360,7 +368,8 @@ |
946 | class StubQuery(client.Query): |
947 | |
948 | def __init__(query, action, creds, endpoint, bucket=None, |
949 | - object_name=None): |
950 | + object_name=None, body_producer=None, |
951 | + receiver_factory=None): |
952 | super(StubQuery, query).__init__(action=action, creds=creds, |
953 | bucket=bucket, |
954 | object_name=object_name) |
955 | @@ -396,7 +405,8 @@ |
956 | class StubQuery(client.Query): |
957 | |
958 | def __init__(query, action, creds, endpoint, bucket=None, |
959 | - object_name=None): |
960 | + object_name=None, body_producer=None, |
961 | + receiver_factory=None): |
962 | super(StubQuery, query).__init__(action=action, creds=creds, |
963 | bucket=bucket, |
964 | object_name=object_name) |
965 | @@ -433,7 +443,8 @@ |
966 | class StubQuery(client.Query): |
967 | |
968 | def __init__(query, action, creds, endpoint, bucket=None, |
969 | - object_name=None): |
970 | + object_name=None, body_producer=None, |
971 | + receiver_factory=None): |
972 | super(StubQuery, query).__init__(action=action, creds=creds, |
973 | bucket=bucket, |
974 | object_name=object_name) |
975 | @@ -473,7 +484,8 @@ |
976 | class StubQuery(client.Query): |
977 | |
978 | def __init__(query, action, creds, endpoint, bucket=None, |
979 | - object_name=None): |
980 | + object_name=None, body_producer=None, |
981 | + receiver_factory=None): |
982 | super(StubQuery, query).__init__(action=action, creds=creds, |
983 | bucket=bucket, |
984 | object_name=object_name) |
985 | @@ -509,7 +521,8 @@ |
986 | class StubQuery(client.Query): |
987 | |
988 | def __init__(query, action, creds, endpoint, bucket=None, |
989 | - object_name=None): |
990 | + object_name=None, body_producer=None, |
991 | + receiver_factory=None): |
992 | super(StubQuery, query).__init__(action=action, creds=creds, |
993 | bucket=bucket, |
994 | object_name=object_name) |
995 | @@ -546,7 +559,8 @@ |
996 | class StubQuery(client.Query): |
997 | |
998 | def __init__(query, action, creds, endpoint, bucket=None, |
999 | - object_name=None): |
1000 | + object_name=None, body_producer=None, |
1001 | + receiver_factory=None): |
1002 | super(StubQuery, query).__init__(action=action, creds=creds, |
1003 | bucket=bucket, |
1004 | object_name=object_name) |
1005 | @@ -576,7 +590,8 @@ |
1006 | |
1007 | class StubQuery(client.Query): |
1008 | |
1009 | - def __init__(query, action, creds, endpoint, bucket=None): |
1010 | + def __init__(query, action, creds, endpoint, bucket=None, |
1011 | + body_producer=None, receiver_factory=None): |
1012 | super(StubQuery, query).__init__( |
1013 | action=action, creds=creds, bucket=bucket) |
1014 | self.assertEquals(action, "DELETE") |
1015 | @@ -599,7 +614,8 @@ |
1016 | class StubQuery(client.Query): |
1017 | |
1018 | def __init__(query, action, creds, endpoint, bucket=None, |
1019 | - object_name=None, data=""): |
1020 | + object_name=None, data="", body_producer=None, |
1021 | + receiver_factory=None): |
1022 | super(StubQuery, query).__init__(action=action, creds=creds, |
1023 | bucket=bucket, |
1024 | object_name=object_name, |
1025 | @@ -630,7 +646,8 @@ |
1026 | class StubQuery(client.Query): |
1027 | |
1028 | def __init__(query, action, creds, endpoint, bucket=None, |
1029 | - object_name=None, data=""): |
1030 | + object_name=None, data="", receiver_factory=None, |
1031 | + body_producer=None): |
1032 | super(StubQuery, query).__init__(action=action, creds=creds, |
1033 | bucket=bucket, |
1034 | object_name=object_name, |
1035 | @@ -665,7 +682,7 @@ |
1036 | |
1037 | def __init__(query, action, creds, endpoint, bucket=None, |
1038 | object_name=None, data=None, content_type=None, |
1039 | - metadata=None): |
1040 | + metadata=None, body_producer=None, receiver_factory=None): |
1041 | super(StubQuery, query).__init__( |
1042 | action=action, creds=creds, bucket=bucket, |
1043 | object_name=object_name, data=data, |
1044 | @@ -701,7 +718,7 @@ |
1045 | |
1046 | def __init__(query, action, creds, endpoint, bucket=None, |
1047 | object_name=None, data=None, content_type=None, |
1048 | - metadata=None): |
1049 | + metadata=None, body_producer=None, receiver_factory=None): |
1050 | super(StubQuery, query).__init__( |
1051 | action=action, creds=creds, bucket=bucket, |
1052 | object_name=object_name, data=data, |
1053 | @@ -730,7 +747,8 @@ |
1054 | |
1055 | def __init__(query, action, creds, endpoint, bucket=None, |
1056 | object_name=None, data=None, content_type=None, |
1057 | - metadata=None, amz_headers=None): |
1058 | + metadata=None, amz_headers=None, body_producer=None, |
1059 | + receiver_factory=None): |
1060 | super(StubQuery, query).__init__( |
1061 | action=action, creds=creds, bucket=bucket, |
1062 | object_name=object_name, data=data, |
1063 | @@ -756,6 +774,42 @@ |
1064 | metadata={"key": "some meta data"}, |
1065 | amz_headers={"acl": "public-read"}) |
1066 | |
1067 | + def test_put_object_with_custom_body_producer(self): |
1068 | + |
1069 | + class StubQuery(client.Query): |
1070 | + |
1071 | + def __init__(query, action, creds, endpoint, bucket=None, |
1072 | + object_name=None, data=None, content_type=None, |
1073 | + metadata=None, amz_headers=None, body_producer=None, |
1074 | + receiver_factory=None): |
1075 | + super(StubQuery, query).__init__( |
1076 | + action=action, creds=creds, bucket=bucket, |
1077 | + object_name=object_name, data=data, |
1078 | + content_type=content_type, metadata=metadata, |
1079 | + amz_headers=amz_headers, body_producer=body_producer) |
1080 | + self.assertEqual(action, "PUT") |
1081 | + self.assertEqual(creds.access_key, "foo") |
1082 | + self.assertEqual(creds.secret_key, "bar") |
1083 | + self.assertEqual(query.bucket, "mybucket") |
1084 | + self.assertEqual(query.object_name, "objectname") |
1085 | + self.assertEqual(query.content_type, "text/plain") |
1086 | + self.assertEqual(query.metadata, {"key": "some meta data"}) |
1087 | + self.assertEqual(query.amz_headers, {"acl": "public-read"}) |
1088 | + self.assertIdentical(body_producer, string_producer) |
1089 | + |
1090 | + def submit(query): |
1091 | + return succeed(None) |
1092 | + |
1093 | + |
1094 | + string_producer = StringBodyProducer("some data") |
1095 | + creds = AWSCredentials("foo", "bar") |
1096 | + s3 = client.S3Client(creds, query_factory=StubQuery) |
1097 | + return s3.put_object("mybucket", "objectname", |
1098 | + content_type="text/plain", |
1099 | + metadata={"key": "some meta data"}, |
1100 | + amz_headers={"acl": "public-read"}, |
1101 | + body_producer=string_producer) |
1102 | + |
1103 | def test_copy_object(self): |
1104 | """ |
1105 | L{S3Client.copy_object} creates a L{Query} to copy an object from one |
1106 | @@ -766,7 +820,8 @@ |
1107 | |
1108 | def __init__(query, action, creds, endpoint, bucket=None, |
1109 | object_name=None, data=None, content_type=None, |
1110 | - metadata=None, amz_headers=None): |
1111 | + metadata=None, amz_headers=None, body_producer=None, |
1112 | + receiver_factory=None): |
1113 | super(StubQuery, query).__init__( |
1114 | action=action, creds=creds, bucket=bucket, |
1115 | object_name=object_name, data=data, |
1116 | @@ -798,7 +853,8 @@ |
1117 | |
1118 | def __init__(query, action, creds, endpoint, bucket=None, |
1119 | object_name=None, data=None, content_type=None, |
1120 | - metadata=None, amz_headers=None): |
1121 | + metadata=None, amz_headers=None, body_producer=None, |
1122 | + receiver_factory=None): |
1123 | super(StubQuery, query).__init__( |
1124 | action=action, creds=creds, bucket=bucket, |
1125 | object_name=object_name, data=data, |
1126 | @@ -822,7 +878,7 @@ |
1127 | |
1128 | def __init__(query, action, creds, endpoint, bucket=None, |
1129 | object_name=None, data=None, content_type=None, |
1130 | - metadata=None): |
1131 | + metadata=None, body_producer=None, receiver_factory=None): |
1132 | super(StubQuery, query).__init__( |
1133 | action=action, creds=creds, bucket=bucket, |
1134 | object_name=object_name, data=data, |
1135 | @@ -846,7 +902,7 @@ |
1136 | |
1137 | def __init__(query, action, creds, endpoint, bucket=None, |
1138 | object_name=None, data=None, content_type=None, |
1139 | - metadata=None): |
1140 | + metadata=None, body_producer=None, receiver_factory=None): |
1141 | super(StubQuery, query).__init__( |
1142 | action=action, creds=creds, bucket=bucket, |
1143 | object_name=object_name, data=data, |
1144 | @@ -869,7 +925,8 @@ |
1145 | class StubQuery(client.Query): |
1146 | |
1147 | def __init__(query, action, creds, endpoint, bucket=None, |
1148 | - object_name=None, data=""): |
1149 | + object_name=None, data="", body_producer=None, |
1150 | + receiver_factory=None): |
1151 | super(StubQuery, query).__init__(action=action, creds=creds, |
1152 | bucket=bucket, |
1153 | object_name=object_name, |
1154 | @@ -902,7 +959,8 @@ |
1155 | class StubQuery(client.Query): |
1156 | |
1157 | def __init__(query, action, creds, endpoint, bucket=None, |
1158 | - object_name=None, data=""): |
1159 | + object_name=None, data="", body_producer=None, |
1160 | + receiver_factory=None): |
1161 | super(StubQuery, query).__init__(action=action, creds=creds, |
1162 | bucket=bucket, |
1163 | object_name=object_name, |
1164 | @@ -926,6 +984,44 @@ |
1165 | deferred = s3.get_object_acl("mybucket", "myobject") |
1166 | return deferred.addCallback(check_result) |
1167 | |
1168 | + def test_init_multipart_upload(self): |
1169 | + |
1170 | + class StubQuery(client.Query): |
1171 | + |
1172 | + def __init__(query, action, creds, endpoint, bucket=None, |
1173 | + object_name=None, data="", body_producer=None, |
1174 | + content_type=None, receiver_factory=None, metadata={}, |
1175 | + amz_headers={}): |
1176 | + super(StubQuery, query).__init__(action=action, creds=creds, |
1177 | + bucket=bucket, |
1178 | + amz_headers=amz_headers, |
1179 | + object_name=object_name, |
1180 | + data=data) |
1181 | + self.assertEquals(action, "POST") |
1182 | + self.assertEqual(creds.access_key, "foo") |
1183 | + self.assertEqual(creds.secret_key, "bar") |
1184 | + self.assertEqual(query.bucket, "example-bucket") |
1185 | + self.assertEqual(query.object_name, "example-object?uploads") |
1186 | + self.assertEqual(query.data, "") |
1187 | + self.assertEqual(query.metadata, {}) |
1188 | + self.assertEqual(query.amz_headers, {"acl": "public"}) |
1189 | + |
1190 | + def submit(query, url_context=None): |
1191 | + return succeed(payload.sample_s3_init_multipart_upload_result) |
1192 | + |
1193 | + |
1194 | + def check_result(result): |
1195 | + self.assert_(isinstance(result, MultipartInitiationResponse)) |
1196 | + self.assertEqual(result.bucket, "example-bucket") |
1197 | + self.assertEqual(result.object_name, "example-object") |
1198 | + self.assertEqual(result.upload_id, "deadbeef") |
1199 | + |
1200 | + creds = AWSCredentials("foo", "bar") |
1201 | + s3 = client.S3Client(creds, query_factory=StubQuery) |
1202 | + deferred = s3.init_multipart_upload("example-bucket", "example-object", |
1203 | + amz_headers={"acl": "public"}) |
1204 | + return deferred.addCallback(check_result) |
1205 | + |
1206 | S3ClientTestCase.skip = s3clientSkip |
1207 | |
1208 | |
1209 | @@ -1077,7 +1173,8 @@ |
1210 | """ |
1211 | class StubQuery(client.Query): |
1212 | |
1213 | - def __init__(query, action, creds, endpoint, bucket): |
1214 | + def __init__(query, action, creds, endpoint, bucket, |
1215 | + body_producer=None, receiver_factory=None): |
1216 | super(StubQuery, query).__init__( |
1217 | action=action, creds=creds, bucket=bucket) |
1218 | self.assertEquals(action, "GET") |
1219 | |
1220 | === modified file 'txaws/testing/payload.py' |
1221 | --- txaws/testing/payload.py 2012-01-28 00:39:00 +0000 |
1222 | +++ txaws/testing/payload.py 2012-02-15 22:12:17 +0000 |
1223 | @@ -1085,3 +1085,12 @@ |
1224 | <Status>Enabled</Status> |
1225 | <MfaDelete>Disabled</MfaDelete> |
1226 | </VersioningConfiguration>""" |
1227 | + |
1228 | +sample_s3_init_multipart_upload_result = """\ |
1229 | +<InitiateMultipartUploadResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"> |
1230 | + <Bucket>example-bucket</Bucket> |
1231 | + <Key>example-object</Key> |
1232 | + <UploadId>deadbeef</UploadId> |
1233 | +</InitiateMultipartUploadResult>""" |
1234 | + |
1235 | + |
1236 | |
1237 | === added file 'txaws/testing/producers.py' |
1238 | --- txaws/testing/producers.py 1970-01-01 00:00:00 +0000 |
1239 | +++ txaws/testing/producers.py 2012-02-15 22:12:17 +0000 |
1240 | @@ -0,0 +1,23 @@ |
1241 | +from zope.interface import implements |
1242 | + |
1243 | +from twisted.internet.defer import succeed |
1244 | +from twisted.web.iweb import IBodyProducer |
1245 | + |
1246 | +class StringBodyProducer(object): |
1247 | + implements(IBodyProducer) |
1248 | + |
1249 | + def __init__(self, data): |
1250 | + self.data = data |
1251 | + self.length = len(data) |
1252 | + self.written = None |
1253 | + |
1254 | + def startProducing(self, consumer): |
1255 | + consumer.write(self.data) |
1256 | + self.written = self.data |
1257 | + return succeed(None) |
1258 | + |
1259 | + def pauseProducing(self): |
1260 | + pass |
1261 | + |
1262 | + def stopProducing(self): |
1263 | + pass |