Merge lp:~therve/storm/twisted-integration into lp:storm
- twisted-integration
- Merge into trunk
Status: | Work in progress |
---|---|
Proposed branch: | lp:~therve/storm/twisted-integration |
Merge into: | lp:storm |
Diff against target: |
2287 lines (+2201/-1) 9 files modified
storm/database.py (+18/-1) storm/exceptions.py (+4/-0) storm/twisted/__init__.py (+23/-0) storm/twisted/store.py (+480/-0) storm/twisted/wrapper.py (+206/-0) tests/twisted/base.py (+1286/-0) tests/twisted/mysql.py (+62/-0) tests/twisted/postgres.py (+57/-0) tests/twisted/sqlite.py (+65/-0) |
To merge this branch: | bzr merge lp:~therve/storm/twisted-integration |
Related bugs: |
Reviewer | Review Type | Date Requested | Status |
---|---|---|---|
Christopher Armstrong (community) | Abstain | ||
Review via email: mp+3733@code.launchpad.net |
Commit message
Description of the change
Gabriel (gabriel-rossetti) wrote : | # |
Christopher Armstrong wrote:
> Christopher Armstrong has proposed merging lp:~therve/storm/twisted-integration into lp:storm.
>
> Requested reviews:
> Storm Developers (storm)
>
I second that :-)
Gabriel
Christopher Armstrong (radix) wrote : | # |
[7] Expose reload to DeferredStore
[8] Maybe add deferredStore.
Drew Smathers (djfroofy) wrote : | # |
> [4] It may be useful to extend the core Store API to allow us to avoid accessing internal attributes in the wrapper. Like, for example, Store.disable_
That sounds like a great idea to me. Otherwise having to refetch objects just to avoid lazy value errors in less than ideal.
- 215. By Thomas Herve
-
Merge from trunk.
- 216. By Thomas Herve
-
Move the tests to be able to run Postgres/MySQL tests as well
- 217. By Thomas Herve
-
Add postgres tests
- 218. By Thomas Herve
-
Add mysql, some tests failing.
- 219. By Thomas Herve
-
Use InnoDB to be able to use transactions.
- 220. By Thomas Herve
-
Merge from trunk.
- 221. By Thomas Herve
-
Force respect of the max_stores value, by incrementing the number of stores
once start_store is called, instead of doing it in the callback [f=373816] - 222. By Thomas Herve
-
Use TestHelper/MakePath to manage sqlite file.
- 223. By Thomas Herve
-
Fix doc typos
- 224. By Thomas Herve
-
Merge from trunk
- 225. By Thomas Herve
-
Some tests cleanup
- 226. By Thomas Herve
-
Merge from trunk
- 227. By Thomas Herve
-
Cleanups
- 228. By Thomas Herve
-
Check thread in Connection to catch problems.
- 229. By Thomas Herve
-
Handle errors in rollback
- 230. By Thomas Herve
-
Typo, cleanup
- 231. By Thomas Herve
-
Make start/stop methods private
Unmerged revisions
- 231. By Thomas Herve
-
Make start/stop methods private
- 230. By Thomas Herve
-
Typo, cleanup
- 229. By Thomas Herve
-
Handle errors in rollback
- 228. By Thomas Herve
-
Check thread in Connection to catch problems.
- 227. By Thomas Herve
-
Cleanups
- 226. By Thomas Herve
-
Merge from trunk
- 225. By Thomas Herve
-
Some tests cleanup
- 224. By Thomas Herve
-
Merge from trunk
- 223. By Thomas Herve
-
Fix doc typos
- 222. By Thomas Herve
-
Use TestHelper/MakePath to manage sqlite file.
Preview Diff
1 | === modified file 'storm/database.py' |
2 | --- storm/database.py 2009-07-30 06:19:27 +0000 |
3 | +++ storm/database.py 2009-12-27 11:31:14 +0000 |
4 | @@ -24,12 +24,13 @@ |
5 | This is the common code for database support; specific databases are |
6 | supported in modules in L{storm.databases}. |
7 | """ |
8 | +import threading |
9 | |
10 | from storm.expr import Expr, State, compile |
11 | from storm.tracer import trace |
12 | from storm.variables import Variable |
13 | from storm.exceptions import ( |
14 | - ClosedError, DatabaseError, DisconnectionError, Error) |
15 | + ClosedError, DatabaseError, DisconnectionError, Error, ThreadSafetyError) |
16 | from storm.uri import URI |
17 | import storm |
18 | |
19 | @@ -180,6 +181,20 @@ |
20 | self._database = database # Ensures deallocation order. |
21 | self._event = event |
22 | self._raw_connection = self._database.raw_connect() |
23 | + self._thread = threading.currentThread() |
24 | + |
25 | + def _check_thread(self): |
26 | + """ |
27 | + Check if the current thread is the same thread that created the |
28 | + connection. This is a safety check that prevents breaking thread |
29 | + boundaries which can create weird bugs. C{execute}, C{commit} and |
30 | + C{rollback} should call it, directly or indirectly. |
31 | + """ |
32 | + thread = threading.currentThread() |
33 | + if thread is not self._thread: |
34 | + raise ThreadSafetyError( |
35 | + "'%s' is not the connection thread '%s'" % |
36 | + (thread, self._thread)) |
37 | |
38 | def __del__(self): |
39 | """Close the connection.""" |
40 | @@ -240,6 +255,7 @@ |
41 | |
42 | def rollback(self): |
43 | """Rollback the connection.""" |
44 | + self._check_thread() |
45 | if self._state == STATE_CONNECTED: |
46 | try: |
47 | self._raw_connection.rollback() |
48 | @@ -314,6 +330,7 @@ |
49 | If the connection is marked as dead, or if we can't reconnect, |
50 | then raise DisconnectionError. |
51 | """ |
52 | + self._check_thread() |
53 | if self._state == STATE_CONNECTED: |
54 | return |
55 | elif self._state == STATE_DISCONNECTED: |
56 | |
57 | === modified file 'storm/exceptions.py' |
58 | --- storm/exceptions.py 2009-05-15 08:43:21 +0000 |
59 | +++ storm/exceptions.py 2009-12-27 11:31:14 +0000 |
60 | @@ -82,6 +82,10 @@ |
61 | pass |
62 | |
63 | |
64 | +class ThreadSafetyError(StormError): |
65 | + """Exception raised when cross-threads operations are attempted.""" |
66 | + |
67 | + |
68 | class Error(StormError): |
69 | pass |
70 | |
71 | |
72 | === added directory 'storm/twisted' |
73 | === added file 'storm/twisted/__init__.py' |
74 | --- storm/twisted/__init__.py 1970-01-01 00:00:00 +0000 |
75 | +++ storm/twisted/__init__.py 2009-12-27 11:31:14 +0000 |
76 | @@ -0,0 +1,23 @@ |
77 | +# |
78 | +# Copyright (c) 2007 Canonical |
79 | +# Copyright (c) 2007 Thomas Herve <thomas@nimail.org> |
80 | +# |
81 | +# This file is part of Storm Object Relational Mapper. |
82 | +# |
83 | +# Storm is free software; you can redistribute it and/or modify |
84 | +# it under the terms of the GNU Lesser General Public License as |
85 | +# published by the Free Software Foundation; either version 2.1 of |
86 | +# the License, or (at your option) any later version. |
87 | +# |
88 | +# Storm is distributed in the hope that it will be useful, |
89 | +# but WITHOUT ANY WARRANTY; without even the implied warranty of |
90 | +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
91 | +# GNU Lesser General Public License for more details. |
92 | +# |
93 | +# You should have received a copy of the GNU Lesser General Public License |
94 | +# along with this program. If not, see <http://www.gnu.org/licenses/>. |
95 | +# |
96 | + |
97 | +""" |
98 | +Asynchronous wrapper around storm to be used within a Twisted application. |
99 | +""" |
100 | |
101 | === added file 'storm/twisted/store.py' |
102 | --- storm/twisted/store.py 1970-01-01 00:00:00 +0000 |
103 | +++ storm/twisted/store.py 2009-12-27 11:31:14 +0000 |
104 | @@ -0,0 +1,480 @@ |
105 | +# |
106 | +# Copyright (c) 2007 Canonical |
107 | +# Copyright (c) 2007 Thomas Herve <thomas@nimail.org> |
108 | +# |
109 | +# This file is part of Storm Object Relational Mapper. |
110 | +# |
111 | +# Storm is free software; you can redistribute it and/or modify |
112 | +# it under the terms of the GNU Lesser General Public License as |
113 | +# published by the Free Software Foundation; either version 2.1 of |
114 | +# the License, or (at your option) any later version. |
115 | +# |
116 | +# Storm is distributed in the hope that it will be useful, |
117 | +# but WITHOUT ANY WARRANTY; without even the implied warranty of |
118 | +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
119 | +# GNU Lesser General Public License for more details. |
120 | +# |
121 | +# You should have received a copy of the GNU Lesser General Public License |
122 | +# along with this program. If not, see <http://www.gnu.org/licenses/>. |
123 | +# |
124 | + |
125 | +""" |
126 | +Store wrapper and custom thread runner. |
127 | +""" |
128 | + |
129 | +from threading import Thread |
130 | +from Queue import Queue |
131 | + |
132 | +from storm.store import Store, AutoReload |
133 | +from storm.info import get_obj_info |
134 | +from storm.twisted.wrapper import partial, DeferredResult, DeferredResultSet |
135 | + |
136 | +from twisted.internet.defer import Deferred, deferredGenerator, maybeDeferred |
137 | +from twisted.internet.defer import waitForDeferred, succeed, fail |
138 | +from twisted.python.failure import Failure |
139 | + |
140 | + |
141 | + |
142 | +class AlreadyStopped(Exception): |
143 | + """ |
144 | + Except raised when a store is stopped multiple time. |
145 | + """ |
146 | + |
147 | + |
148 | + |
149 | +class StoreThread(Thread): |
150 | + """ |
151 | + A thread class that wraps methods calls and fires deferred in the reactor |
152 | + thread. |
153 | + """ |
154 | + STOP = object() |
155 | + |
156 | + def __init__(self): |
157 | + """ |
158 | + Initialize the thread, and create a L{Queue} to stack jobs. |
159 | + """ |
160 | + Thread.__init__(self) |
161 | + self.setDaemon(True) |
162 | + self._queue = Queue() |
163 | + self._stop_deferred = None |
164 | + self.stopped = False |
165 | + |
166 | + |
167 | + def defer_to_thread(self, f, *args, **kwargs): |
168 | + """ |
169 | + Run the given function in the thread, wrapping the result with a |
170 | + L{Deferred}. |
171 | + |
172 | + @return: a deferred whose result will be the result of the function. |
173 | + @rtype: C{Deferred} |
174 | + """ |
175 | + if self.stopped: |
176 | + # to prevent having pending calls after the thread got stopped |
177 | + return fail(AlreadyStopped(f)) |
178 | + d = Deferred() |
179 | + self._queue.put((d, f, args, kwargs)) |
180 | + return d |
181 | + |
182 | + |
183 | + def run(self): |
184 | + """ |
185 | + Main execution loop: retrieve jobs from the queue and run them. |
186 | + """ |
187 | + from twisted.internet import reactor |
188 | + o = self._queue.get() |
189 | + while o is not self.STOP: |
190 | + d, f, args, kwargs = o |
191 | + try: |
192 | + result = f(*args, **kwargs) |
193 | + except: |
194 | + f = Failure() |
195 | + reactor.callFromThread(d.errback, f) |
196 | + else: |
197 | + reactor.callFromThread(d.callback, result) |
198 | + o = self._queue.get() |
199 | + reactor.callFromThread(self._stop_deferred.callback, None) |
200 | + |
201 | + |
202 | + def stop(self): |
203 | + """ |
204 | + Stop the thread. |
205 | + """ |
206 | + if self.stopped: |
207 | + return self._stop_deferred |
208 | + self._stop_deferred = Deferred() |
209 | + self._queue.put(self.STOP) |
210 | + self.stopped = True |
211 | + return self._stop_deferred |
212 | + |
213 | + |
214 | + |
215 | +class DeferredStore(object): |
216 | + """ |
217 | + A wrapper around L{Store} to have async operations. |
218 | + """ |
219 | + store = None |
220 | + |
221 | + def __init__(self, database): |
222 | + """ |
223 | + @param database: instance of database providing connection, used to |
224 | + instantiate the store later. |
225 | + @type database: L{storm.database.Database} |
226 | + """ |
227 | + self.thread = StoreThread() |
228 | + self.database = database |
229 | + self.started = False |
230 | + |
231 | + |
232 | + def start(self): |
233 | + """ |
234 | + Start the store. |
235 | + |
236 | + @return: a deferred that will fire once the store is started. |
237 | + """ |
238 | + if not self.started: |
239 | + self.started = True |
240 | + self.thread.start() |
241 | + # Add a event trigger to be sure that the thread is stopped |
242 | + from twisted.internet import reactor |
243 | + reactor.addSystemEventTrigger( |
244 | + "before", "shutdown", self.stop) |
245 | + return self.thread.defer_to_thread(Store, self.database |
246 | + ).addCallback(self._got_store) |
247 | + else: |
248 | + raise RuntimeError("Already started") |
249 | + |
250 | + |
251 | + def _got_store(self, store): |
252 | + """ |
253 | + Internal method called when the store is created, initializing most of |
254 | + the API methods. |
255 | + """ |
256 | + self.store = store |
257 | + # Maybe not ? |
258 | + self.store._deferredStore = self |
259 | + for methodName in ("commit", "flush", "remove", "reload", |
260 | + "rollback"): |
261 | + method = partial(self.thread.defer_to_thread, |
262 | + getattr(self.store, methodName)) |
263 | + setattr(self, methodName, method) |
264 | + |
265 | + self._do_resolve_lazy_value = self.store._resolve_lazy_value |
266 | + self.store._resolve_lazy_value = self._resolve_lazy_value |
267 | + |
268 | + |
269 | + def get(self, cls, key): |
270 | + def _get(): |
271 | + obj = self.store.get(cls, key) |
272 | + if obj is not None: |
273 | + obj_info = get_obj_info(obj) |
274 | + self._do_resolve_lazy_value(obj_info, None, AutoReload) |
275 | + return obj |
276 | + return self.thread.defer_to_thread(_get) |
277 | + |
278 | + |
279 | + def add(self, obj): |
280 | + """ |
281 | + Specific add method that doesn't return any result, to not make think |
282 | + that it's something usable. |
283 | + """ |
284 | + def _add(): |
285 | + self.store.add(obj) |
286 | + return self.thread.defer_to_thread(_add) |
287 | + |
288 | + |
289 | + def execute(self, *args, **kwargs): |
290 | + """ |
291 | + Wrapper around C{execute} to have a C{DeferredResult} instead of the |
292 | + standard L{storm.database.Result} object. |
293 | + """ |
294 | + if self.store is None: |
295 | + raise RuntimeError("Store not started") |
296 | + return self.thread.defer_to_thread( |
297 | + self.store.execute, *args, **kwargs |
298 | + ).addCallback(self._cb_execute) |
299 | + |
300 | + |
301 | + def _cb_execute(self, result): |
302 | + """ |
303 | + Wrap the result with a C{DeferredResult}. |
304 | + """ |
305 | + if result is not None: |
306 | + return DeferredResult(self.thread, result) |
307 | + |
308 | + |
309 | + def find(self, *args, **kwargs): |
310 | + """ |
311 | + Wrapper around C{find}. |
312 | + """ |
313 | + if self.store is None: |
314 | + raise RuntimeError("Store not started") |
315 | + return self.thread.defer_to_thread( |
316 | + self.store.find, *args, **kwargs |
317 | + ).addCallback(self._cb_find) |
318 | + |
319 | + |
320 | + def _cb_find(self, resultSet): |
321 | + """ |
322 | + Wrap the result set with a C{DeferredResultSet}. |
323 | + """ |
324 | + return DeferredResultSet(self.thread, resultSet) |
325 | + |
326 | + |
327 | + def stop(self): |
328 | + """ |
329 | + Stop the store. |
330 | + """ |
331 | + if self.thread.stopped: |
332 | + return succeed(None) |
333 | + def close(): |
334 | + self.store.rollback() |
335 | + self.store.close() |
336 | + return self.thread.defer_to_thread(close |
337 | + ).addCallback(lambda ign: self.thread.stop()) |
338 | + |
339 | + |
340 | + def _resolve_lazy_value(self, *args): |
341 | + raise RuntimeError( |
342 | + "Resolving lazy values with the Twisted wrapper is not possible " |
343 | + "right now! Please refetch your object using " |
344 | + "store.get/store.find") |
345 | + |
346 | + |
347 | + @staticmethod |
348 | + def of(obj): |
349 | + """ |
350 | + Get the DeferredStore object is associated with |
351 | + |
352 | + If the given object has not been associated with a DeferredStore, |
353 | + return None. |
354 | + """ |
355 | + store = Store.of(obj) |
356 | + if not store: |
357 | + return |
358 | + return getattr(store, '_deferredStore', None) |
359 | + |
360 | + |
361 | + |
362 | +class StorePool(object): |
363 | + """ |
364 | + A pool of started stores, maintaining persistent connections. |
365 | + """ |
366 | + started = False |
367 | + store_factory = DeferredStore |
368 | + |
369 | + def __init__(self, database, min_stores=0, max_stores=10): |
370 | + """ |
371 | + @param database: instance of database providing connection, used to |
372 | + instantiate the store later. |
373 | + @type database: L{storm.database.Database} |
374 | + |
375 | + @param min_stores: initial number of stores. |
376 | + @type min_stores: C{int} |
377 | + |
378 | + @param max_stores: maximum number of stores. |
379 | + @type max_stores: C{int} |
380 | + """ |
381 | + self.database = database |
382 | + self.min_stores = min_stores |
383 | + self.max_stores = max_stores |
384 | + self._stores = [] |
385 | + self._stores_created = 0 |
386 | + self._pending_get = [] |
387 | + self._store_refs = [] |
388 | + |
389 | + |
390 | + def start(self): |
391 | + """ |
392 | + Start the pool. |
393 | + """ |
394 | + if self.started: |
395 | + raise RuntimeError("Already started") |
396 | + self.started = True |
397 | + return self.adjust_size() |
398 | + |
399 | + |
400 | + def stop(self): |
401 | + """ |
402 | + Stop the pool: this is not a total stop, it just try to kill the |
403 | + current available stores. |
404 | + """ |
405 | + return self.adjust_size(0, 0, self._store_refs) |
406 | + |
407 | + |
408 | + def _start_store(self): |
409 | + """ |
410 | + Create a new store. |
411 | + """ |
412 | + store = self.store_factory(self.database) |
413 | + # Increment here, so that other simultaneous calls don't make the |
414 | + # number of connections pass the maximum |
415 | + self._stores_created += 1 |
416 | + return store.start( |
417 | + ).addCallback(self._cb_start_store, store |
418 | + ).addErrback(self._eb_start_store) |
419 | + |
420 | + |
421 | + def _cb_start_store(self, ign, store): |
422 | + """ |
423 | + Add the created store to the list of available stores. |
424 | + """ |
425 | + self._stores.append(store) |
426 | + self._store_refs.append(store) |
427 | + |
428 | + |
429 | + def _eb_start_store(self, failure): |
430 | + """ |
431 | + Reduce the amount of created stores, and let the failure propagate. |
432 | + """ |
433 | + self._stores_created -= 1 |
434 | + return failure |
435 | + |
436 | + |
437 | + def _stop_store(self, stores=None): |
438 | + """ |
439 | + Stop a store and remove it from the available stores. |
440 | + """ |
441 | + if stores is None: |
442 | + stores = self._stores |
443 | + self._stores_created -= 1 |
444 | + store = stores.pop() |
445 | + return store.stop() |
446 | + |
447 | + |
448 | + @deferredGenerator |
449 | + def adjust_size(self, min_stores=None, max_stores=None, stores=None): |
450 | + """ |
451 | + Change the number of available stores, shrinking or raising as |
452 | + necessary. |
453 | + """ |
454 | + if min_stores is None: |
455 | + min_stores = self.min_stores |
456 | + if max_stores is None: |
457 | + max_stores = self.max_stores |
458 | + if stores is None: |
459 | + stores = self._stores |
460 | + |
461 | + if min_stores < 0: |
462 | + raise ValueError('minimum is negative') |
463 | + if min_stores > max_stores: |
464 | + raise ValueError('minimum is greater than maximum') |
465 | + |
466 | + self.min_stores = min_stores |
467 | + self.max_stores = max_stores |
468 | + if not self.started: |
469 | + return |
470 | + |
471 | + # Kill of some stores if we have too many. |
472 | + while self._stores_created > self.max_stores and stores: |
473 | + wfd = waitForDeferred(self._stop_store(stores)) |
474 | + yield wfd |
475 | + wfd.getResult() |
476 | + # Start some stores if we have too few. |
477 | + while self._stores_created < self.min_stores: |
478 | + wfd = waitForDeferred(self._start_store()) |
479 | + yield wfd |
480 | + wfd.getResult() |
481 | + |
482 | + |
483 | + def get(self): |
484 | + """ |
485 | + Return a started store from the pool, or start a new one if necessary. |
486 | + A store retrieve by this way should be put back using the put |
487 | + method, or it won't be used anymore. |
488 | + """ |
489 | + if not self.started: |
490 | + raise RuntimeError("Not started") |
491 | + if self._stores: |
492 | + store = self._stores.pop() |
493 | + return succeed(store) |
494 | + elif self._stores_created < self.max_stores: |
495 | + return self._start_store().addCallback(self._cb_get) |
496 | + else: |
497 | + # Maybe all stores are consumed? |
498 | + return self.adjust_size().addCallback(self._cb_get) |
499 | + |
500 | + |
501 | + def _cb_get(self, ign): |
502 | + """ |
503 | + If the previous operation added a store, return it, or return a pending |
504 | + C{Deferred}. |
505 | + """ |
506 | + if self._stores: |
507 | + store = self._stores.pop() |
508 | + return store |
509 | + else: |
510 | + # All stores are in used, wait |
511 | + d = Deferred() |
512 | + self._pending_get.append(d) |
513 | + return d |
514 | + |
515 | + |
516 | + def put(self, store): |
517 | + """ |
518 | + Make a store available again. |
519 | + |
520 | + This should be done explicitely to have the store back in the pool. |
521 | + The good way to use the pool is this: |
522 | + |
523 | + >>> d1 = pool.get() |
524 | + |
525 | + >>> # d1 callback with a store |
526 | + >>> d2 = store.add(foo) |
527 | + >>> d2.addCallback(doSomething).addErrback(manageErrors) |
528 | + >>> d2.addBoth(lambda x: pool.put(store)) |
529 | + """ |
530 | + return store.rollback().addBoth(self._cb_put, store) |
531 | + |
532 | + |
533 | + def _cb_put(self, passthrough, store): |
534 | + """ |
535 | + Once the rollback has finished, the store is really available. |
536 | + """ |
537 | + if self._pending_get: |
538 | + # People are waiting, fire with the store |
539 | + d = self._pending_get.pop(0) |
540 | + d.callback(store) |
541 | + else: |
542 | + self._stores.append(store) |
543 | + return passthrough |
544 | + |
545 | + |
546 | + def transact(self, f, *args, **kwargs): |
547 | + """ |
548 | + Call function C{f} with a L{Store} instance and arguments C{args} and |
549 | + C{kwargs} in transaction bound to the acquired store. If transaction |
550 | + succeeds, store will be commited. Store is returned to this pool after |
551 | + call to C{f} completes. |
552 | + |
553 | + Note that the function C{f} must return an instance of L{Deferred}. |
554 | + |
555 | + @param f: function to call in transaction |
556 | + @param args: positional arguments to function C{f} |
557 | + @param kwargs: keyword arguments to function C{f} |
558 | + """ |
559 | + return self.get( |
560 | + ).addCallback(self._cb_transact_start, f, args, kwargs) |
561 | + |
562 | + |
563 | + def _cb_transact_start(self, store, f, args, kwargs): |
564 | + """ |
565 | + Call transacted function with acquired store. |
566 | + """ |
567 | + result = maybeDeferred(f, store, *args, **kwargs) |
568 | + result.addCallback(self._cb_transact_success, store) |
569 | + result.addBoth(self._cb_transact_stop, store) |
570 | + return result |
571 | + |
572 | + |
573 | + def _cb_transact_success(self, result, store): |
574 | + """ |
575 | + Commit and pass through function result. |
576 | + """ |
577 | + return store.commit().addCallback(lambda ignore: result) |
578 | + |
579 | + |
580 | + def _cb_transact_stop(self, result, store): |
581 | + """ |
582 | + Return the store back to the pool and pass through the result again. |
583 | + """ |
584 | + return self.put(store).addCallback(lambda ignore: result) |
585 | |
586 | === added file 'storm/twisted/wrapper.py' |
587 | --- storm/twisted/wrapper.py 1970-01-01 00:00:00 +0000 |
588 | +++ storm/twisted/wrapper.py 2009-12-27 11:31:14 +0000 |
589 | @@ -0,0 +1,206 @@ |
590 | +# |
591 | +# Copyright (c) 2007 Canonical |
592 | +# Copyright (c) 2007 Thomas Herve <thomas@nimail.org> |
593 | +# |
594 | +# This file is part of Storm Object Relational Mapper. |
595 | +# |
596 | +# Storm is free software; you can redistribute it and/or modify |
597 | +# it under the terms of the GNU Lesser General Public License as |
598 | +# published by the Free Software Foundation; either version 2.1 of |
599 | +# the License, or (at your option) any later version. |
600 | +# |
601 | +# Storm is distributed in the hope that it will be useful, |
602 | +# but WITHOUT ANY WARRANTY; without even the implied warranty of |
603 | +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
604 | +# GNU Lesser General Public License for more details. |
605 | +# |
606 | +# You should have received a copy of the GNU Lesser General Public License |
607 | +# along with this program. If not, see <http://www.gnu.org/licenses/>. |
608 | +# |
609 | + |
610 | +""" |
611 | +Asynchronous wrapper around storm. |
612 | +""" |
613 | + |
614 | +from storm.store import Store |
615 | +from storm.references import Reference, ReferenceSet |
616 | + |
617 | + |
618 | +try: |
619 | + from functools import partial |
620 | +except ImportError: |
621 | + # For Python < 2.5 |
622 | + class partial(object): |
623 | + def __init__(self, fn, *args, **kw): |
624 | + self.fn = fn |
625 | + self.args = args |
626 | + self.kw = kw |
627 | + |
628 | + def __call__(self, *args, **kw): |
629 | + if kw and self.kw: |
630 | + d = self.kw.copy() |
631 | + d.update(kw) |
632 | + else: |
633 | + d = kw or self.kw |
634 | + return self.fn(*(self.args + args), **d) |
635 | + |
636 | + |
637 | + |
638 | +class DeferredResult(object): |
639 | + """ |
640 | + Proxy for a storm result, running the blocking methods in a thread and |
641 | + returning C{Deferred}s. |
642 | + """ |
643 | + |
644 | + def __init__(self, thread, result): |
645 | + """ |
646 | + @param thread: the running thread of the store |
647 | + @type thread: C{StoreThread} |
648 | + |
649 | + @param result: the result instance to be wrapped. |
650 | + @type result: C{storm.database.Result} |
651 | + """ |
652 | + self.result = result |
653 | + for methodName in ("get_one", "get_all"): |
654 | + method = partial(thread.defer_to_thread, |
655 | + getattr(result, methodName)) |
656 | + setattr(self, methodName, method) |
657 | + |
658 | + |
659 | + |
660 | +class DeferredResultSet(object): |
661 | + """ |
662 | + Wrapper for a L{storm.store.ResultSet}. |
663 | + """ |
664 | + |
665 | + def __init__(self, thread, resultSet): |
666 | + """ |
667 | + Create the results with given C{StoreThread} and the set to wrap. |
668 | + """ |
669 | + self._thread = thread |
670 | + self._resultSet = resultSet |
671 | + for methodName in ("any", "one", "first", "last", "remove", "count", |
672 | + "max", "min", "avg", "sum", "set", "is_empty"): |
673 | + method = partial(thread.defer_to_thread, |
674 | + getattr(resultSet, methodName)) |
675 | + setattr(self, methodName, method) |
676 | + for methodName in ("union", "difference", "intersection"): |
677 | + method = partial(self._set_expr, |
678 | + getattr(resultSet, methodName)) |
679 | + setattr(self, methodName, method) |
680 | + for methodName in ("order_by", "config", "group_by", "having"): |
681 | + setattr(self, methodName, getattr(self._resultSet, methodName)) |
682 | + |
683 | + |
684 | + def all(self): |
685 | + """ |
686 | + Specific method to emulate C{__iter__}. |
687 | + """ |
688 | + return self._thread.defer_to_thread(list, self._resultSet) |
689 | + |
690 | + |
691 | + def values(self, *columns): |
692 | + """ |
693 | + Wrapper around values that remove the iterator feature to return a list |
694 | + instead. |
695 | + """ |
696 | + def _get_values(): |
697 | + return list(self._resultSet.values(*columns)) |
698 | + return self._thread.defer_to_thread(_get_values) |
699 | + |
700 | + |
701 | + def _set_expr(self, method, other, all=False): |
702 | + """ |
703 | + Wrap a set expression with a C{DeferredResultSet}. |
704 | + """ |
705 | + return DeferredResultSet(self._thread, method(other, all)) |
706 | + |
707 | + |
708 | + |
709 | +class DeferredReference(Reference): |
710 | + """ |
711 | + A reference property but within a C{Deferred}. |
712 | + """ |
713 | + |
714 | + def __get__(self, local, cls=None): |
715 | + """ |
716 | + Wrapper around C{Reference.__get__}. |
717 | + """ |
718 | + store = Store.of(local) |
719 | + if store is None: |
720 | + return None |
721 | + _thread = store._deferredStore.thread |
722 | + return _thread.defer_to_thread(Reference.__get__, self, local, cls) |
723 | + |
724 | + |
725 | + def __set__(self, local, remote): |
726 | + """ |
727 | + Wrapper around C{Reference.__set__}. |
728 | + """ |
729 | + raise RuntimeError("Can't set a DeferredReference") |
730 | + |
731 | + |
732 | + |
733 | +class DeferredReferenceSet(ReferenceSet): |
734 | + """ |
735 | + A C{ReferenceSet} but within a C{Deferred}. |
736 | + """ |
737 | + |
738 | + def __get__(self, local, cls=None): |
739 | + """ |
740 | + Wrapper around C{ReferenceSet.__get__}. |
741 | + """ |
742 | + store = Store.of(local) |
743 | + if store is None: |
744 | + return None |
745 | + _thread = store._deferredStore.thread |
746 | + boundReference = ReferenceSet.__get__(self, local, cls) |
747 | + return DeferredBoundReference(_thread, boundReference) |
748 | + |
749 | + |
750 | + |
751 | +class DeferredBoundReference(object): |
752 | + """ |
753 | + Wrapper around C{BoundReferenceSet} and C{BoundIndirectReferenceSet}. |
754 | + """ |
755 | + |
756 | + def __init__(self, thread, boundReference): |
757 | + """ |
758 | + Create the reference with given C{StoreThread} and the reference to |
759 | + wrap. |
760 | + """ |
761 | + self._thread = thread |
762 | + self._boundReference = boundReference |
763 | + for methodName in ("clear", "add", "remove", "any", "count", "one", |
764 | + "first", "last"): |
765 | + method = partial(thread.defer_to_thread, |
766 | + getattr(boundReference, methodName)) |
767 | + setattr(self, methodName, method) |
768 | + for methodName in ("order_by", "find"): |
769 | + method = partial(self._defer_and_wrap_result, |
770 | + getattr(boundReference, methodName)) |
771 | + setattr(self, methodName, method) |
772 | + |
773 | + |
774 | + def all(self): |
775 | + """ |
776 | + Specific method to emulate C{__iter__}. |
777 | + """ |
778 | + return self._thread.defer_to_thread(list, self._boundReference) |
779 | + |
780 | + |
781 | + def values(self, *columns): |
782 | + """ |
783 | + Emulate the values method. |
784 | + """ |
785 | + def _get_values(): |
786 | + return list(self._boundReference.values(*columns)) |
787 | + return self._thread.defer_to_thread(_get_values) |
788 | + |
789 | + |
790 | + def _defer_and_wrap_result(self, method, *args, **kwargs): |
791 | + """ |
792 | + Helper for methods returning another C{ResultSet}. |
793 | + """ |
794 | + return self._thread.defer_to_thread(method, *args, **kwargs |
795 | + ).addCallback(lambda x: DeferredResultSet(self._thread, x)) |
796 | |
797 | === added directory 'tests/twisted' |
798 | === added file 'tests/twisted/__init__.py' |
799 | === added file 'tests/twisted/base.py' |
800 | --- tests/twisted/base.py 1970-01-01 00:00:00 +0000 |
801 | +++ tests/twisted/base.py 2009-12-27 11:31:15 +0000 |
802 | @@ -0,0 +1,1286 @@ |
803 | +# |
804 | +# Copyright (c) 2007 Canonical |
805 | +# Copyright (c) 2007 Thomas Herve <thomas@nimail.org> |
806 | +# |
807 | +# This file is part of Storm Object Relational Mapper. |
808 | +# |
809 | +# Storm is free software; you can redistribute it and/or modify |
810 | +# it under the terms of the GNU Lesser General Public License as |
811 | +# published by the Free Software Foundation; either version 2.1 of |
812 | +# the License, or (at your option) any later version. |
813 | +# |
814 | +# Storm is distributed in the hope that it will be useful, |
815 | +# but WITHOUT ANY WARRANTY; without even the implied warranty of |
816 | +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
817 | +# GNU Lesser General Public License for more details. |
818 | +# |
819 | +# You should have received a copy of the GNU Lesser General Public License |
820 | +# along with this program. If not, see <http://www.gnu.org/licenses/>. |
821 | +# |
822 | + |
823 | +""" |
824 | +Test for twistorm. |
825 | +""" |
826 | + |
827 | +from storm.properties import Int, Unicode |
828 | +from storm.expr import Count |
829 | +from storm.references import Reference |
830 | +from storm.exceptions import OperationalError, ThreadSafetyError |
831 | + |
832 | +from storm.twisted.store import ( |
833 | + DeferredStore, StoreThread, StorePool, AlreadyStopped ) |
834 | +from storm.twisted.wrapper import DeferredReference, DeferredReferenceSet |
835 | + |
836 | +from twisted.trial.unittest import TestCase |
837 | +from twisted.internet.defer import gatherResults, deferredGenerator |
838 | +from twisted.internet.defer import waitForDeferred, DeferredList |
839 | +from twisted.internet.defer import succeed, fail |
840 | + |
841 | + |
842 | +class Foo(object): |
843 | + """ |
844 | + Test table. |
845 | + """ |
846 | + __storm_table__ = "foo" |
847 | + id = Int(primary=True) |
848 | + title = Unicode() |
849 | + |
850 | + |
851 | + |
852 | +class Bar(object): |
853 | + """ |
854 | + Test table referencing to C{Foo} |
855 | + """ |
856 | + __storm_table__ = "bar" |
857 | + id = Int(primary=True) |
858 | + title = Unicode() |
859 | + foo_id = Int() |
860 | + foo = DeferredReference(foo_id, Foo.id) |
861 | + |
862 | + |
863 | + |
864 | +class FooRefSet(Foo): |
865 | + """ |
866 | + A C{Foo} class with a C{DeferredReferenceSet} to get all the bars related. |
867 | + """ |
868 | + bars = DeferredReferenceSet(Foo.id, Bar.foo_id) |
869 | + |
870 | + |
871 | + |
872 | +class FooRefSetOrderID(Foo): |
873 | + """ |
874 | + A C{Foo} class with an order C{DeferredReferenceSet} to C{Bar}. |
875 | + """ |
876 | + bars = DeferredReferenceSet(Foo.id, Bar.foo_id, order_by=Bar.id) |
877 | + |
878 | + |
879 | + |
880 | +class Egg(object): |
881 | + """ |
882 | + Test table. |
883 | + """ |
884 | + __storm_table__ = "egg" |
885 | + id = Int(primary=True) |
886 | + value = Int() |
887 | + |
888 | + |
889 | + |
890 | +class DeferredStoreTest(object): |
891 | + """ |
892 | + Tests for L{DeferredStore}. |
893 | + """ |
894 | + |
895 | + def setUp(self): |
896 | + """ |
897 | + Create a test sqlite database, and insert some data. |
898 | + """ |
899 | + self.create_database() |
900 | + connection = self.connection = self.database.connect() |
901 | + self.drop_tables() |
902 | + self.create_tables() |
903 | + connection.execute("INSERT INTO foo VALUES (10, 'Title 30')") |
904 | + connection.execute("INSERT INTO bar VALUES (10, 10, 'Title 50')") |
905 | + connection.execute("INSERT INTO bar VALUES (11, 10, 'Title 40')") |
906 | + connection.execute("INSERT INTO egg VALUES (1, 4)") |
907 | + connection.execute("INSERT INTO egg VALUES (2, 3)") |
908 | + connection.execute("INSERT INTO egg VALUES (3, 7)") |
909 | + connection.execute("INSERT INTO egg VALUES (4, 5)") |
910 | + connection.commit() |
911 | + self.store = DeferredStore(self.database) |
912 | + return self.store.start() |
913 | + |
914 | + |
915 | + def tearDown(self): |
916 | + """ |
917 | + Kill the store (and its underlying thread). |
918 | + """ |
919 | + def _stop(ign): |
920 | + return self.store.stop().addCallback(_drop) |
921 | + def _drop(ign): |
922 | + self.drop_tables() |
923 | + return self.store.rollback().addCallback(_stop) |
924 | + |
925 | + |
926 | + def create_database(self): |
927 | + raise NotImplementedError() |
928 | + |
929 | + |
930 | + def create_tables(self): |
931 | + raise NotImplementedError() |
932 | + |
933 | + |
934 | + def drop_tables(self): |
935 | + for table in ["foo", "bar", "egg"]: |
936 | + try: |
937 | + self.connection.execute("DROP TABLE %s" % table) |
938 | + self.connection.commit() |
939 | + except: |
940 | + self.connection.rollback() |
941 | + |
942 | + |
943 | + def test_multiple_start(self): |
944 | + """ |
945 | + Check that start raises an exception when the store is already started. |
946 | + """ |
947 | + self.assertRaises(RuntimeError, self.store.start) |
948 | + |
949 | + |
950 | + def test_get(self): |
951 | + """ |
952 | + Try to get an object from the store and check its attributes. |
953 | + """ |
954 | + def cb(result): |
955 | + self.assertEquals(result.title, u"Title 30") |
956 | + self.assertEquals(result.id, 10) |
957 | + return self.store.get(Foo, 10).addCallback(cb) |
958 | + |
959 | + |
960 | + def test_add(self): |
961 | + """ |
962 | + Add an object to the store. |
963 | + """ |
964 | + foo = Foo() |
965 | + foo.title = u"Great title" |
966 | + foo.id = 11 |
967 | + def cb_add(ign): |
968 | + return self.store.get(Foo, 11).addCallback(cb_get) |
969 | + def cb_get(result): |
970 | + self.assertEquals(result.title, u"Great title") |
971 | + self.assertEquals(result.id, 11) |
972 | + return self.store.add(foo).addCallback(cb_add) |
973 | + |
974 | + |
975 | + def test_add_default_value(self): |
976 | + """ |
977 | + When adding an object to the store, the default values from the |
978 | + database are retrieved and put into the object. |
979 | + """ |
980 | + foo = Foo() |
981 | + foo.id = 11 |
982 | + def cb_add(result): |
983 | + self.assertIdentical(result, None) |
984 | + return self.store.get(Foo, 11).addCallback(cb_get) |
985 | + def cb_get(result): |
986 | + self.assertEquals(result.title, u"Default Title") |
987 | + self.assertEquals(result.id, 11) |
988 | + return self.store.add(foo).addCallback(cb_add) |
989 | + |
990 | + |
991 | + def test_execute(self): |
992 | + """ |
993 | + Test a direct execute on the store, and the C{get_one} method of |
994 | + C{DeferredResult}. |
995 | + """ |
996 | + def cb_execute(result): |
997 | + return result.get_one().addCallback(cb_result) |
998 | + def cb_result(result): |
999 | + self.assertEquals(result, (u"Title 30",)) |
1000 | + return self.store.execute("SELECT title FROM foo WHERE id=10" |
1001 | + ).addCallback(cb_execute) |
1002 | + |
1003 | + |
1004 | + def test_execute_all(self): |
1005 | + """ |
1006 | + Test a direct execute on the store, and the C{all} method of |
1007 | + C{DeferredResult}. |
1008 | + """ |
1009 | + def cb_execute(result): |
1010 | + return result.get_all().addCallback(cb_result) |
1011 | + def cb_result(result): |
1012 | + self.assertEquals(result, [(u"Title 50",), (u"Title 40",)]) |
1013 | + return self.store.execute("SELECT title FROM bar" |
1014 | + ).addCallback(cb_execute) |
1015 | + |
1016 | + |
1017 | + def test_remove(self): |
1018 | + """ |
1019 | + Trying removing an object from the database. |
1020 | + """ |
1021 | + def cb_get(result): |
1022 | + return self.store.remove(result).addCallback(cb_remove) |
1023 | + def cb_remove(ign): |
1024 | + return self.store.get(Foo, 10).addCallback(cb_get_after_remove) |
1025 | + def cb_get_after_remove(result): |
1026 | + self.assertIdentical(result, None) |
1027 | + return self.store.get(Foo, 10).addCallback(cb_get) |
1028 | + |
1029 | + |
1030 | + def test_find(self): |
1031 | + """ |
1032 | + Try to find a list of objects using the store. |
1033 | + """ |
1034 | + def cb_find(results): |
1035 | + return results.all().addCallback(cb_all) |
1036 | + def cb_all(results): |
1037 | + self.assertEquals(len(results), 2) |
1038 | + titles = [results[0].title, results[1].title] |
1039 | + titles.sort() |
1040 | + self.assertEquals(titles, [u"Title 40", u"Title 50"]) |
1041 | + return self.store.find(Bar).addCallback(cb_find) |
1042 | + |
1043 | + |
1044 | + def test_find_first(self): |
1045 | + """ |
1046 | + Try to get the first object matching a query. |
1047 | + """ |
1048 | + def cb_find(results): |
1049 | + results.order_by(Bar.title) |
1050 | + return results.first().addCallback(cb_all) |
1051 | + def cb_all(result): |
1052 | + self.assertEquals(result.title, u"Title 40") |
1053 | + self.assertEquals(result.id, 11) |
1054 | + return self.store.find(Bar).addCallback(cb_find) |
1055 | + |
1056 | + |
1057 | + def test_find_last(self): |
1058 | + """ |
1059 | + Try to get the last object matching a query. |
1060 | + """ |
1061 | + def cb_find(results): |
1062 | + results.order_by(Bar.title) |
1063 | + return results.last().addCallback(cb_all) |
1064 | + def cb_all(result): |
1065 | + self.assertEquals(result.title, u"Title 50") |
1066 | + self.assertEquals(result.id, 10) |
1067 | + return self.store.find(Bar).addCallback(cb_find) |
1068 | + |
1069 | + |
1070 | + def test_find_any(self): |
1071 | + """ |
1072 | + Try to get an object matching a query using the C{any} method. |
1073 | + """ |
1074 | + def cb_find(results): |
1075 | + return results.any().addCallback(cb_all) |
1076 | + def cb_all(result): |
1077 | + self.assertEquals(result.title, u"Title 50") |
1078 | + self.assertEquals(result.id, 10) |
1079 | + return self.store.find(Bar).addCallback(cb_find) |
1080 | + |
1081 | + |
1082 | + def test_find_max(self): |
1083 | + """ |
1084 | + Try to get the maximum of a value after a find. |
1085 | + """ |
1086 | + def cb_find(results): |
1087 | + return results.max(Egg.value).addCallback(cb_all) |
1088 | + def cb_all(result): |
1089 | + self.assertEquals(result, 7) |
1090 | + return self.store.find(Egg).addCallback(cb_find) |
1091 | + |
1092 | + |
1093 | + def test_find_min(self): |
1094 | + """ |
1095 | + Try to get the minimum of a value after a find. |
1096 | + """ |
1097 | + def cb_find(results): |
1098 | + return results.min(Egg.value).addCallback(cb_all) |
1099 | + def cb_all(result): |
1100 | + self.assertEquals(result, 3) |
1101 | + return self.store.find(Egg).addCallback(cb_find) |
1102 | + |
1103 | + |
1104 | + def test_find_avg(self): |
1105 | + """ |
1106 | + Try to get the average of a value after a find. |
1107 | + """ |
1108 | + def cb_find(results): |
1109 | + return results.avg(Egg.value).addCallback(cb_all) |
1110 | + def cb_all(result): |
1111 | + self.assertEquals(result, 4.75) |
1112 | + return self.store.find(Egg).addCallback(cb_find) |
1113 | + |
1114 | + |
1115 | + def test_find_sum(self): |
1116 | + """ |
1117 | + Try to get the sum of a value after a find. |
1118 | + """ |
1119 | + def cb_find(results): |
1120 | + return results.sum(Egg.value).addCallback(cb_all) |
1121 | + def cb_all(result): |
1122 | + self.assertEquals(result, 19) |
1123 | + return self.store.find(Egg).addCallback(cb_find) |
1124 | + |
1125 | + |
1126 | + def test_find_count(self): |
1127 | + """ |
1128 | + Try to get the count of a result after a find. |
1129 | + """ |
1130 | + def cb_find(results): |
1131 | + return results.count().addCallback(cb_all) |
1132 | + def cb_all(result): |
1133 | + self.assertEquals(result, 2) |
1134 | + return self.store.find(Egg, Egg.value >= 5).addCallback(cb_find) |
1135 | + |
1136 | + |
1137 | + def test_find_remove(self): |
1138 | + """ |
1139 | + Remove the result of a find query. |
1140 | + """ |
1141 | + def cb_find(results): |
1142 | + return results.remove().addCallback(cb_remove) |
1143 | + def cb_remove(ignore): |
1144 | + return self.store.find(Egg).addCallback(cb_find_after_remove) |
1145 | + def cb_find_after_remove(results): |
1146 | + return results.all().addCallback(cb_all) |
1147 | + def cb_all(results): |
1148 | + self.assertEquals(len(results), 2) |
1149 | + return self.store.find(Egg, Egg.value >= 5).addCallback(cb_find) |
1150 | + |
1151 | + |
1152 | + def test_find_limit(self): |
1153 | + """ |
1154 | + Put a limit on the number of results of a find. |
1155 | + """ |
1156 | + def cb_find(results): |
1157 | + results.config(limit=3) |
1158 | + return results.all().addCallback(cb_all) |
1159 | + def cb_all(results): |
1160 | + self.assertEquals(len(results), 3) |
1161 | + return self.store.find(Egg).addCallback(cb_find) |
1162 | + |
1163 | + |
1164 | + def test_find_union(self): |
1165 | + """ |
1166 | + Call C{union} on 2 differents C{DeferredResultSet}. |
1167 | + """ |
1168 | + def cb_find(results): |
1169 | + result1, result2 = results |
1170 | + results = result1.union(result2._resultSet) |
1171 | + return results.all().addCallback(cb_all) |
1172 | + def cb_all(results): |
1173 | + self.assertEquals(len(results), 3) |
1174 | + d1 = self.store.find(Egg, Egg.value >= 5) |
1175 | + d2 = self.store.find(Egg, Egg.value == 3) |
1176 | + return gatherResults([d1, d2]).addCallback(cb_find) |
1177 | + |
1178 | + |
1179 | + def test_find_difference(self): |
1180 | + """ |
1181 | + Call C{union} on 2 differents C{DeferredResultSet}. |
1182 | + """ |
1183 | + if self.__class__.__name__.startswith("MySQL"): |
1184 | + return |
1185 | + def cb_find(results): |
1186 | + result1, result2 = results |
1187 | + results = result1.difference(result2._resultSet) |
1188 | + return results.all().addCallback(cb_all) |
1189 | + def cb_all(results): |
1190 | + self.assertEquals(len(results), 1) |
1191 | + self.assertEquals(results[0].value, 5) |
1192 | + d1 = self.store.find(Egg, Egg.value >= 5) |
1193 | + d2 = self.store.find(Egg, Egg.value == 7) |
1194 | + return gatherResults([d1, d2]).addCallback(cb_find) |
1195 | + |
1196 | + |
1197 | + def test_find_intersection(self): |
1198 | + """ |
1199 | + Call C{intersection} on 2 differents C{DeferredResultSet}. |
1200 | + """ |
1201 | + if self.__class__.__name__.startswith("MySQL"): |
1202 | + return |
1203 | + def cb_find(results): |
1204 | + result1, result2 = results |
1205 | + results = result1.intersection(result2._resultSet) |
1206 | + return results.all().addCallback(cb_all) |
1207 | + def cb_all(results): |
1208 | + self.assertEquals(len(results), 1) |
1209 | + self.assertEquals(results[0].value, 7) |
1210 | + d1 = self.store.find(Egg, Egg.value >= 5) |
1211 | + d2 = self.store.find(Egg, Egg.value == 7) |
1212 | + return gatherResults([d1, d2]).addCallback(cb_find) |
1213 | + |
1214 | + |
1215 | + def test_find_values(self): |
1216 | + """ |
1217 | + Filter the fields returned by a find using the values method. |
1218 | + """ |
1219 | + def cb_find(results): |
1220 | + return results.values(Bar.title).addCallback(cb_all) |
1221 | + def cb_all(titles): |
1222 | + titles.sort() |
1223 | + self.assertEquals(titles, [u"Title 40", u"Title 50"]) |
1224 | + return self.store.find(Bar).addCallback(cb_find) |
1225 | + |
1226 | + |
1227 | + def test_find_and_set(self): |
1228 | + """ |
1229 | + The C{set} method of a C{ResultSet} should update the specified fields |
1230 | + in a thread. |
1231 | + """ |
1232 | + def cb_find(results): |
1233 | + return results.set(title=u"Title").addCallback(cb_set, results) |
1234 | + def cb_set(ignore, results): |
1235 | + return results.values(Bar.title).addCallback(cb_all) |
1236 | + def cb_all(titles): |
1237 | + titles.sort() |
1238 | + self.assertEquals(titles, [u"Title", u"Title"]) |
1239 | + return self.store.find(Bar).addCallback(cb_find) |
1240 | + |
1241 | + |
1242 | + def test_find_offset(self): |
1243 | + """ |
1244 | + Put an offset on the number of results of a find. |
1245 | + """ |
1246 | + def cb_find(results): |
1247 | + results.config(offset=2) |
1248 | + return results.all().addCallback(cb_all) |
1249 | + def cb_all(results): |
1250 | + self.assertEquals(len(results), 2) |
1251 | + return self.store.find(Egg).addCallback(cb_find) |
1252 | + |
1253 | + |
1254 | + def test_find_offset_limit(self): |
1255 | + """ |
1256 | + Put an offset and limit in the number of results of a find. |
1257 | + """ |
1258 | + def cb_find(results): |
1259 | + results.config(offset=1, limit=2) |
1260 | + return results.all().addCallback(cb_all) |
1261 | + def cb_all(results): |
1262 | + self.assertEquals(len(results), 2) |
1263 | + return self.store.find(Egg).addCallback(cb_find) |
1264 | + |
1265 | + |
1266 | + @deferredGenerator |
1267 | + def test_find_defgen(self): |
1268 | + """ |
1269 | + Do a find, add an object, then do another find: this to ensure that the |
1270 | + connection remains in the dedicated thread. |
1271 | + """ |
1272 | + d = self.store.find(Bar) |
1273 | + wfd = waitForDeferred(d) |
1274 | + yield wfd |
1275 | + results = wfd.getResult() |
1276 | + d = results.all() |
1277 | + wfd = waitForDeferred(d) |
1278 | + yield wfd |
1279 | + wfd.getResult() |
1280 | + foo = Foo() |
1281 | + foo.title = u"Great title" |
1282 | + foo.id = 11 |
1283 | + d = self.store.add(foo) |
1284 | + wfd = waitForDeferred(d) |
1285 | + yield wfd |
1286 | + wfd.getResult() |
1287 | + d = self.store.find(Foo) |
1288 | + wfd = waitForDeferred(d) |
1289 | + yield wfd |
1290 | + results = wfd.getResult() |
1291 | + |
1292 | + |
1293 | + def test_find_order_by(self): |
1294 | + """ |
1295 | + Try to find a list of objects using the store, then order the result |
1296 | + set. |
1297 | + """ |
1298 | + def cb_find(results): |
1299 | + results.order_by(Bar.title) |
1300 | + return results.all().addCallback(cb_all) |
1301 | + def cb_all(results): |
1302 | + self.assertEquals(len(results), 2) |
1303 | + titles = [results[0].title, results[1].title] |
1304 | + self.assertEquals(titles, [u"Title 40", u"Title 50"]) |
1305 | + return self.store.find(Bar).addCallback(cb_find) |
1306 | + |
1307 | + |
1308 | + def test_find_and_rollback(self): |
1309 | + """ |
1310 | + Accessing an object outside of a transaction fails because the object |
1311 | + hasn't been resolved yet. |
1312 | + """ |
1313 | + def cb_find(results): |
1314 | + results.order_by(Bar.title) |
1315 | + return results.all().addCallback(cb_all) |
1316 | + def cb_all(results): |
1317 | + return self.store.rollback().addCallback(cbRollback, results) |
1318 | + def cbRollback(ign, results): |
1319 | + self.assertEquals(len(results), 2) |
1320 | + self.assertRaises(RuntimeError, getattr, results[0], "title") |
1321 | + return self.store.find(Bar).addCallback(cb_find) |
1322 | + |
1323 | + |
1324 | + def test_find_is_empty(self): |
1325 | + """ |
1326 | + DeferredReference.is_empty returns a Deferred that fires with True or |
1327 | + False depending if the matched result set is empty or not. |
1328 | + """ |
1329 | + def cb_find(results): |
1330 | + return results.is_empty().addCallback(self.assertEquals, False) |
1331 | + |
1332 | + return self.store.find(Bar).addCallback(cb_find) |
1333 | + |
1334 | + |
1335 | + def test_find_group_by(self): |
1336 | + """ |
1337 | + DeferredReference.group_by is a simple wrapper to the group_by method |
1338 | + of the reference set. |
1339 | + """ |
1340 | + def cb_find(results): |
1341 | + results.group_by(Bar.foo_id) |
1342 | + return results.all().addCallback(check) |
1343 | + |
1344 | + def check(result): |
1345 | + self.assertEquals(result, [(2, 10)]) |
1346 | + |
1347 | + return self.store.find((Count(Bar.id), Bar.foo_id) |
1348 | + ).addCallback(cb_find) |
1349 | + |
1350 | + |
1351 | + def test_find_having(self): |
1352 | + """ |
1353 | + DeferredReference.having is a simple wrapper to the having method of |
1354 | + the reference set. |
1355 | + """ |
1356 | + connection = self.database.connect() |
1357 | + connection.execute("INSERT INTO egg VALUES (5, 7)") |
1358 | + connection.commit() |
1359 | + |
1360 | + def cb_find(results): |
1361 | + results.group_by(Egg.value) |
1362 | + results.having(Egg.value >= 5) |
1363 | + results.order_by(Egg.value) |
1364 | + return results.all().addCallback(check) |
1365 | + |
1366 | + def check(result): |
1367 | + self.assertEquals(result, [(1, 5), (2, 7)]) |
1368 | + |
1369 | + return self.store.find((Count(Egg.id), Egg.value) |
1370 | + ).addCallback(cb_find) |
1371 | + |
1372 | + |
1373 | + def test_reference(self): |
1374 | + """ |
1375 | + Trying to get a reference of an object using C{DeferredReference}. |
1376 | + """ |
1377 | + def cb_getBar(result): |
1378 | + return result.foo.addCallback(cb_getFoo) |
1379 | + def cb_getFoo(fooResult): |
1380 | + return self.store.get(Foo, 10).addCallback(cb_getFooBar, fooResult) |
1381 | + def cb_getFooBar(result, fooResult): |
1382 | + self.assertIdentical(fooResult, result) |
1383 | + # The result should be valid too |
1384 | + self.assertEquals(fooResult.title, u"Title 30") |
1385 | + return self.store.get(Bar, 10).addCallback(cb_getBar) |
1386 | + |
1387 | + |
1388 | + def test_reference_setting(self): |
1389 | + """ |
1390 | + Try to set a reference of an object. |
1391 | + """ |
1392 | + connection = self.database.connect() |
1393 | + connection.execute("INSERT INTO foo VALUES (20, 'Title 20')") |
1394 | + connection.commit() |
1395 | + def cb_getBar(result): |
1396 | + return self.store.get(Foo, 20).addCallback(cb_getFooBar, result) |
1397 | + def cb_getFooBar(result, barResult): |
1398 | + self.assertRaises(RuntimeError, setattr, barResult, "foo", result) |
1399 | + return self.store.get(Bar, 10).addCallback(cb_getBar) |
1400 | + |
1401 | + |
1402 | + def test_reference_set_unordered(self): |
1403 | + """ |
1404 | + Get a reference set and call various wrapped methods on it. |
1405 | + """ |
1406 | + # find test |
1407 | + def cb_find(results): |
1408 | + return results.all().addCallback(cb_all) |
1409 | + |
1410 | + def cb_all(results): |
1411 | + self.assertEquals(len(results), 2) |
1412 | + titles = [results[0].title, results[1].title] |
1413 | + titles.sort() |
1414 | + self.assertEquals(titles, [u"Title 40", u"Title 50"]) |
1415 | + |
1416 | + def cb_any(result): |
1417 | + self.assertTrue(result) |
1418 | + |
1419 | + def cb_values(titles): |
1420 | + titles.sort() |
1421 | + self.assertEquals(titles, [u"Title 40", u"Title 50"]) |
1422 | + |
1423 | + def do_tests(results): |
1424 | + results = results.bars |
1425 | + dfrs = [ |
1426 | + results.find().addCallback(cb_find), |
1427 | + results.any().addCallback(cb_any), |
1428 | + results.values(Bar.title).addCallback(cb_values), |
1429 | + ] |
1430 | + return DeferredList(dfrs) |
1431 | + |
1432 | + return self.store.get(FooRefSet, 10).addCallback(do_tests) |
1433 | + |
1434 | + |
1435 | + def test_reference_set_ordered(self): |
1436 | + """ |
1437 | + A DeferredReferenceSet has a order_by method which returns a Deferred |
1438 | + firing when the reference set is ordered. |
1439 | + """ |
1440 | + def do_tests(result): |
1441 | + dfrs = [ |
1442 | + result.first().addCallback(lambda t: |
1443 | + self.assertEquals(t.title, u"Title 40")), |
1444 | + result.last().addCallback(lambda t: |
1445 | + self.assertEquals(t.title, u"Title 50")), |
1446 | + ] |
1447 | + return DeferredList(dfrs) |
1448 | + |
1449 | + def order(results): |
1450 | + dfr = results.bars.order_by("title") |
1451 | + return dfr.addCallback(do_tests) |
1452 | + |
1453 | + return self.store.get(FooRefSet, 10).addCallback(order) |
1454 | + |
1455 | + |
1456 | + def test_reference_set_add_remove(self): |
1457 | + """ |
1458 | + A DeferredReferenceSet has a add method with returns a Deferred once |
1459 | + the object has been added. |
1460 | + Try to add things from the reference set async. |
1461 | + """ |
1462 | + def add_one(result): |
1463 | + bar = Bar() |
1464 | + bar.title = u"Yeah" |
1465 | + return result.bars.add(bar).addCallback(remove_one, result, bar) |
1466 | + |
1467 | + def remove_one(add_result, result, bar): |
1468 | + return result.bars.remove(bar).addCallback(get_all, result) |
1469 | + |
1470 | + def get_all(ignore, result): |
1471 | + return result.bars.all().addCallback(check) |
1472 | + |
1473 | + def check(result): |
1474 | + self.assertEquals(len(result), 2) |
1475 | + |
1476 | + return self.store.get(FooRefSet, 10).addCallback(add_one) |
1477 | + |
1478 | + |
1479 | + def test_reference_set_clear(self): |
1480 | + """ |
1481 | + A DeferredReferenceSet has a clear method which removes all elements |
1482 | + from the reference set and fires the returned Deferred when done. |
1483 | + """ |
1484 | + def first_cb(result): |
1485 | + refs = result.bars |
1486 | + return check_count(refs, 2).addCallback(clear_cb, refs) |
1487 | + |
1488 | + def check_count(ref_set, num): |
1489 | + return ref_set.count().addCallback(self.assertEquals, num) |
1490 | + |
1491 | + def clear_cb(result, refs): |
1492 | + return refs.clear().addCallback(lambda x: check_count(refs, 0)) |
1493 | + |
1494 | + return self.store.get(FooRefSet, 10).addCallback(first_cb) |
1495 | + |
1496 | + |
1497 | + def test_reference_set_one(self): |
1498 | + """ |
1499 | + Call C{one} on a C{DeferredBoundReference}. |
1500 | + """ |
1501 | + connection = self.database.connect() |
1502 | + connection.execute("INSERT INTO foo VALUES (11, 'Title 40')") |
1503 | + connection.execute("INSERT INTO bar VALUES (20, 11, 'Title 50')") |
1504 | + connection.commit() |
1505 | + def cb_get(result): |
1506 | + return result.bars.one().addCallback(cb_one) |
1507 | + def cb_one(result): |
1508 | + return self.store.get(Bar, 20).addCallback(check, result) |
1509 | + def check(result, expected): |
1510 | + self.assertIdentical(result, expected) |
1511 | + return self.store.get(FooRefSet, 11).addCallback(cb_get) |
1512 | + |
1513 | + |
1514 | + def test_reference_set_first(self): |
1515 | + """ |
1516 | + Call C{first} on an ordered C{DeferredBoundReference}. |
1517 | + """ |
1518 | + def cb_get(result): |
1519 | + return result.bars.first().addCallback(cb_one) |
1520 | + def cb_one(result): |
1521 | + return self.store.get(Bar, 10).addCallback(check, result) |
1522 | + def check(result, expected): |
1523 | + self.assertIdentical(result, expected) |
1524 | + return self.store.get(FooRefSetOrderID, 10).addCallback(cb_get) |
1525 | + |
1526 | + |
1527 | + def test_reference_set_last(self): |
1528 | + """ |
1529 | + Call C{last} on an ordered C{DeferredBoundReference}. |
1530 | + """ |
1531 | + def cb_get(result): |
1532 | + return result.bars.last().addCallback(cb_one) |
1533 | + def cb_one(result): |
1534 | + return self.store.get(Bar, 11).addCallback(check, result) |
1535 | + def check(result, expected): |
1536 | + self.assertIdentical(result, expected) |
1537 | + return self.store.get(FooRefSetOrderID, 10).addCallback(cb_get) |
1538 | + |
1539 | + |
1540 | + def test_commit(self): |
1541 | + """ |
1542 | + Make some changes and commit them. |
1543 | + """ |
1544 | + def cb_get(result): |
1545 | + return self.store.remove(result).addCallback(cb_remove) |
1546 | + def cb_remove(ign): |
1547 | + return self.store.commit().addCallback(cb_commit) |
1548 | + def cb_commit(ign): |
1549 | + # To be sure the data is no more in the db, the best is to |
1550 | + # directly connect to the db |
1551 | + connection = self.database.connect() |
1552 | + result = connection.execute("SELECT * FROM foo") |
1553 | + self.assertEquals(list(result), []) |
1554 | + return self.store.get(Foo, 10).addCallback(cb_get) |
1555 | + |
1556 | + |
1557 | + def test_rollback(self): |
1558 | + """ |
1559 | + Make and some changes and rollback them. |
1560 | + """ |
1561 | + def cb_get(result): |
1562 | + return self.store.remove(result).addCallback(cb_remove) |
1563 | + def cb_remove(ign): |
1564 | + return self.store.rollback().addCallback(cbRollback) |
1565 | + def cbRollback(ign): |
1566 | + connection = self.database.connect() |
1567 | + result = connection.execute("SELECT * FROM foo") |
1568 | + self.assertEquals(list(result), [(10, u"Title 30")]) |
1569 | + return self.store.get(Foo, 10).addCallback(cb_get) |
1570 | + |
1571 | + |
1572 | + def test_deferred_reference_multithread(self): |
1573 | + """ |
1574 | + If a store is restarted, the objects in the cache should still be |
1575 | + usable, in particular an object shouldn't not store a reference to the |
1576 | + store thread, as it can change. |
1577 | + """ |
1578 | + def test(ignore): |
1579 | + # get a bar and retrieve the deferred reference |
1580 | + def get_foo(bar): |
1581 | + return bar.foo |
1582 | + |
1583 | + return self.store.find(Bar, Bar.id == 10).addCallback(lambda x: |
1584 | + x.one()).addCallback(get_foo) |
1585 | + |
1586 | + def _restart_store(res): |
1587 | + def stopped(res): |
1588 | + self.store = DeferredStore(self.database) |
1589 | + return self.store.start() |
1590 | + return self.store.stop().addCallback(stopped) |
1591 | + |
1592 | + def check(foo): |
1593 | + self.assertEquals(foo.id, 10) |
1594 | + self.assertEquals(foo.title, u"Title 30") |
1595 | + |
1596 | + return test(None |
1597 | + ).addCallback(_restart_store |
1598 | + ).addCallback(test |
1599 | + ).addCallback(check) |
1600 | + |
1601 | + |
1602 | + def test_of(self): |
1603 | + """ |
1604 | + The DeferredStore associated with an object is returned by the static |
1605 | + C{of} method. |
1606 | + """ |
1607 | + def cb_get(result): |
1608 | + store = DeferredStore.of(result) |
1609 | + self.assertIdentical(self.store, store) |
1610 | + |
1611 | + return self.store.get(Foo, 10).addCallback(cb_get) |
1612 | + |
1613 | + |
1614 | + def test_thread_check(self): |
1615 | + """ |
1616 | + A L{ThreadSafetyError} is raised when attempting to do an unsafe |
1617 | + operation, like accessing a C{Reference} attribute via a |
1618 | + C{DeferredStore}. |
1619 | + """ |
1620 | + class WrongBar(Bar): |
1621 | + foo_sync = Reference(Bar.foo_id, Foo.id) |
1622 | + d = self.store.find(WrongBar).addCallback(lambda result: result.all()) |
1623 | + |
1624 | + def check(results): |
1625 | + self.assertRaises( |
1626 | + ThreadSafetyError, getattr, results[0], "foo_sync") |
1627 | + return d.addCallback(check) |
1628 | + |
1629 | + |
1630 | + |
1631 | +class StoreThreadTestCase(TestCase): |
1632 | + """ |
1633 | + Tests for L{StoreThread}. |
1634 | + """ |
1635 | + |
1636 | + def setUp(self): |
1637 | + """ |
1638 | + Create an instance of C{StoreThread} and start it. |
1639 | + """ |
1640 | + self.thread = StoreThread() |
1641 | + self.thread.start() |
1642 | + |
1643 | + |
1644 | + def tearDown(self): |
1645 | + """ |
1646 | + Kill the running thread. |
1647 | + """ |
1648 | + self.thread.stop() |
1649 | + |
1650 | + |
1651 | + def test_defer_after_stop(self): |
1652 | + """ |
1653 | + Deferring calls after store is stopped raises C{AlreadyStopped}. |
1654 | + """ |
1655 | + def cb_stop(r): |
1656 | + self.assertFailure(self.thread.defer_to_thread(lambda f : None), |
1657 | + AlreadyStopped) |
1658 | + return self.thread.stop().addCallback(cb_stop) |
1659 | + |
1660 | + |
1661 | + def test_callback(self): |
1662 | + """ |
1663 | + Fire a simple function in a thread and check its result. |
1664 | + """ |
1665 | + def testfunc(): |
1666 | + return 1 |
1667 | + return self.thread.defer_to_thread(testfunc |
1668 | + ).addCallback(self.assertEquals, 1) |
1669 | + |
1670 | + |
1671 | + def test_errback(self): |
1672 | + """ |
1673 | + Raising an exception in a thread returns a failure. |
1674 | + """ |
1675 | + def testfunc(): |
1676 | + raise RuntimeError("Error!") |
1677 | + return self.assertFailure(self.thread.defer_to_thread(testfunc), |
1678 | + RuntimeError) |
1679 | + |
1680 | + |
1681 | + |
1682 | +class StorePoolTest(object): |
1683 | + """ |
1684 | + Tests for L{StorePool}. |
1685 | + """ |
1686 | + |
1687 | + def setUp(self): |
1688 | + """ |
1689 | + Build a database with data, a create a pool. |
1690 | + """ |
1691 | + self.create_database() |
1692 | + connection = self.connection = self.database.connect() |
1693 | + self.drop_tables() |
1694 | + self.create_tables() |
1695 | + connection.execute("INSERT INTO foo VALUES (10, 'Title 30')") |
1696 | + connection.execute("INSERT INTO bar VALUES (10, 10, 'Title 40')") |
1697 | + connection.execute("INSERT INTO bar VALUES (11, 10, 'Title 50')") |
1698 | + connection.commit() |
1699 | + self.pool = StorePool(self.database, 2, 5) |
1700 | + return self.pool.start() |
1701 | + |
1702 | + |
1703 | + def tearDown(self): |
1704 | + """ |
1705 | + Stop the pool. |
1706 | + """ |
1707 | + def _drop(ign): |
1708 | + self.drop_tables() |
1709 | + return self.pool.stop().addCallback(_drop) |
1710 | + |
1711 | + |
1712 | + def drop_tables(self): |
1713 | + for table in ["foo", "bar"]: |
1714 | + try: |
1715 | + self.connection.execute("DROP TABLE %s" % table) |
1716 | + self.connection.commit() |
1717 | + except: |
1718 | + self.connection.rollback() |
1719 | + |
1720 | + |
1721 | + def test_already_started(self): |
1722 | + """ |
1723 | + Check that the pool can't be restarted multiple times. |
1724 | + """ |
1725 | + self.assertRaises(RuntimeError, self.pool.start) |
1726 | + |
1727 | + |
1728 | + def test_get(self): |
1729 | + """ |
1730 | + get should return different stores if available. |
1731 | + """ |
1732 | + def cb_get1(store1): |
1733 | + return self.pool.get().addCallback(cb_get2, store1) |
1734 | + def cb_get2(store2, store1): |
1735 | + self.assertNotIdentical(store1, store2) |
1736 | + self.assertTrue(store1.started) |
1737 | + self.assertTrue(store2.started) |
1738 | + return self.pool.get().addCallback(cb_get1) |
1739 | + |
1740 | + |
1741 | + def test_get_not_started(self): |
1742 | + """ |
1743 | + If no store are available, the pool should create a store. |
1744 | + """ |
1745 | + def cb(ign): |
1746 | + self.assertEquals(self.pool._stores_created, 0) |
1747 | + return self.pool.adjust_size(0, 5).addCallback(cb_add) |
1748 | + def cb_add(ign): |
1749 | + return self.pool.get().addCallback(cb_get) |
1750 | + def cb_get(store): |
1751 | + self.assertTrue(store.started) |
1752 | + return self.pool.adjust_size(0, 0).addCallback(cb) |
1753 | + |
1754 | + |
1755 | + def test_waiting_for_store(self): |
1756 | + """ |
1757 | + Test waiting for a store availability. |
1758 | + """ |
1759 | + def cb(ign): |
1760 | + return self.pool.get().addCallback(cb_get1) |
1761 | + def cb_get1(store1): |
1762 | + self.assertTrue(store1.started) |
1763 | + # Now we have a store, no store should be returned by the pool |
1764 | + # until we put it back |
1765 | + d1 = self.pool.get().addCallback(cb_get2, store1) |
1766 | + d2 = self.pool.put(store1) |
1767 | + return gatherResults([d1, d2]) |
1768 | + def cb_get2(store2, store1): |
1769 | + self.assertIdentical(store1, store2) |
1770 | + return self.pool.adjust_size(1, 1).addCallback(cb) |
1771 | + |
1772 | + |
1773 | + def test_adjust_size_minmax(self): |
1774 | + """ |
1775 | + Test sanity check on min/max - i.e. min <= max. |
1776 | + """ |
1777 | + return self.assertFailure(self.pool.adjust_size(2, 1), ValueError) |
1778 | + |
1779 | + |
1780 | + def test_adjust_size_nonnegative(self): |
1781 | + """ |
1782 | + Test sanity check for nonnegative min. |
1783 | + """ |
1784 | + return self.assertFailure(self.pool.adjust_size(-1), ValueError) |
1785 | + |
1786 | + |
1787 | + @deferredGenerator |
1788 | + def test_concurrent_data(self): |
1789 | + """ |
1790 | + Test that different stores have different states: if the first store |
1791 | + hasn't yet committed, the second one shouldn't get the new data. |
1792 | + """ |
1793 | + foo = Foo() |
1794 | + foo.title = u"Great title" |
1795 | + foo.id = 11 |
1796 | + d = self.pool.get() |
1797 | + wfd = waitForDeferred(d) |
1798 | + yield wfd |
1799 | + store1 = wfd.getResult() |
1800 | + |
1801 | + d = self.pool.get() |
1802 | + wfd = waitForDeferred(d) |
1803 | + yield wfd |
1804 | + store2 = wfd.getResult() |
1805 | + |
1806 | + d = store1.add(foo) |
1807 | + wfd = waitForDeferred(d) |
1808 | + yield wfd |
1809 | + wfd.getResult() |
1810 | + |
1811 | + d = store1.get(Foo, 11) |
1812 | + wfd = waitForDeferred(d) |
1813 | + yield wfd |
1814 | + foo2 = wfd.getResult() |
1815 | + |
1816 | + # The object is already in the store cache |
1817 | + self.assertIdentical(foo2, foo) |
1818 | + |
1819 | + d = store2.get(Foo, 11) |
1820 | + wfd = waitForDeferred(d) |
1821 | + yield wfd |
1822 | + foo3 = wfd.getResult() |
1823 | + |
1824 | + # The object isn't in the db yet |
1825 | + self.assertIdentical(foo3, None) |
1826 | + |
1827 | + # Let's rollback, because even select open a transaction |
1828 | + d = store2.rollback() |
1829 | + wfd = waitForDeferred(d) |
1830 | + yield wfd |
1831 | + wfd.getResult() |
1832 | + |
1833 | + # Let's commit |
1834 | + d = store1.commit() |
1835 | + wfd = waitForDeferred(d) |
1836 | + yield wfd |
1837 | + wfd.getResult() |
1838 | + |
1839 | + d = store2.get(Foo, 11) |
1840 | + wfd = waitForDeferred(d) |
1841 | + yield wfd |
1842 | + foo4 = wfd.getResult() |
1843 | + |
1844 | + # The objects must be different |
1845 | + self.assertNotIdentical(foo4, foo) |
1846 | + # But the content must be the same |
1847 | + self.assertEquals(foo4.title, u"Great title") |
1848 | + |
1849 | + |
1850 | + def test_no_overflow(self): |
1851 | + """ |
1852 | + Test that pool does not allocate more connections than store_max. |
1853 | + """ |
1854 | + ds = [] |
1855 | + stores = set() |
1856 | + |
1857 | + def cb_get(store): |
1858 | + stores.add(store) |
1859 | + return self.pool.put(store) |
1860 | + |
1861 | + for i in range(10): |
1862 | + ds.append(self.pool.get().addCallback(cb_get)) |
1863 | + |
1864 | + def checkInstances(result): |
1865 | + self.assertEquals(len(stores), 5) |
1866 | + |
1867 | + return gatherResults(ds).addCallback(checkInstances) |
1868 | + |
1869 | + |
1870 | + def test_start_failure(self): |
1871 | + """ |
1872 | + If a store failed to start, the number of allocated connections doesn't |
1873 | + grow, so we're later able to start more stores. |
1874 | + """ |
1875 | + ds = [] |
1876 | + stores = set() |
1877 | + |
1878 | + class DontStartStore(DeferredStore): |
1879 | + def start(self): |
1880 | + return fail(RuntimeError("oops")) |
1881 | + |
1882 | + calls = [] |
1883 | + |
1884 | + def vicious_store_factory(database): |
1885 | + if not calls: |
1886 | + store = DontStartStore(database) |
1887 | + else: |
1888 | + store = DeferredStore(database) |
1889 | + calls.append(None) |
1890 | + return store |
1891 | + |
1892 | + self.pool.store_factory = vicious_store_factory |
1893 | + |
1894 | + def cb_get(store): |
1895 | + stores.add(store) |
1896 | + return self.pool.put(store) |
1897 | + |
1898 | + errors = [] |
1899 | + |
1900 | + def save_errors(failure): |
1901 | + errors.append(failure) |
1902 | + |
1903 | + for i in range(6): |
1904 | + ds.append( |
1905 | + self.pool.get().addCallback(cb_get).addErrback(save_errors)) |
1906 | + |
1907 | + def checkInstances(result): |
1908 | + self.assertEquals(len(stores), 5) |
1909 | + self.assertEquals(len(errors), 1) |
1910 | + errors[0].trap(RuntimeError) |
1911 | + |
1912 | + return gatherResults(ds).addCallback(checkInstances) |
1913 | + |
1914 | + |
1915 | + def test_arguments(self): |
1916 | + """ |
1917 | + Arguments are passed along to transacted method when store |
1918 | + is available. |
1919 | + """ |
1920 | + def tx(store, a, b=None): |
1921 | + self.assertIsInstance(store, DeferredStore) |
1922 | + self.assertEquals(1, a) |
1923 | + self.assertEquals(2, b) |
1924 | + return succeed("ok") |
1925 | + return self.pool.transact(tx, 1, b=2) |
1926 | + |
1927 | + |
1928 | + def test_commit(self): |
1929 | + """ |
1930 | + Changes made inside a successful transaction are committed. |
1931 | + """ |
1932 | + @deferredGenerator |
1933 | + def check(result): |
1934 | + d = self.pool.get() |
1935 | + wfd = waitForDeferred(d) |
1936 | + yield wfd |
1937 | + store = wfd.getResult() |
1938 | + d = store.execute("SELECT * FROM foo ORDER BY id", []) |
1939 | + wfd = waitForDeferred(d) |
1940 | + yield wfd |
1941 | + dr = wfd.getResult() |
1942 | + d = dr.get_all() |
1943 | + wfd = waitForDeferred(d) |
1944 | + yield wfd |
1945 | + results = wfd.getResult() |
1946 | + self.assertEquals(2, len(results)) |
1947 | + self.assertEquals(1, results[0][0]) |
1948 | + self.assertEquals("test", results[0][1]) |
1949 | + |
1950 | + def tx(store): |
1951 | + d = store.execute("INSERT INTO foo(id, title) " |
1952 | + "VALUES (1, 'test')", noresult=True) |
1953 | + return d |
1954 | + |
1955 | + return self.pool.transact(tx).addCallback(check) |
1956 | + |
1957 | + |
1958 | + def test_rollback(self): |
1959 | + """ |
1960 | + Changes made inside a failed transaction are rolled back. |
1961 | + """ |
1962 | + @deferredGenerator |
1963 | + def check(reason): |
1964 | + d = self.pool.get() |
1965 | + wfd = waitForDeferred(d) |
1966 | + yield wfd |
1967 | + store = wfd.getResult() |
1968 | + d = store.execute("SELECT * FROM foo", []) |
1969 | + wfd = waitForDeferred(d) |
1970 | + yield wfd |
1971 | + dr = wfd.getResult() |
1972 | + d = dr.get_all() |
1973 | + wfd = waitForDeferred(d) |
1974 | + yield wfd |
1975 | + results = wfd.getResult() |
1976 | + self.assertEquals(1, len(results)) |
1977 | + |
1978 | + @deferredGenerator |
1979 | + def tx(store): |
1980 | + d = store.execute("INSERT INTO foo(id, title) " |
1981 | + "VALUES (1, 'test')", []) |
1982 | + wfd = waitForDeferred(d) |
1983 | + yield wfd |
1984 | + wfd.getResult() |
1985 | + d = store.execute("INSERT INTO foo(id, title) " |
1986 | + "VALUES (1, 'test')", []) |
1987 | + wfd = waitForDeferred(d) |
1988 | + yield wfd |
1989 | + wfd.getResult() |
1990 | + |
1991 | + return self.pool.transact(tx).addBoth(check) |
1992 | + |
1993 | + |
1994 | + def test_return_value(self): |
1995 | + """ |
1996 | + Final return value should match return value from successful call to |
1997 | + transacted function. |
1998 | + """ |
1999 | + def check(result): |
2000 | + self.assertEquals("completed", result) |
2001 | + |
2002 | + def cb(result): |
2003 | + return "completed" |
2004 | + |
2005 | + def tx(store): |
2006 | + return store.execute("INSERT INTO foo(id, title) " |
2007 | + "VALUES (1, 'test')", [] |
2008 | + ).addCallback(cb) |
2009 | + |
2010 | + return self.pool.transact(tx).addCallback(check) |
2011 | + |
2012 | + |
2013 | + def test_poolsize_after_success(self): |
2014 | + """ |
2015 | + After successful transaction, pool size should be same size as before. |
2016 | + """ |
2017 | + size = len(self.pool._stores) |
2018 | + |
2019 | + def check(result): |
2020 | + self.assertEquals(size, len(self.pool._stores)) |
2021 | + |
2022 | + def tx(store): |
2023 | + d = store.execute("SELECT * from foo", []) |
2024 | + return d.addCallback(lambda result: result.get_all()) |
2025 | + |
2026 | + return self.pool.transact(tx).addCallback(check) |
2027 | + |
2028 | + |
2029 | + def test_poolsize_after_failure(self): |
2030 | + """ |
2031 | + After failed transaction, pool size is restored to the initial value. |
2032 | + """ |
2033 | + size = len(self.pool._stores) |
2034 | + |
2035 | + def check(reason): |
2036 | + self.assertEquals(size, len(self.pool._stores)) |
2037 | + |
2038 | + def tx(store): |
2039 | + return store.execute("SELECT * from not_a_table", []) |
2040 | + |
2041 | + d = self.assertFailure(self.pool.transact(tx), OperationalError) |
2042 | + return d.addCallback(check) |
2043 | + |
2044 | + |
2045 | + def test_failure_propagation(self): |
2046 | + """ |
2047 | + A custom exception is propagated by a C{transact} call. |
2048 | + """ |
2049 | + class MyException(Exception): |
2050 | + pass |
2051 | + |
2052 | + def tx(store): |
2053 | + raise MyException("Bad things happened") |
2054 | + |
2055 | + return self.assertFailure(self.pool.transact(tx), MyException) |
2056 | + |
2057 | + |
2058 | + def test_non_deferred_function(self): |
2059 | + """ |
2060 | + C{transact} can handle functions that don't return a C{Deferred}. |
2061 | + """ |
2062 | + def tx(store): |
2063 | + return "foo" |
2064 | + return self.pool.transact(tx).addCallback(self.assertEquals, "foo") |
2065 | + |
2066 | + |
2067 | + def test_rollback_failure(self): |
2068 | + """ |
2069 | + If C{rollback} fails, the store is put back into the pool. |
2070 | + """ |
2071 | + |
2072 | + def cb_get(store): |
2073 | + store.rollback = lambda: fail(RuntimeError("oops")) |
2074 | + d = self.assertFailure(self.pool.put(store), RuntimeError) |
2075 | + return d.addCallback(get_all) |
2076 | + |
2077 | + def get_all(ignore): |
2078 | + dl = [] |
2079 | + for i in range(5): |
2080 | + dl.append(self.pool.get()) |
2081 | + return gatherResults(dl).addCallback(check) |
2082 | + |
2083 | + def check(stores): |
2084 | + # The fact that we're here show that the test succeeds, because we |
2085 | + # didn't hang waiting for a store |
2086 | + self.assertEquals(len(stores), 5) |
2087 | + |
2088 | + return self.pool.get().addCallback(cb_get) |
2089 | |
2090 | === added file 'tests/twisted/mysql.py' |
2091 | --- tests/twisted/mysql.py 1970-01-01 00:00:00 +0000 |
2092 | +++ tests/twisted/mysql.py 2009-12-27 11:31:15 +0000 |
2093 | @@ -0,0 +1,62 @@ |
2094 | +import os |
2095 | + |
2096 | +from storm.database import create_database |
2097 | + |
2098 | +from tests.twisted.base import DeferredStoreTest, StorePoolTest |
2099 | + |
2100 | +from twisted.trial.unittest import TestCase |
2101 | + |
2102 | + |
2103 | +class MySQLDeferredStoreTest(TestCase, DeferredStoreTest): |
2104 | + |
2105 | + def setUp(self): |
2106 | + return DeferredStoreTest.setUp(self) |
2107 | + |
2108 | + def tearDown(self): |
2109 | + return DeferredStoreTest.tearDown(self) |
2110 | + |
2111 | + def is_supported(self): |
2112 | + return bool(os.environ.get("STORM_MYSQL_URI")) |
2113 | + |
2114 | + def create_database(self): |
2115 | + self.database = create_database(os.environ["STORM_MYSQL_URI"]) |
2116 | + |
2117 | + def create_tables(self): |
2118 | + connection = self.connection |
2119 | + connection.execute("CREATE TABLE foo " |
2120 | + "(id INT PRIMARY KEY AUTO_INCREMENT," |
2121 | + " title VARCHAR(50) DEFAULT 'Default Title')" |
2122 | + " ENGINE=InnoDB") |
2123 | + connection.execute("CREATE TABLE bar " |
2124 | + "(id INT PRIMARY KEY AUTO_INCREMENT," |
2125 | + " foo_id INTEGER, title VARCHAR(50))" |
2126 | + " ENGINE=InnoDB") |
2127 | + connection.execute("CREATE TABLE egg " |
2128 | + "(id INT PRIMARY KEY AUTO_INCREMENT, value INTEGER)" |
2129 | + " ENGINE=InnoDB") |
2130 | + |
2131 | + |
2132 | +class MySQLStorePoolTest(TestCase, StorePoolTest): |
2133 | + |
2134 | + def setUp(self): |
2135 | + return StorePoolTest.setUp(self) |
2136 | + |
2137 | + def tearDown(self): |
2138 | + return StorePoolTest.tearDown(self) |
2139 | + |
2140 | + def is_supported(self): |
2141 | + return bool(os.environ.get("STORM_MYSQL_URI")) |
2142 | + |
2143 | + def create_database(self): |
2144 | + self.database = create_database(os.environ["STORM_MYSQL_URI"]) |
2145 | + |
2146 | + def create_tables(self): |
2147 | + connection = self.connection |
2148 | + connection.execute("CREATE TABLE foo " |
2149 | + "(id INT PRIMARY KEY AUTO_INCREMENT," |
2150 | + " title VARCHAR(50) DEFAULT 'Default Title')" |
2151 | + " ENGINE=InnoDB") |
2152 | + connection.execute("CREATE TABLE bar " |
2153 | + "(id INT PRIMARY KEY AUTO_INCREMENT," |
2154 | + " foo_id INTEGER, title VARCHAR(50))" |
2155 | + " ENGINE=InnoDB") |
2156 | |
2157 | === added file 'tests/twisted/postgres.py' |
2158 | --- tests/twisted/postgres.py 1970-01-01 00:00:00 +0000 |
2159 | +++ tests/twisted/postgres.py 2009-12-27 11:31:15 +0000 |
2160 | @@ -0,0 +1,57 @@ |
2161 | +import os |
2162 | + |
2163 | +from storm.database import create_database |
2164 | + |
2165 | +from tests.twisted.base import DeferredStoreTest, StorePoolTest |
2166 | + |
2167 | +from twisted.trial.unittest import TestCase |
2168 | + |
2169 | + |
2170 | +class PostgresDeferredStoreTest(TestCase, DeferredStoreTest): |
2171 | + |
2172 | + def setUp(self): |
2173 | + return DeferredStoreTest.setUp(self) |
2174 | + |
2175 | + def tearDown(self): |
2176 | + return DeferredStoreTest.tearDown(self) |
2177 | + |
2178 | + def is_supported(self): |
2179 | + return bool(os.environ.get("STORM_POSTGRES_URI")) |
2180 | + |
2181 | + def create_database(self): |
2182 | + self.database = create_database(os.environ["STORM_POSTGRES_URI"]) |
2183 | + |
2184 | + def create_tables(self): |
2185 | + connection = self.connection |
2186 | + connection.execute("CREATE TABLE foo " |
2187 | + "(id SERIAL PRIMARY KEY," |
2188 | + " title VARCHAR DEFAULT 'Default Title')") |
2189 | + connection.execute("CREATE TABLE bar " |
2190 | + "(id SERIAL PRIMARY KEY," |
2191 | + " foo_id INTEGER, title VARCHAR)") |
2192 | + connection.execute("CREATE TABLE egg " |
2193 | + "(id SERIAL PRIMARY KEY, value INTEGER)") |
2194 | + |
2195 | + |
2196 | +class PostgresStorePoolTest(TestCase, StorePoolTest): |
2197 | + |
2198 | + def setUp(self): |
2199 | + return StorePoolTest.setUp(self) |
2200 | + |
2201 | + def tearDown(self): |
2202 | + return StorePoolTest.tearDown(self) |
2203 | + |
2204 | + def is_supported(self): |
2205 | + return bool(os.environ.get("STORM_POSTGRES_URI")) |
2206 | + |
2207 | + def create_database(self): |
2208 | + self.database = create_database(os.environ["STORM_POSTGRES_URI"]) |
2209 | + |
2210 | + def create_tables(self): |
2211 | + connection = self.connection |
2212 | + connection.execute("CREATE TABLE foo " |
2213 | + "(id SERIAL PRIMARY KEY," |
2214 | + " title VARCHAR DEFAULT 'Default Title')") |
2215 | + connection.execute("CREATE TABLE bar " |
2216 | + "(id SERIAL PRIMARY KEY," |
2217 | + " foo_id INTEGER, title VARCHAR)") |
2218 | |
2219 | === added file 'tests/twisted/sqlite.py' |
2220 | --- tests/twisted/sqlite.py 1970-01-01 00:00:00 +0000 |
2221 | +++ tests/twisted/sqlite.py 2009-12-27 11:31:15 +0000 |
2222 | @@ -0,0 +1,65 @@ |
2223 | +from storm.databases.sqlite import SQLite |
2224 | +from storm.uri import URI |
2225 | + |
2226 | +from tests.twisted.base import DeferredStoreTest, StorePoolTest |
2227 | +from tests.helper import TestHelper, MakePath |
2228 | + |
2229 | +from twisted.trial.unittest import TestCase |
2230 | + |
2231 | + |
2232 | +class SQLiteDeferredStoreTest(TestCase, TestHelper, DeferredStoreTest): |
2233 | + |
2234 | + helpers = [MakePath] |
2235 | + |
2236 | + def setUp(self): |
2237 | + TestHelper.setUp(self) |
2238 | + return DeferredStoreTest.setUp(self) |
2239 | + |
2240 | + def tearDown(self): |
2241 | + def cb(passthrough): |
2242 | + TestHelper.tearDown(self) |
2243 | + return passthrough |
2244 | + return DeferredStoreTest.tearDown(self).addBoth(cb) |
2245 | + |
2246 | + def create_database(self): |
2247 | + self.database = SQLite(URI("sqlite:%s?synchronous=OFF" % |
2248 | + self.make_path())) |
2249 | + |
2250 | + def create_tables(self): |
2251 | + connection = self.connection |
2252 | + connection.execute("CREATE TABLE foo " |
2253 | + "(id INTEGER PRIMARY KEY," |
2254 | + " title VARCHAR DEFAULT 'Default Title')") |
2255 | + connection.execute("CREATE TABLE bar " |
2256 | + "(id INTEGER PRIMARY KEY," |
2257 | + " foo_id INTEGER, title VARCHAR)") |
2258 | + connection.execute("CREATE TABLE egg " |
2259 | + "(id INTEGER PRIMARY KEY, value INTEGER)") |
2260 | + |
2261 | + |
2262 | +class SQLiteStorePoolTest(TestCase, TestHelper, StorePoolTest): |
2263 | + |
2264 | + helpers = [MakePath] |
2265 | + |
2266 | + def setUp(self): |
2267 | + TestHelper.setUp(self) |
2268 | + return StorePoolTest.setUp(self) |
2269 | + |
2270 | + def tearDown(self): |
2271 | + def cb(passthrough): |
2272 | + TestHelper.tearDown(self) |
2273 | + return passthrough |
2274 | + return StorePoolTest.tearDown(self).addBoth(cb) |
2275 | + |
2276 | + def create_database(self): |
2277 | + self.database = SQLite(URI("sqlite:%s?synchronous=OFF" % |
2278 | + self.make_path())) |
2279 | + |
2280 | + def create_tables(self): |
2281 | + connection = self.connection |
2282 | + connection.execute("CREATE TABLE foo " |
2283 | + "(id INTEGER PRIMARY KEY," |
2284 | + " title VARCHAR DEFAULT 'Default Title')") |
2285 | + connection.execute("CREATE TABLE bar " |
2286 | + "(id INTEGER PRIMARY KEY," |
2287 | + " foo_id INTEGER, title VARCHAR)") |
[1] Why setDaemon(True)
[2] can you use underscores for DeferredStore
[3] doc for get explaining that it loads all attributes
[4] It may be useful to extend the core Store API to allow us to avoid accessing internal attributes in the wrapper. Like, for example, Store.disable_ lazy_loading( ). I dunno.
[5] global mapping of storm store to deferred store? maybe?
[6] Make it so that objects that came from stores which were then put back into a pool are "broken" so that any attribute access breaks in the future.