Merge lp:~therve/storm/twisted-integration into lp:storm

Proposed by Christopher Armstrong
Status: Work in progress
Proposed branch: lp:~therve/storm/twisted-integration
Merge into: lp:storm
Diff against target: 2287 lines (+2201/-1)
9 files modified
storm/database.py (+18/-1)
storm/exceptions.py (+4/-0)
storm/twisted/__init__.py (+23/-0)
storm/twisted/store.py (+480/-0)
storm/twisted/wrapper.py (+206/-0)
tests/twisted/base.py (+1286/-0)
tests/twisted/mysql.py (+62/-0)
tests/twisted/postgres.py (+57/-0)
tests/twisted/sqlite.py (+65/-0)
To merge this branch: bzr merge lp:~therve/storm/twisted-integration
Reviewer Review Type Date Requested Status
Christopher Armstrong (community) Abstain
Review via email: mp+3733@code.launchpad.net
To post a comment you must log in.
Revision history for this message
Christopher Armstrong (radix) wrote :

[1] Why setDaemon(True)

[2] can you use underscores for DeferredStore

[3] doc for get explaining that it loads all attributes

[4] It may be useful to extend the core Store API to allow us to avoid accessing internal attributes in the wrapper. Like, for example, Store.disable_lazy_loading(). I dunno.

[5] global mapping of storm store to deferred store? maybe?

[6] Make it so that objects that came from stores which were then put back into a pool are "broken" so that any attribute access breaks in the future.

review: Abstain
Revision history for this message
Gabriel (gabriel-rossetti) wrote :

Christopher Armstrong wrote:
> Christopher Armstrong has proposed merging lp:~therve/storm/twisted-integration into lp:storm.
>
> Requested reviews:
> Storm Developers (storm)
>
I second that :-)

Gabriel

Revision history for this message
Christopher Armstrong (radix) wrote :

[7] Expose reload to DeferredStore

[8] Maybe add deferredStore.commit(reload=True)

review: Abstain
Revision history for this message
Drew Smathers (djfroofy) wrote :

> [4] It may be useful to extend the core Store API to allow us to avoid accessing internal attributes in the wrapper. Like, for example, Store.disable_lazy_loading(). I dunno.

That sounds like a great idea to me. Otherwise having to refetch objects just to avoid lazy value errors in less than ideal.

lp:~therve/storm/twisted-integration updated
215. By Thomas Herve

Merge from trunk.

216. By Thomas Herve

Move the tests to be able to run Postgres/MySQL tests as well

217. By Thomas Herve

Add postgres tests

218. By Thomas Herve

Add mysql, some tests failing.

219. By Thomas Herve

Use InnoDB to be able to use transactions.

220. By Thomas Herve

Merge from trunk.

221. By Thomas Herve

Force respect of the max_stores value, by incrementing the number of stores
once start_store is called, instead of doing it in the callback [f=373816]

222. By Thomas Herve

Use TestHelper/MakePath to manage sqlite file.

223. By Thomas Herve

Fix doc typos

224. By Thomas Herve

Merge from trunk

225. By Thomas Herve

Some tests cleanup

226. By Thomas Herve

Merge from trunk

227. By Thomas Herve

Cleanups

228. By Thomas Herve

Check thread in Connection to catch problems.

229. By Thomas Herve

Handle errors in rollback

230. By Thomas Herve

Typo, cleanup

231. By Thomas Herve

Make start/stop methods private

Unmerged revisions

231. By Thomas Herve

Make start/stop methods private

230. By Thomas Herve

Typo, cleanup

229. By Thomas Herve

Handle errors in rollback

228. By Thomas Herve

Check thread in Connection to catch problems.

227. By Thomas Herve

Cleanups

226. By Thomas Herve

Merge from trunk

225. By Thomas Herve

Some tests cleanup

224. By Thomas Herve

Merge from trunk

223. By Thomas Herve

Fix doc typos

222. By Thomas Herve

Use TestHelper/MakePath to manage sqlite file.

Preview Diff

[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
=== modified file 'storm/database.py'
--- storm/database.py 2009-07-30 06:19:27 +0000
+++ storm/database.py 2009-12-27 11:31:14 +0000
@@ -24,12 +24,13 @@
24This is the common code for database support; specific databases are24This is the common code for database support; specific databases are
25supported in modules in L{storm.databases}.25supported in modules in L{storm.databases}.
26"""26"""
27import threading
2728
28from storm.expr import Expr, State, compile29from storm.expr import Expr, State, compile
29from storm.tracer import trace30from storm.tracer import trace
30from storm.variables import Variable31from storm.variables import Variable
31from storm.exceptions import (32from storm.exceptions import (
32 ClosedError, DatabaseError, DisconnectionError, Error)33 ClosedError, DatabaseError, DisconnectionError, Error, ThreadSafetyError)
33from storm.uri import URI34from storm.uri import URI
34import storm35import storm
3536
@@ -180,6 +181,20 @@
180 self._database = database # Ensures deallocation order.181 self._database = database # Ensures deallocation order.
181 self._event = event182 self._event = event
182 self._raw_connection = self._database.raw_connect()183 self._raw_connection = self._database.raw_connect()
184 self._thread = threading.currentThread()
185
186 def _check_thread(self):
187 """
188 Check if the current thread is the same thread that created the
189 connection. This is a safety check that prevents breaking thread
190 boundaries which can create weird bugs. C{execute}, C{commit} and
191 C{rollback} should call it, directly or indirectly.
192 """
193 thread = threading.currentThread()
194 if thread is not self._thread:
195 raise ThreadSafetyError(
196 "'%s' is not the connection thread '%s'" %
197 (thread, self._thread))
183198
184 def __del__(self):199 def __del__(self):
185 """Close the connection."""200 """Close the connection."""
@@ -240,6 +255,7 @@
240255
241 def rollback(self):256 def rollback(self):
242 """Rollback the connection."""257 """Rollback the connection."""
258 self._check_thread()
243 if self._state == STATE_CONNECTED:259 if self._state == STATE_CONNECTED:
244 try:260 try:
245 self._raw_connection.rollback()261 self._raw_connection.rollback()
@@ -314,6 +330,7 @@
314 If the connection is marked as dead, or if we can't reconnect,330 If the connection is marked as dead, or if we can't reconnect,
315 then raise DisconnectionError.331 then raise DisconnectionError.
316 """332 """
333 self._check_thread()
317 if self._state == STATE_CONNECTED:334 if self._state == STATE_CONNECTED:
318 return335 return
319 elif self._state == STATE_DISCONNECTED:336 elif self._state == STATE_DISCONNECTED:
320337
=== modified file 'storm/exceptions.py'
--- storm/exceptions.py 2009-05-15 08:43:21 +0000
+++ storm/exceptions.py 2009-12-27 11:31:14 +0000
@@ -82,6 +82,10 @@
82 pass82 pass
8383
8484
85class ThreadSafetyError(StormError):
86 """Exception raised when cross-threads operations are attempted."""
87
88
85class Error(StormError):89class Error(StormError):
86 pass90 pass
8791
8892
=== added directory 'storm/twisted'
=== added file 'storm/twisted/__init__.py'
--- storm/twisted/__init__.py 1970-01-01 00:00:00 +0000
+++ storm/twisted/__init__.py 2009-12-27 11:31:14 +0000
@@ -0,0 +1,23 @@
1#
2# Copyright (c) 2007 Canonical
3# Copyright (c) 2007 Thomas Herve <thomas@nimail.org>
4#
5# This file is part of Storm Object Relational Mapper.
6#
7# Storm is free software; you can redistribute it and/or modify
8# it under the terms of the GNU Lesser General Public License as
9# published by the Free Software Foundation; either version 2.1 of
10# the License, or (at your option) any later version.
11#
12# Storm is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU Lesser General Public License for more details.
16#
17# You should have received a copy of the GNU Lesser General Public License
18# along with this program. If not, see <http://www.gnu.org/licenses/>.
19#
20
21"""
22Asynchronous wrapper around storm to be used within a Twisted application.
23"""
024
=== added file 'storm/twisted/store.py'
--- storm/twisted/store.py 1970-01-01 00:00:00 +0000
+++ storm/twisted/store.py 2009-12-27 11:31:14 +0000
@@ -0,0 +1,480 @@
1#
2# Copyright (c) 2007 Canonical
3# Copyright (c) 2007 Thomas Herve <thomas@nimail.org>
4#
5# This file is part of Storm Object Relational Mapper.
6#
7# Storm is free software; you can redistribute it and/or modify
8# it under the terms of the GNU Lesser General Public License as
9# published by the Free Software Foundation; either version 2.1 of
10# the License, or (at your option) any later version.
11#
12# Storm is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU Lesser General Public License for more details.
16#
17# You should have received a copy of the GNU Lesser General Public License
18# along with this program. If not, see <http://www.gnu.org/licenses/>.
19#
20
21"""
22Store wrapper and custom thread runner.
23"""
24
25from threading import Thread
26from Queue import Queue
27
28from storm.store import Store, AutoReload
29from storm.info import get_obj_info
30from storm.twisted.wrapper import partial, DeferredResult, DeferredResultSet
31
32from twisted.internet.defer import Deferred, deferredGenerator, maybeDeferred
33from twisted.internet.defer import waitForDeferred, succeed, fail
34from twisted.python.failure import Failure
35
36
37
38class AlreadyStopped(Exception):
39 """
40 Except raised when a store is stopped multiple time.
41 """
42
43
44
45class StoreThread(Thread):
46 """
47 A thread class that wraps methods calls and fires deferred in the reactor
48 thread.
49 """
50 STOP = object()
51
52 def __init__(self):
53 """
54 Initialize the thread, and create a L{Queue} to stack jobs.
55 """
56 Thread.__init__(self)
57 self.setDaemon(True)
58 self._queue = Queue()
59 self._stop_deferred = None
60 self.stopped = False
61
62
63 def defer_to_thread(self, f, *args, **kwargs):
64 """
65 Run the given function in the thread, wrapping the result with a
66 L{Deferred}.
67
68 @return: a deferred whose result will be the result of the function.
69 @rtype: C{Deferred}
70 """
71 if self.stopped:
72 # to prevent having pending calls after the thread got stopped
73 return fail(AlreadyStopped(f))
74 d = Deferred()
75 self._queue.put((d, f, args, kwargs))
76 return d
77
78
79 def run(self):
80 """
81 Main execution loop: retrieve jobs from the queue and run them.
82 """
83 from twisted.internet import reactor
84 o = self._queue.get()
85 while o is not self.STOP:
86 d, f, args, kwargs = o
87 try:
88 result = f(*args, **kwargs)
89 except:
90 f = Failure()
91 reactor.callFromThread(d.errback, f)
92 else:
93 reactor.callFromThread(d.callback, result)
94 o = self._queue.get()
95 reactor.callFromThread(self._stop_deferred.callback, None)
96
97
98 def stop(self):
99 """
100 Stop the thread.
101 """
102 if self.stopped:
103 return self._stop_deferred
104 self._stop_deferred = Deferred()
105 self._queue.put(self.STOP)
106 self.stopped = True
107 return self._stop_deferred
108
109
110
111class DeferredStore(object):
112 """
113 A wrapper around L{Store} to have async operations.
114 """
115 store = None
116
117 def __init__(self, database):
118 """
119 @param database: instance of database providing connection, used to
120 instantiate the store later.
121 @type database: L{storm.database.Database}
122 """
123 self.thread = StoreThread()
124 self.database = database
125 self.started = False
126
127
128 def start(self):
129 """
130 Start the store.
131
132 @return: a deferred that will fire once the store is started.
133 """
134 if not self.started:
135 self.started = True
136 self.thread.start()
137 # Add a event trigger to be sure that the thread is stopped
138 from twisted.internet import reactor
139 reactor.addSystemEventTrigger(
140 "before", "shutdown", self.stop)
141 return self.thread.defer_to_thread(Store, self.database
142 ).addCallback(self._got_store)
143 else:
144 raise RuntimeError("Already started")
145
146
147 def _got_store(self, store):
148 """
149 Internal method called when the store is created, initializing most of
150 the API methods.
151 """
152 self.store = store
153 # Maybe not ?
154 self.store._deferredStore = self
155 for methodName in ("commit", "flush", "remove", "reload",
156 "rollback"):
157 method = partial(self.thread.defer_to_thread,
158 getattr(self.store, methodName))
159 setattr(self, methodName, method)
160
161 self._do_resolve_lazy_value = self.store._resolve_lazy_value
162 self.store._resolve_lazy_value = self._resolve_lazy_value
163
164
165 def get(self, cls, key):
166 def _get():
167 obj = self.store.get(cls, key)
168 if obj is not None:
169 obj_info = get_obj_info(obj)
170 self._do_resolve_lazy_value(obj_info, None, AutoReload)
171 return obj
172 return self.thread.defer_to_thread(_get)
173
174
175 def add(self, obj):
176 """
177 Specific add method that doesn't return any result, to not make think
178 that it's something usable.
179 """
180 def _add():
181 self.store.add(obj)
182 return self.thread.defer_to_thread(_add)
183
184
185 def execute(self, *args, **kwargs):
186 """
187 Wrapper around C{execute} to have a C{DeferredResult} instead of the
188 standard L{storm.database.Result} object.
189 """
190 if self.store is None:
191 raise RuntimeError("Store not started")
192 return self.thread.defer_to_thread(
193 self.store.execute, *args, **kwargs
194 ).addCallback(self._cb_execute)
195
196
197 def _cb_execute(self, result):
198 """
199 Wrap the result with a C{DeferredResult}.
200 """
201 if result is not None:
202 return DeferredResult(self.thread, result)
203
204
205 def find(self, *args, **kwargs):
206 """
207 Wrapper around C{find}.
208 """
209 if self.store is None:
210 raise RuntimeError("Store not started")
211 return self.thread.defer_to_thread(
212 self.store.find, *args, **kwargs
213 ).addCallback(self._cb_find)
214
215
216 def _cb_find(self, resultSet):
217 """
218 Wrap the result set with a C{DeferredResultSet}.
219 """
220 return DeferredResultSet(self.thread, resultSet)
221
222
223 def stop(self):
224 """
225 Stop the store.
226 """
227 if self.thread.stopped:
228 return succeed(None)
229 def close():
230 self.store.rollback()
231 self.store.close()
232 return self.thread.defer_to_thread(close
233 ).addCallback(lambda ign: self.thread.stop())
234
235
236 def _resolve_lazy_value(self, *args):
237 raise RuntimeError(
238 "Resolving lazy values with the Twisted wrapper is not possible "
239 "right now! Please refetch your object using "
240 "store.get/store.find")
241
242
243 @staticmethod
244 def of(obj):
245 """
246 Get the DeferredStore object is associated with
247
248 If the given object has not been associated with a DeferredStore,
249 return None.
250 """
251 store = Store.of(obj)
252 if not store:
253 return
254 return getattr(store, '_deferredStore', None)
255
256
257
258class StorePool(object):
259 """
260 A pool of started stores, maintaining persistent connections.
261 """
262 started = False
263 store_factory = DeferredStore
264
265 def __init__(self, database, min_stores=0, max_stores=10):
266 """
267 @param database: instance of database providing connection, used to
268 instantiate the store later.
269 @type database: L{storm.database.Database}
270
271 @param min_stores: initial number of stores.
272 @type min_stores: C{int}
273
274 @param max_stores: maximum number of stores.
275 @type max_stores: C{int}
276 """
277 self.database = database
278 self.min_stores = min_stores
279 self.max_stores = max_stores
280 self._stores = []
281 self._stores_created = 0
282 self._pending_get = []
283 self._store_refs = []
284
285
286 def start(self):
287 """
288 Start the pool.
289 """
290 if self.started:
291 raise RuntimeError("Already started")
292 self.started = True
293 return self.adjust_size()
294
295
296 def stop(self):
297 """
298 Stop the pool: this is not a total stop, it just try to kill the
299 current available stores.
300 """
301 return self.adjust_size(0, 0, self._store_refs)
302
303
304 def _start_store(self):
305 """
306 Create a new store.
307 """
308 store = self.store_factory(self.database)
309 # Increment here, so that other simultaneous calls don't make the
310 # number of connections pass the maximum
311 self._stores_created += 1
312 return store.start(
313 ).addCallback(self._cb_start_store, store
314 ).addErrback(self._eb_start_store)
315
316
317 def _cb_start_store(self, ign, store):
318 """
319 Add the created store to the list of available stores.
320 """
321 self._stores.append(store)
322 self._store_refs.append(store)
323
324
325 def _eb_start_store(self, failure):
326 """
327 Reduce the amount of created stores, and let the failure propagate.
328 """
329 self._stores_created -= 1
330 return failure
331
332
333 def _stop_store(self, stores=None):
334 """
335 Stop a store and remove it from the available stores.
336 """
337 if stores is None:
338 stores = self._stores
339 self._stores_created -= 1
340 store = stores.pop()
341 return store.stop()
342
343
344 @deferredGenerator
345 def adjust_size(self, min_stores=None, max_stores=None, stores=None):
346 """
347 Change the number of available stores, shrinking or raising as
348 necessary.
349 """
350 if min_stores is None:
351 min_stores = self.min_stores
352 if max_stores is None:
353 max_stores = self.max_stores
354 if stores is None:
355 stores = self._stores
356
357 if min_stores < 0:
358 raise ValueError('minimum is negative')
359 if min_stores > max_stores:
360 raise ValueError('minimum is greater than maximum')
361
362 self.min_stores = min_stores
363 self.max_stores = max_stores
364 if not self.started:
365 return
366
367 # Kill of some stores if we have too many.
368 while self._stores_created > self.max_stores and stores:
369 wfd = waitForDeferred(self._stop_store(stores))
370 yield wfd
371 wfd.getResult()
372 # Start some stores if we have too few.
373 while self._stores_created < self.min_stores:
374 wfd = waitForDeferred(self._start_store())
375 yield wfd
376 wfd.getResult()
377
378
379 def get(self):
380 """
381 Return a started store from the pool, or start a new one if necessary.
382 A store retrieve by this way should be put back using the put
383 method, or it won't be used anymore.
384 """
385 if not self.started:
386 raise RuntimeError("Not started")
387 if self._stores:
388 store = self._stores.pop()
389 return succeed(store)
390 elif self._stores_created < self.max_stores:
391 return self._start_store().addCallback(self._cb_get)
392 else:
393 # Maybe all stores are consumed?
394 return self.adjust_size().addCallback(self._cb_get)
395
396
397 def _cb_get(self, ign):
398 """
399 If the previous operation added a store, return it, or return a pending
400 C{Deferred}.
401 """
402 if self._stores:
403 store = self._stores.pop()
404 return store
405 else:
406 # All stores are in used, wait
407 d = Deferred()
408 self._pending_get.append(d)
409 return d
410
411
412 def put(self, store):
413 """
414 Make a store available again.
415
416 This should be done explicitely to have the store back in the pool.
417 The good way to use the pool is this:
418
419 >>> d1 = pool.get()
420
421 >>> # d1 callback with a store
422 >>> d2 = store.add(foo)
423 >>> d2.addCallback(doSomething).addErrback(manageErrors)
424 >>> d2.addBoth(lambda x: pool.put(store))
425 """
426 return store.rollback().addBoth(self._cb_put, store)
427
428
429 def _cb_put(self, passthrough, store):
430 """
431 Once the rollback has finished, the store is really available.
432 """
433 if self._pending_get:
434 # People are waiting, fire with the store
435 d = self._pending_get.pop(0)
436 d.callback(store)
437 else:
438 self._stores.append(store)
439 return passthrough
440
441
442 def transact(self, f, *args, **kwargs):
443 """
444 Call function C{f} with a L{Store} instance and arguments C{args} and
445 C{kwargs} in transaction bound to the acquired store. If transaction
446 succeeds, store will be commited. Store is returned to this pool after
447 call to C{f} completes.
448
449 Note that the function C{f} must return an instance of L{Deferred}.
450
451 @param f: function to call in transaction
452 @param args: positional arguments to function C{f}
453 @param kwargs: keyword arguments to function C{f}
454 """
455 return self.get(
456 ).addCallback(self._cb_transact_start, f, args, kwargs)
457
458
459 def _cb_transact_start(self, store, f, args, kwargs):
460 """
461 Call transacted function with acquired store.
462 """
463 result = maybeDeferred(f, store, *args, **kwargs)
464 result.addCallback(self._cb_transact_success, store)
465 result.addBoth(self._cb_transact_stop, store)
466 return result
467
468
469 def _cb_transact_success(self, result, store):
470 """
471 Commit and pass through function result.
472 """
473 return store.commit().addCallback(lambda ignore: result)
474
475
476 def _cb_transact_stop(self, result, store):
477 """
478 Return the store back to the pool and pass through the result again.
479 """
480 return self.put(store).addCallback(lambda ignore: result)
0481
=== added file 'storm/twisted/wrapper.py'
--- storm/twisted/wrapper.py 1970-01-01 00:00:00 +0000
+++ storm/twisted/wrapper.py 2009-12-27 11:31:14 +0000
@@ -0,0 +1,206 @@
1#
2# Copyright (c) 2007 Canonical
3# Copyright (c) 2007 Thomas Herve <thomas@nimail.org>
4#
5# This file is part of Storm Object Relational Mapper.
6#
7# Storm is free software; you can redistribute it and/or modify
8# it under the terms of the GNU Lesser General Public License as
9# published by the Free Software Foundation; either version 2.1 of
10# the License, or (at your option) any later version.
11#
12# Storm is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU Lesser General Public License for more details.
16#
17# You should have received a copy of the GNU Lesser General Public License
18# along with this program. If not, see <http://www.gnu.org/licenses/>.
19#
20
21"""
22Asynchronous wrapper around storm.
23"""
24
25from storm.store import Store
26from storm.references import Reference, ReferenceSet
27
28
29try:
30 from functools import partial
31except ImportError:
32 # For Python < 2.5
33 class partial(object):
34 def __init__(self, fn, *args, **kw):
35 self.fn = fn
36 self.args = args
37 self.kw = kw
38
39 def __call__(self, *args, **kw):
40 if kw and self.kw:
41 d = self.kw.copy()
42 d.update(kw)
43 else:
44 d = kw or self.kw
45 return self.fn(*(self.args + args), **d)
46
47
48
49class DeferredResult(object):
50 """
51 Proxy for a storm result, running the blocking methods in a thread and
52 returning C{Deferred}s.
53 """
54
55 def __init__(self, thread, result):
56 """
57 @param thread: the running thread of the store
58 @type thread: C{StoreThread}
59
60 @param result: the result instance to be wrapped.
61 @type result: C{storm.database.Result}
62 """
63 self.result = result
64 for methodName in ("get_one", "get_all"):
65 method = partial(thread.defer_to_thread,
66 getattr(result, methodName))
67 setattr(self, methodName, method)
68
69
70
71class DeferredResultSet(object):
72 """
73 Wrapper for a L{storm.store.ResultSet}.
74 """
75
76 def __init__(self, thread, resultSet):
77 """
78 Create the results with given C{StoreThread} and the set to wrap.
79 """
80 self._thread = thread
81 self._resultSet = resultSet
82 for methodName in ("any", "one", "first", "last", "remove", "count",
83 "max", "min", "avg", "sum", "set", "is_empty"):
84 method = partial(thread.defer_to_thread,
85 getattr(resultSet, methodName))
86 setattr(self, methodName, method)
87 for methodName in ("union", "difference", "intersection"):
88 method = partial(self._set_expr,
89 getattr(resultSet, methodName))
90 setattr(self, methodName, method)
91 for methodName in ("order_by", "config", "group_by", "having"):
92 setattr(self, methodName, getattr(self._resultSet, methodName))
93
94
95 def all(self):
96 """
97 Specific method to emulate C{__iter__}.
98 """
99 return self._thread.defer_to_thread(list, self._resultSet)
100
101
102 def values(self, *columns):
103 """
104 Wrapper around values that remove the iterator feature to return a list
105 instead.
106 """
107 def _get_values():
108 return list(self._resultSet.values(*columns))
109 return self._thread.defer_to_thread(_get_values)
110
111
112 def _set_expr(self, method, other, all=False):
113 """
114 Wrap a set expression with a C{DeferredResultSet}.
115 """
116 return DeferredResultSet(self._thread, method(other, all))
117
118
119
120class DeferredReference(Reference):
121 """
122 A reference property but within a C{Deferred}.
123 """
124
125 def __get__(self, local, cls=None):
126 """
127 Wrapper around C{Reference.__get__}.
128 """
129 store = Store.of(local)
130 if store is None:
131 return None
132 _thread = store._deferredStore.thread
133 return _thread.defer_to_thread(Reference.__get__, self, local, cls)
134
135
136 def __set__(self, local, remote):
137 """
138 Wrapper around C{Reference.__set__}.
139 """
140 raise RuntimeError("Can't set a DeferredReference")
141
142
143
144class DeferredReferenceSet(ReferenceSet):
145 """
146 A C{ReferenceSet} but within a C{Deferred}.
147 """
148
149 def __get__(self, local, cls=None):
150 """
151 Wrapper around C{ReferenceSet.__get__}.
152 """
153 store = Store.of(local)
154 if store is None:
155 return None
156 _thread = store._deferredStore.thread
157 boundReference = ReferenceSet.__get__(self, local, cls)
158 return DeferredBoundReference(_thread, boundReference)
159
160
161
162class DeferredBoundReference(object):
163 """
164 Wrapper around C{BoundReferenceSet} and C{BoundIndirectReferenceSet}.
165 """
166
167 def __init__(self, thread, boundReference):
168 """
169 Create the reference with given C{StoreThread} and the reference to
170 wrap.
171 """
172 self._thread = thread
173 self._boundReference = boundReference
174 for methodName in ("clear", "add", "remove", "any", "count", "one",
175 "first", "last"):
176 method = partial(thread.defer_to_thread,
177 getattr(boundReference, methodName))
178 setattr(self, methodName, method)
179 for methodName in ("order_by", "find"):
180 method = partial(self._defer_and_wrap_result,
181 getattr(boundReference, methodName))
182 setattr(self, methodName, method)
183
184
185 def all(self):
186 """
187 Specific method to emulate C{__iter__}.
188 """
189 return self._thread.defer_to_thread(list, self._boundReference)
190
191
192 def values(self, *columns):
193 """
194 Emulate the values method.
195 """
196 def _get_values():
197 return list(self._boundReference.values(*columns))
198 return self._thread.defer_to_thread(_get_values)
199
200
201 def _defer_and_wrap_result(self, method, *args, **kwargs):
202 """
203 Helper for methods returning another C{ResultSet}.
204 """
205 return self._thread.defer_to_thread(method, *args, **kwargs
206 ).addCallback(lambda x: DeferredResultSet(self._thread, x))
0207
=== added directory 'tests/twisted'
=== added file 'tests/twisted/__init__.py'
=== added file 'tests/twisted/base.py'
--- tests/twisted/base.py 1970-01-01 00:00:00 +0000
+++ tests/twisted/base.py 2009-12-27 11:31:15 +0000
@@ -0,0 +1,1286 @@
1#
2# Copyright (c) 2007 Canonical
3# Copyright (c) 2007 Thomas Herve <thomas@nimail.org>
4#
5# This file is part of Storm Object Relational Mapper.
6#
7# Storm is free software; you can redistribute it and/or modify
8# it under the terms of the GNU Lesser General Public License as
9# published by the Free Software Foundation; either version 2.1 of
10# the License, or (at your option) any later version.
11#
12# Storm is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU Lesser General Public License for more details.
16#
17# You should have received a copy of the GNU Lesser General Public License
18# along with this program. If not, see <http://www.gnu.org/licenses/>.
19#
20
21"""
22Test for twistorm.
23"""
24
25from storm.properties import Int, Unicode
26from storm.expr import Count
27from storm.references import Reference
28from storm.exceptions import OperationalError, ThreadSafetyError
29
30from storm.twisted.store import (
31 DeferredStore, StoreThread, StorePool, AlreadyStopped )
32from storm.twisted.wrapper import DeferredReference, DeferredReferenceSet
33
34from twisted.trial.unittest import TestCase
35from twisted.internet.defer import gatherResults, deferredGenerator
36from twisted.internet.defer import waitForDeferred, DeferredList
37from twisted.internet.defer import succeed, fail
38
39
40class Foo(object):
41 """
42 Test table.
43 """
44 __storm_table__ = "foo"
45 id = Int(primary=True)
46 title = Unicode()
47
48
49
50class Bar(object):
51 """
52 Test table referencing to C{Foo}
53 """
54 __storm_table__ = "bar"
55 id = Int(primary=True)
56 title = Unicode()
57 foo_id = Int()
58 foo = DeferredReference(foo_id, Foo.id)
59
60
61
62class FooRefSet(Foo):
63 """
64 A C{Foo} class with a C{DeferredReferenceSet} to get all the bars related.
65 """
66 bars = DeferredReferenceSet(Foo.id, Bar.foo_id)
67
68
69
70class FooRefSetOrderID(Foo):
71 """
72 A C{Foo} class with an order C{DeferredReferenceSet} to C{Bar}.
73 """
74 bars = DeferredReferenceSet(Foo.id, Bar.foo_id, order_by=Bar.id)
75
76
77
78class Egg(object):
79 """
80 Test table.
81 """
82 __storm_table__ = "egg"
83 id = Int(primary=True)
84 value = Int()
85
86
87
88class DeferredStoreTest(object):
89 """
90 Tests for L{DeferredStore}.
91 """
92
93 def setUp(self):
94 """
95 Create a test sqlite database, and insert some data.
96 """
97 self.create_database()
98 connection = self.connection = self.database.connect()
99 self.drop_tables()
100 self.create_tables()
101 connection.execute("INSERT INTO foo VALUES (10, 'Title 30')")
102 connection.execute("INSERT INTO bar VALUES (10, 10, 'Title 50')")
103 connection.execute("INSERT INTO bar VALUES (11, 10, 'Title 40')")
104 connection.execute("INSERT INTO egg VALUES (1, 4)")
105 connection.execute("INSERT INTO egg VALUES (2, 3)")
106 connection.execute("INSERT INTO egg VALUES (3, 7)")
107 connection.execute("INSERT INTO egg VALUES (4, 5)")
108 connection.commit()
109 self.store = DeferredStore(self.database)
110 return self.store.start()
111
112
113 def tearDown(self):
114 """
115 Kill the store (and its underlying thread).
116 """
117 def _stop(ign):
118 return self.store.stop().addCallback(_drop)
119 def _drop(ign):
120 self.drop_tables()
121 return self.store.rollback().addCallback(_stop)
122
123
124 def create_database(self):
125 raise NotImplementedError()
126
127
128 def create_tables(self):
129 raise NotImplementedError()
130
131
132 def drop_tables(self):
133 for table in ["foo", "bar", "egg"]:
134 try:
135 self.connection.execute("DROP TABLE %s" % table)
136 self.connection.commit()
137 except:
138 self.connection.rollback()
139
140
141 def test_multiple_start(self):
142 """
143 Check that start raises an exception when the store is already started.
144 """
145 self.assertRaises(RuntimeError, self.store.start)
146
147
148 def test_get(self):
149 """
150 Try to get an object from the store and check its attributes.
151 """
152 def cb(result):
153 self.assertEquals(result.title, u"Title 30")
154 self.assertEquals(result.id, 10)
155 return self.store.get(Foo, 10).addCallback(cb)
156
157
158 def test_add(self):
159 """
160 Add an object to the store.
161 """
162 foo = Foo()
163 foo.title = u"Great title"
164 foo.id = 11
165 def cb_add(ign):
166 return self.store.get(Foo, 11).addCallback(cb_get)
167 def cb_get(result):
168 self.assertEquals(result.title, u"Great title")
169 self.assertEquals(result.id, 11)
170 return self.store.add(foo).addCallback(cb_add)
171
172
173 def test_add_default_value(self):
174 """
175 When adding an object to the store, the default values from the
176 database are retrieved and put into the object.
177 """
178 foo = Foo()
179 foo.id = 11
180 def cb_add(result):
181 self.assertIdentical(result, None)
182 return self.store.get(Foo, 11).addCallback(cb_get)
183 def cb_get(result):
184 self.assertEquals(result.title, u"Default Title")
185 self.assertEquals(result.id, 11)
186 return self.store.add(foo).addCallback(cb_add)
187
188
189 def test_execute(self):
190 """
191 Test a direct execute on the store, and the C{get_one} method of
192 C{DeferredResult}.
193 """
194 def cb_execute(result):
195 return result.get_one().addCallback(cb_result)
196 def cb_result(result):
197 self.assertEquals(result, (u"Title 30",))
198 return self.store.execute("SELECT title FROM foo WHERE id=10"
199 ).addCallback(cb_execute)
200
201
202 def test_execute_all(self):
203 """
204 Test a direct execute on the store, and the C{all} method of
205 C{DeferredResult}.
206 """
207 def cb_execute(result):
208 return result.get_all().addCallback(cb_result)
209 def cb_result(result):
210 self.assertEquals(result, [(u"Title 50",), (u"Title 40",)])
211 return self.store.execute("SELECT title FROM bar"
212 ).addCallback(cb_execute)
213
214
215 def test_remove(self):
216 """
217 Trying removing an object from the database.
218 """
219 def cb_get(result):
220 return self.store.remove(result).addCallback(cb_remove)
221 def cb_remove(ign):
222 return self.store.get(Foo, 10).addCallback(cb_get_after_remove)
223 def cb_get_after_remove(result):
224 self.assertIdentical(result, None)
225 return self.store.get(Foo, 10).addCallback(cb_get)
226
227
228 def test_find(self):
229 """
230 Try to find a list of objects using the store.
231 """
232 def cb_find(results):
233 return results.all().addCallback(cb_all)
234 def cb_all(results):
235 self.assertEquals(len(results), 2)
236 titles = [results[0].title, results[1].title]
237 titles.sort()
238 self.assertEquals(titles, [u"Title 40", u"Title 50"])
239 return self.store.find(Bar).addCallback(cb_find)
240
241
242 def test_find_first(self):
243 """
244 Try to get the first object matching a query.
245 """
246 def cb_find(results):
247 results.order_by(Bar.title)
248 return results.first().addCallback(cb_all)
249 def cb_all(result):
250 self.assertEquals(result.title, u"Title 40")
251 self.assertEquals(result.id, 11)
252 return self.store.find(Bar).addCallback(cb_find)
253
254
255 def test_find_last(self):
256 """
257 Try to get the last object matching a query.
258 """
259 def cb_find(results):
260 results.order_by(Bar.title)
261 return results.last().addCallback(cb_all)
262 def cb_all(result):
263 self.assertEquals(result.title, u"Title 50")
264 self.assertEquals(result.id, 10)
265 return self.store.find(Bar).addCallback(cb_find)
266
267
268 def test_find_any(self):
269 """
270 Try to get an object matching a query using the C{any} method.
271 """
272 def cb_find(results):
273 return results.any().addCallback(cb_all)
274 def cb_all(result):
275 self.assertEquals(result.title, u"Title 50")
276 self.assertEquals(result.id, 10)
277 return self.store.find(Bar).addCallback(cb_find)
278
279
280 def test_find_max(self):
281 """
282 Try to get the maximum of a value after a find.
283 """
284 def cb_find(results):
285 return results.max(Egg.value).addCallback(cb_all)
286 def cb_all(result):
287 self.assertEquals(result, 7)
288 return self.store.find(Egg).addCallback(cb_find)
289
290
291 def test_find_min(self):
292 """
293 Try to get the minimum of a value after a find.
294 """
295 def cb_find(results):
296 return results.min(Egg.value).addCallback(cb_all)
297 def cb_all(result):
298 self.assertEquals(result, 3)
299 return self.store.find(Egg).addCallback(cb_find)
300
301
302 def test_find_avg(self):
303 """
304 Try to get the average of a value after a find.
305 """
306 def cb_find(results):
307 return results.avg(Egg.value).addCallback(cb_all)
308 def cb_all(result):
309 self.assertEquals(result, 4.75)
310 return self.store.find(Egg).addCallback(cb_find)
311
312
313 def test_find_sum(self):
314 """
315 Try to get the sum of a value after a find.
316 """
317 def cb_find(results):
318 return results.sum(Egg.value).addCallback(cb_all)
319 def cb_all(result):
320 self.assertEquals(result, 19)
321 return self.store.find(Egg).addCallback(cb_find)
322
323
324 def test_find_count(self):
325 """
326 Try to get the count of a result after a find.
327 """
328 def cb_find(results):
329 return results.count().addCallback(cb_all)
330 def cb_all(result):
331 self.assertEquals(result, 2)
332 return self.store.find(Egg, Egg.value >= 5).addCallback(cb_find)
333
334
335 def test_find_remove(self):
336 """
337 Remove the result of a find query.
338 """
339 def cb_find(results):
340 return results.remove().addCallback(cb_remove)
341 def cb_remove(ignore):
342 return self.store.find(Egg).addCallback(cb_find_after_remove)
343 def cb_find_after_remove(results):
344 return results.all().addCallback(cb_all)
345 def cb_all(results):
346 self.assertEquals(len(results), 2)
347 return self.store.find(Egg, Egg.value >= 5).addCallback(cb_find)
348
349
350 def test_find_limit(self):
351 """
352 Put a limit on the number of results of a find.
353 """
354 def cb_find(results):
355 results.config(limit=3)
356 return results.all().addCallback(cb_all)
357 def cb_all(results):
358 self.assertEquals(len(results), 3)
359 return self.store.find(Egg).addCallback(cb_find)
360
361
362 def test_find_union(self):
363 """
364 Call C{union} on 2 differents C{DeferredResultSet}.
365 """
366 def cb_find(results):
367 result1, result2 = results
368 results = result1.union(result2._resultSet)
369 return results.all().addCallback(cb_all)
370 def cb_all(results):
371 self.assertEquals(len(results), 3)
372 d1 = self.store.find(Egg, Egg.value >= 5)
373 d2 = self.store.find(Egg, Egg.value == 3)
374 return gatherResults([d1, d2]).addCallback(cb_find)
375
376
377 def test_find_difference(self):
378 """
379 Call C{union} on 2 differents C{DeferredResultSet}.
380 """
381 if self.__class__.__name__.startswith("MySQL"):
382 return
383 def cb_find(results):
384 result1, result2 = results
385 results = result1.difference(result2._resultSet)
386 return results.all().addCallback(cb_all)
387 def cb_all(results):
388 self.assertEquals(len(results), 1)
389 self.assertEquals(results[0].value, 5)
390 d1 = self.store.find(Egg, Egg.value >= 5)
391 d2 = self.store.find(Egg, Egg.value == 7)
392 return gatherResults([d1, d2]).addCallback(cb_find)
393
394
395 def test_find_intersection(self):
396 """
397 Call C{intersection} on 2 differents C{DeferredResultSet}.
398 """
399 if self.__class__.__name__.startswith("MySQL"):
400 return
401 def cb_find(results):
402 result1, result2 = results
403 results = result1.intersection(result2._resultSet)
404 return results.all().addCallback(cb_all)
405 def cb_all(results):
406 self.assertEquals(len(results), 1)
407 self.assertEquals(results[0].value, 7)
408 d1 = self.store.find(Egg, Egg.value >= 5)
409 d2 = self.store.find(Egg, Egg.value == 7)
410 return gatherResults([d1, d2]).addCallback(cb_find)
411
412
413 def test_find_values(self):
414 """
415 Filter the fields returned by a find using the values method.
416 """
417 def cb_find(results):
418 return results.values(Bar.title).addCallback(cb_all)
419 def cb_all(titles):
420 titles.sort()
421 self.assertEquals(titles, [u"Title 40", u"Title 50"])
422 return self.store.find(Bar).addCallback(cb_find)
423
424
425 def test_find_and_set(self):
426 """
427 The C{set} method of a C{ResultSet} should update the specified fields
428 in a thread.
429 """
430 def cb_find(results):
431 return results.set(title=u"Title").addCallback(cb_set, results)
432 def cb_set(ignore, results):
433 return results.values(Bar.title).addCallback(cb_all)
434 def cb_all(titles):
435 titles.sort()
436 self.assertEquals(titles, [u"Title", u"Title"])
437 return self.store.find(Bar).addCallback(cb_find)
438
439
440 def test_find_offset(self):
441 """
442 Put an offset on the number of results of a find.
443 """
444 def cb_find(results):
445 results.config(offset=2)
446 return results.all().addCallback(cb_all)
447 def cb_all(results):
448 self.assertEquals(len(results), 2)
449 return self.store.find(Egg).addCallback(cb_find)
450
451
452 def test_find_offset_limit(self):
453 """
454 Put an offset and limit in the number of results of a find.
455 """
456 def cb_find(results):
457 results.config(offset=1, limit=2)
458 return results.all().addCallback(cb_all)
459 def cb_all(results):
460 self.assertEquals(len(results), 2)
461 return self.store.find(Egg).addCallback(cb_find)
462
463
464 @deferredGenerator
465 def test_find_defgen(self):
466 """
467 Do a find, add an object, then do another find: this to ensure that the
468 connection remains in the dedicated thread.
469 """
470 d = self.store.find(Bar)
471 wfd = waitForDeferred(d)
472 yield wfd
473 results = wfd.getResult()
474 d = results.all()
475 wfd = waitForDeferred(d)
476 yield wfd
477 wfd.getResult()
478 foo = Foo()
479 foo.title = u"Great title"
480 foo.id = 11
481 d = self.store.add(foo)
482 wfd = waitForDeferred(d)
483 yield wfd
484 wfd.getResult()
485 d = self.store.find(Foo)
486 wfd = waitForDeferred(d)
487 yield wfd
488 results = wfd.getResult()
489
490
491 def test_find_order_by(self):
492 """
493 Try to find a list of objects using the store, then order the result
494 set.
495 """
496 def cb_find(results):
497 results.order_by(Bar.title)
498 return results.all().addCallback(cb_all)
499 def cb_all(results):
500 self.assertEquals(len(results), 2)
501 titles = [results[0].title, results[1].title]
502 self.assertEquals(titles, [u"Title 40", u"Title 50"])
503 return self.store.find(Bar).addCallback(cb_find)
504
505
506 def test_find_and_rollback(self):
507 """
508 Accessing an object outside of a transaction fails because the object
509 hasn't been resolved yet.
510 """
511 def cb_find(results):
512 results.order_by(Bar.title)
513 return results.all().addCallback(cb_all)
514 def cb_all(results):
515 return self.store.rollback().addCallback(cbRollback, results)
516 def cbRollback(ign, results):
517 self.assertEquals(len(results), 2)
518 self.assertRaises(RuntimeError, getattr, results[0], "title")
519 return self.store.find(Bar).addCallback(cb_find)
520
521
522 def test_find_is_empty(self):
523 """
524 DeferredReference.is_empty returns a Deferred that fires with True or
525 False depending if the matched result set is empty or not.
526 """
527 def cb_find(results):
528 return results.is_empty().addCallback(self.assertEquals, False)
529
530 return self.store.find(Bar).addCallback(cb_find)
531
532
533 def test_find_group_by(self):
534 """
535 DeferredReference.group_by is a simple wrapper to the group_by method
536 of the reference set.
537 """
538 def cb_find(results):
539 results.group_by(Bar.foo_id)
540 return results.all().addCallback(check)
541
542 def check(result):
543 self.assertEquals(result, [(2, 10)])
544
545 return self.store.find((Count(Bar.id), Bar.foo_id)
546 ).addCallback(cb_find)
547
548
549 def test_find_having(self):
550 """
551 DeferredReference.having is a simple wrapper to the having method of
552 the reference set.
553 """
554 connection = self.database.connect()
555 connection.execute("INSERT INTO egg VALUES (5, 7)")
556 connection.commit()
557
558 def cb_find(results):
559 results.group_by(Egg.value)
560 results.having(Egg.value >= 5)
561 results.order_by(Egg.value)
562 return results.all().addCallback(check)
563
564 def check(result):
565 self.assertEquals(result, [(1, 5), (2, 7)])
566
567 return self.store.find((Count(Egg.id), Egg.value)
568 ).addCallback(cb_find)
569
570
571 def test_reference(self):
572 """
573 Trying to get a reference of an object using C{DeferredReference}.
574 """
575 def cb_getBar(result):
576 return result.foo.addCallback(cb_getFoo)
577 def cb_getFoo(fooResult):
578 return self.store.get(Foo, 10).addCallback(cb_getFooBar, fooResult)
579 def cb_getFooBar(result, fooResult):
580 self.assertIdentical(fooResult, result)
581 # The result should be valid too
582 self.assertEquals(fooResult.title, u"Title 30")
583 return self.store.get(Bar, 10).addCallback(cb_getBar)
584
585
586 def test_reference_setting(self):
587 """
588 Try to set a reference of an object.
589 """
590 connection = self.database.connect()
591 connection.execute("INSERT INTO foo VALUES (20, 'Title 20')")
592 connection.commit()
593 def cb_getBar(result):
594 return self.store.get(Foo, 20).addCallback(cb_getFooBar, result)
595 def cb_getFooBar(result, barResult):
596 self.assertRaises(RuntimeError, setattr, barResult, "foo", result)
597 return self.store.get(Bar, 10).addCallback(cb_getBar)
598
599
600 def test_reference_set_unordered(self):
601 """
602 Get a reference set and call various wrapped methods on it.
603 """
604 # find test
605 def cb_find(results):
606 return results.all().addCallback(cb_all)
607
608 def cb_all(results):
609 self.assertEquals(len(results), 2)
610 titles = [results[0].title, results[1].title]
611 titles.sort()
612 self.assertEquals(titles, [u"Title 40", u"Title 50"])
613
614 def cb_any(result):
615 self.assertTrue(result)
616
617 def cb_values(titles):
618 titles.sort()
619 self.assertEquals(titles, [u"Title 40", u"Title 50"])
620
621 def do_tests(results):
622 results = results.bars
623 dfrs = [
624 results.find().addCallback(cb_find),
625 results.any().addCallback(cb_any),
626 results.values(Bar.title).addCallback(cb_values),
627 ]
628 return DeferredList(dfrs)
629
630 return self.store.get(FooRefSet, 10).addCallback(do_tests)
631
632
633 def test_reference_set_ordered(self):
634 """
635 A DeferredReferenceSet has a order_by method which returns a Deferred
636 firing when the reference set is ordered.
637 """
638 def do_tests(result):
639 dfrs = [
640 result.first().addCallback(lambda t:
641 self.assertEquals(t.title, u"Title 40")),
642 result.last().addCallback(lambda t:
643 self.assertEquals(t.title, u"Title 50")),
644 ]
645 return DeferredList(dfrs)
646
647 def order(results):
648 dfr = results.bars.order_by("title")
649 return dfr.addCallback(do_tests)
650
651 return self.store.get(FooRefSet, 10).addCallback(order)
652
653
654 def test_reference_set_add_remove(self):
655 """
656 A DeferredReferenceSet has a add method with returns a Deferred once
657 the object has been added.
658 Try to add things from the reference set async.
659 """
660 def add_one(result):
661 bar = Bar()
662 bar.title = u"Yeah"
663 return result.bars.add(bar).addCallback(remove_one, result, bar)
664
665 def remove_one(add_result, result, bar):
666 return result.bars.remove(bar).addCallback(get_all, result)
667
668 def get_all(ignore, result):
669 return result.bars.all().addCallback(check)
670
671 def check(result):
672 self.assertEquals(len(result), 2)
673
674 return self.store.get(FooRefSet, 10).addCallback(add_one)
675
676
677 def test_reference_set_clear(self):
678 """
679 A DeferredReferenceSet has a clear method which removes all elements
680 from the reference set and fires the returned Deferred when done.
681 """
682 def first_cb(result):
683 refs = result.bars
684 return check_count(refs, 2).addCallback(clear_cb, refs)
685
686 def check_count(ref_set, num):
687 return ref_set.count().addCallback(self.assertEquals, num)
688
689 def clear_cb(result, refs):
690 return refs.clear().addCallback(lambda x: check_count(refs, 0))
691
692 return self.store.get(FooRefSet, 10).addCallback(first_cb)
693
694
695 def test_reference_set_one(self):
696 """
697 Call C{one} on a C{DeferredBoundReference}.
698 """
699 connection = self.database.connect()
700 connection.execute("INSERT INTO foo VALUES (11, 'Title 40')")
701 connection.execute("INSERT INTO bar VALUES (20, 11, 'Title 50')")
702 connection.commit()
703 def cb_get(result):
704 return result.bars.one().addCallback(cb_one)
705 def cb_one(result):
706 return self.store.get(Bar, 20).addCallback(check, result)
707 def check(result, expected):
708 self.assertIdentical(result, expected)
709 return self.store.get(FooRefSet, 11).addCallback(cb_get)
710
711
712 def test_reference_set_first(self):
713 """
714 Call C{first} on an ordered C{DeferredBoundReference}.
715 """
716 def cb_get(result):
717 return result.bars.first().addCallback(cb_one)
718 def cb_one(result):
719 return self.store.get(Bar, 10).addCallback(check, result)
720 def check(result, expected):
721 self.assertIdentical(result, expected)
722 return self.store.get(FooRefSetOrderID, 10).addCallback(cb_get)
723
724
725 def test_reference_set_last(self):
726 """
727 Call C{last} on an ordered C{DeferredBoundReference}.
728 """
729 def cb_get(result):
730 return result.bars.last().addCallback(cb_one)
731 def cb_one(result):
732 return self.store.get(Bar, 11).addCallback(check, result)
733 def check(result, expected):
734 self.assertIdentical(result, expected)
735 return self.store.get(FooRefSetOrderID, 10).addCallback(cb_get)
736
737
738 def test_commit(self):
739 """
740 Make some changes and commit them.
741 """
742 def cb_get(result):
743 return self.store.remove(result).addCallback(cb_remove)
744 def cb_remove(ign):
745 return self.store.commit().addCallback(cb_commit)
746 def cb_commit(ign):
747 # To be sure the data is no more in the db, the best is to
748 # directly connect to the db
749 connection = self.database.connect()
750 result = connection.execute("SELECT * FROM foo")
751 self.assertEquals(list(result), [])
752 return self.store.get(Foo, 10).addCallback(cb_get)
753
754
755 def test_rollback(self):
756 """
757 Make and some changes and rollback them.
758 """
759 def cb_get(result):
760 return self.store.remove(result).addCallback(cb_remove)
761 def cb_remove(ign):
762 return self.store.rollback().addCallback(cbRollback)
763 def cbRollback(ign):
764 connection = self.database.connect()
765 result = connection.execute("SELECT * FROM foo")
766 self.assertEquals(list(result), [(10, u"Title 30")])
767 return self.store.get(Foo, 10).addCallback(cb_get)
768
769
770 def test_deferred_reference_multithread(self):
771 """
772 If a store is restarted, the objects in the cache should still be
773 usable, in particular an object shouldn't not store a reference to the
774 store thread, as it can change.
775 """
776 def test(ignore):
777 # get a bar and retrieve the deferred reference
778 def get_foo(bar):
779 return bar.foo
780
781 return self.store.find(Bar, Bar.id == 10).addCallback(lambda x:
782 x.one()).addCallback(get_foo)
783
784 def _restart_store(res):
785 def stopped(res):
786 self.store = DeferredStore(self.database)
787 return self.store.start()
788 return self.store.stop().addCallback(stopped)
789
790 def check(foo):
791 self.assertEquals(foo.id, 10)
792 self.assertEquals(foo.title, u"Title 30")
793
794 return test(None
795 ).addCallback(_restart_store
796 ).addCallback(test
797 ).addCallback(check)
798
799
800 def test_of(self):
801 """
802 The DeferredStore associated with an object is returned by the static
803 C{of} method.
804 """
805 def cb_get(result):
806 store = DeferredStore.of(result)
807 self.assertIdentical(self.store, store)
808
809 return self.store.get(Foo, 10).addCallback(cb_get)
810
811
812 def test_thread_check(self):
813 """
814 A L{ThreadSafetyError} is raised when attempting to do an unsafe
815 operation, like accessing a C{Reference} attribute via a
816 C{DeferredStore}.
817 """
818 class WrongBar(Bar):
819 foo_sync = Reference(Bar.foo_id, Foo.id)
820 d = self.store.find(WrongBar).addCallback(lambda result: result.all())
821
822 def check(results):
823 self.assertRaises(
824 ThreadSafetyError, getattr, results[0], "foo_sync")
825 return d.addCallback(check)
826
827
828
829class StoreThreadTestCase(TestCase):
830 """
831 Tests for L{StoreThread}.
832 """
833
834 def setUp(self):
835 """
836 Create an instance of C{StoreThread} and start it.
837 """
838 self.thread = StoreThread()
839 self.thread.start()
840
841
842 def tearDown(self):
843 """
844 Kill the running thread.
845 """
846 self.thread.stop()
847
848
849 def test_defer_after_stop(self):
850 """
851 Deferring calls after store is stopped raises C{AlreadyStopped}.
852 """
853 def cb_stop(r):
854 self.assertFailure(self.thread.defer_to_thread(lambda f : None),
855 AlreadyStopped)
856 return self.thread.stop().addCallback(cb_stop)
857
858
859 def test_callback(self):
860 """
861 Fire a simple function in a thread and check its result.
862 """
863 def testfunc():
864 return 1
865 return self.thread.defer_to_thread(testfunc
866 ).addCallback(self.assertEquals, 1)
867
868
869 def test_errback(self):
870 """
871 Raising an exception in a thread returns a failure.
872 """
873 def testfunc():
874 raise RuntimeError("Error!")
875 return self.assertFailure(self.thread.defer_to_thread(testfunc),
876 RuntimeError)
877
878
879
880class StorePoolTest(object):
881 """
882 Tests for L{StorePool}.
883 """
884
885 def setUp(self):
886 """
887 Build a database with data, a create a pool.
888 """
889 self.create_database()
890 connection = self.connection = self.database.connect()
891 self.drop_tables()
892 self.create_tables()
893 connection.execute("INSERT INTO foo VALUES (10, 'Title 30')")
894 connection.execute("INSERT INTO bar VALUES (10, 10, 'Title 40')")
895 connection.execute("INSERT INTO bar VALUES (11, 10, 'Title 50')")
896 connection.commit()
897 self.pool = StorePool(self.database, 2, 5)
898 return self.pool.start()
899
900
901 def tearDown(self):
902 """
903 Stop the pool.
904 """
905 def _drop(ign):
906 self.drop_tables()
907 return self.pool.stop().addCallback(_drop)
908
909
910 def drop_tables(self):
911 for table in ["foo", "bar"]:
912 try:
913 self.connection.execute("DROP TABLE %s" % table)
914 self.connection.commit()
915 except:
916 self.connection.rollback()
917
918
919 def test_already_started(self):
920 """
921 Check that the pool can't be restarted multiple times.
922 """
923 self.assertRaises(RuntimeError, self.pool.start)
924
925
926 def test_get(self):
927 """
928 get should return different stores if available.
929 """
930 def cb_get1(store1):
931 return self.pool.get().addCallback(cb_get2, store1)
932 def cb_get2(store2, store1):
933 self.assertNotIdentical(store1, store2)
934 self.assertTrue(store1.started)
935 self.assertTrue(store2.started)
936 return self.pool.get().addCallback(cb_get1)
937
938
939 def test_get_not_started(self):
940 """
941 If no store are available, the pool should create a store.
942 """
943 def cb(ign):
944 self.assertEquals(self.pool._stores_created, 0)
945 return self.pool.adjust_size(0, 5).addCallback(cb_add)
946 def cb_add(ign):
947 return self.pool.get().addCallback(cb_get)
948 def cb_get(store):
949 self.assertTrue(store.started)
950 return self.pool.adjust_size(0, 0).addCallback(cb)
951
952
953 def test_waiting_for_store(self):
954 """
955 Test waiting for a store availability.
956 """
957 def cb(ign):
958 return self.pool.get().addCallback(cb_get1)
959 def cb_get1(store1):
960 self.assertTrue(store1.started)
961 # Now we have a store, no store should be returned by the pool
962 # until we put it back
963 d1 = self.pool.get().addCallback(cb_get2, store1)
964 d2 = self.pool.put(store1)
965 return gatherResults([d1, d2])
966 def cb_get2(store2, store1):
967 self.assertIdentical(store1, store2)
968 return self.pool.adjust_size(1, 1).addCallback(cb)
969
970
971 def test_adjust_size_minmax(self):
972 """
973 Test sanity check on min/max - i.e. min <= max.
974 """
975 return self.assertFailure(self.pool.adjust_size(2, 1), ValueError)
976
977
978 def test_adjust_size_nonnegative(self):
979 """
980 Test sanity check for nonnegative min.
981 """
982 return self.assertFailure(self.pool.adjust_size(-1), ValueError)
983
984
985 @deferredGenerator
986 def test_concurrent_data(self):
987 """
988 Test that different stores have different states: if the first store
989 hasn't yet committed, the second one shouldn't get the new data.
990 """
991 foo = Foo()
992 foo.title = u"Great title"
993 foo.id = 11
994 d = self.pool.get()
995 wfd = waitForDeferred(d)
996 yield wfd
997 store1 = wfd.getResult()
998
999 d = self.pool.get()
1000 wfd = waitForDeferred(d)
1001 yield wfd
1002 store2 = wfd.getResult()
1003
1004 d = store1.add(foo)
1005 wfd = waitForDeferred(d)
1006 yield wfd
1007 wfd.getResult()
1008
1009 d = store1.get(Foo, 11)
1010 wfd = waitForDeferred(d)
1011 yield wfd
1012 foo2 = wfd.getResult()
1013
1014 # The object is already in the store cache
1015 self.assertIdentical(foo2, foo)
1016
1017 d = store2.get(Foo, 11)
1018 wfd = waitForDeferred(d)
1019 yield wfd
1020 foo3 = wfd.getResult()
1021
1022 # The object isn't in the db yet
1023 self.assertIdentical(foo3, None)
1024
1025 # Let's rollback, because even select open a transaction
1026 d = store2.rollback()
1027 wfd = waitForDeferred(d)
1028 yield wfd
1029 wfd.getResult()
1030
1031 # Let's commit
1032 d = store1.commit()
1033 wfd = waitForDeferred(d)
1034 yield wfd
1035 wfd.getResult()
1036
1037 d = store2.get(Foo, 11)
1038 wfd = waitForDeferred(d)
1039 yield wfd
1040 foo4 = wfd.getResult()
1041
1042 # The objects must be different
1043 self.assertNotIdentical(foo4, foo)
1044 # But the content must be the same
1045 self.assertEquals(foo4.title, u"Great title")
1046
1047
1048 def test_no_overflow(self):
1049 """
1050 Test that pool does not allocate more connections than store_max.
1051 """
1052 ds = []
1053 stores = set()
1054
1055 def cb_get(store):
1056 stores.add(store)
1057 return self.pool.put(store)
1058
1059 for i in range(10):
1060 ds.append(self.pool.get().addCallback(cb_get))
1061
1062 def checkInstances(result):
1063 self.assertEquals(len(stores), 5)
1064
1065 return gatherResults(ds).addCallback(checkInstances)
1066
1067
1068 def test_start_failure(self):
1069 """
1070 If a store failed to start, the number of allocated connections doesn't
1071 grow, so we're later able to start more stores.
1072 """
1073 ds = []
1074 stores = set()
1075
1076 class DontStartStore(DeferredStore):
1077 def start(self):
1078 return fail(RuntimeError("oops"))
1079
1080 calls = []
1081
1082 def vicious_store_factory(database):
1083 if not calls:
1084 store = DontStartStore(database)
1085 else:
1086 store = DeferredStore(database)
1087 calls.append(None)
1088 return store
1089
1090 self.pool.store_factory = vicious_store_factory
1091
1092 def cb_get(store):
1093 stores.add(store)
1094 return self.pool.put(store)
1095
1096 errors = []
1097
1098 def save_errors(failure):
1099 errors.append(failure)
1100
1101 for i in range(6):
1102 ds.append(
1103 self.pool.get().addCallback(cb_get).addErrback(save_errors))
1104
1105 def checkInstances(result):
1106 self.assertEquals(len(stores), 5)
1107 self.assertEquals(len(errors), 1)
1108 errors[0].trap(RuntimeError)
1109
1110 return gatherResults(ds).addCallback(checkInstances)
1111
1112
1113 def test_arguments(self):
1114 """
1115 Arguments are passed along to transacted method when store
1116 is available.
1117 """
1118 def tx(store, a, b=None):
1119 self.assertIsInstance(store, DeferredStore)
1120 self.assertEquals(1, a)
1121 self.assertEquals(2, b)
1122 return succeed("ok")
1123 return self.pool.transact(tx, 1, b=2)
1124
1125
1126 def test_commit(self):
1127 """
1128 Changes made inside a successful transaction are committed.
1129 """
1130 @deferredGenerator
1131 def check(result):
1132 d = self.pool.get()
1133 wfd = waitForDeferred(d)
1134 yield wfd
1135 store = wfd.getResult()
1136 d = store.execute("SELECT * FROM foo ORDER BY id", [])
1137 wfd = waitForDeferred(d)
1138 yield wfd
1139 dr = wfd.getResult()
1140 d = dr.get_all()
1141 wfd = waitForDeferred(d)
1142 yield wfd
1143 results = wfd.getResult()
1144 self.assertEquals(2, len(results))
1145 self.assertEquals(1, results[0][0])
1146 self.assertEquals("test", results[0][1])
1147
1148 def tx(store):
1149 d = store.execute("INSERT INTO foo(id, title) "
1150 "VALUES (1, 'test')", noresult=True)
1151 return d
1152
1153 return self.pool.transact(tx).addCallback(check)
1154
1155
1156 def test_rollback(self):
1157 """
1158 Changes made inside a failed transaction are rolled back.
1159 """
1160 @deferredGenerator
1161 def check(reason):
1162 d = self.pool.get()
1163 wfd = waitForDeferred(d)
1164 yield wfd
1165 store = wfd.getResult()
1166 d = store.execute("SELECT * FROM foo", [])
1167 wfd = waitForDeferred(d)
1168 yield wfd
1169 dr = wfd.getResult()
1170 d = dr.get_all()
1171 wfd = waitForDeferred(d)
1172 yield wfd
1173 results = wfd.getResult()
1174 self.assertEquals(1, len(results))
1175
1176 @deferredGenerator
1177 def tx(store):
1178 d = store.execute("INSERT INTO foo(id, title) "
1179 "VALUES (1, 'test')", [])
1180 wfd = waitForDeferred(d)
1181 yield wfd
1182 wfd.getResult()
1183 d = store.execute("INSERT INTO foo(id, title) "
1184 "VALUES (1, 'test')", [])
1185 wfd = waitForDeferred(d)
1186 yield wfd
1187 wfd.getResult()
1188
1189 return self.pool.transact(tx).addBoth(check)
1190
1191
1192 def test_return_value(self):
1193 """
1194 Final return value should match return value from successful call to
1195 transacted function.
1196 """
1197 def check(result):
1198 self.assertEquals("completed", result)
1199
1200 def cb(result):
1201 return "completed"
1202
1203 def tx(store):
1204 return store.execute("INSERT INTO foo(id, title) "
1205 "VALUES (1, 'test')", []
1206 ).addCallback(cb)
1207
1208 return self.pool.transact(tx).addCallback(check)
1209
1210
1211 def test_poolsize_after_success(self):
1212 """
1213 After successful transaction, pool size should be same size as before.
1214 """
1215 size = len(self.pool._stores)
1216
1217 def check(result):
1218 self.assertEquals(size, len(self.pool._stores))
1219
1220 def tx(store):
1221 d = store.execute("SELECT * from foo", [])
1222 return d.addCallback(lambda result: result.get_all())
1223
1224 return self.pool.transact(tx).addCallback(check)
1225
1226
1227 def test_poolsize_after_failure(self):
1228 """
1229 After failed transaction, pool size is restored to the initial value.
1230 """
1231 size = len(self.pool._stores)
1232
1233 def check(reason):
1234 self.assertEquals(size, len(self.pool._stores))
1235
1236 def tx(store):
1237 return store.execute("SELECT * from not_a_table", [])
1238
1239 d = self.assertFailure(self.pool.transact(tx), OperationalError)
1240 return d.addCallback(check)
1241
1242
1243 def test_failure_propagation(self):
1244 """
1245 A custom exception is propagated by a C{transact} call.
1246 """
1247 class MyException(Exception):
1248 pass
1249
1250 def tx(store):
1251 raise MyException("Bad things happened")
1252
1253 return self.assertFailure(self.pool.transact(tx), MyException)
1254
1255
1256 def test_non_deferred_function(self):
1257 """
1258 C{transact} can handle functions that don't return a C{Deferred}.
1259 """
1260 def tx(store):
1261 return "foo"
1262 return self.pool.transact(tx).addCallback(self.assertEquals, "foo")
1263
1264
1265 def test_rollback_failure(self):
1266 """
1267 If C{rollback} fails, the store is put back into the pool.
1268 """
1269
1270 def cb_get(store):
1271 store.rollback = lambda: fail(RuntimeError("oops"))
1272 d = self.assertFailure(self.pool.put(store), RuntimeError)
1273 return d.addCallback(get_all)
1274
1275 def get_all(ignore):
1276 dl = []
1277 for i in range(5):
1278 dl.append(self.pool.get())
1279 return gatherResults(dl).addCallback(check)
1280
1281 def check(stores):
1282 # The fact that we're here show that the test succeeds, because we
1283 # didn't hang waiting for a store
1284 self.assertEquals(len(stores), 5)
1285
1286 return self.pool.get().addCallback(cb_get)
01287
=== added file 'tests/twisted/mysql.py'
--- tests/twisted/mysql.py 1970-01-01 00:00:00 +0000
+++ tests/twisted/mysql.py 2009-12-27 11:31:15 +0000
@@ -0,0 +1,62 @@
1import os
2
3from storm.database import create_database
4
5from tests.twisted.base import DeferredStoreTest, StorePoolTest
6
7from twisted.trial.unittest import TestCase
8
9
10class MySQLDeferredStoreTest(TestCase, DeferredStoreTest):
11
12 def setUp(self):
13 return DeferredStoreTest.setUp(self)
14
15 def tearDown(self):
16 return DeferredStoreTest.tearDown(self)
17
18 def is_supported(self):
19 return bool(os.environ.get("STORM_MYSQL_URI"))
20
21 def create_database(self):
22 self.database = create_database(os.environ["STORM_MYSQL_URI"])
23
24 def create_tables(self):
25 connection = self.connection
26 connection.execute("CREATE TABLE foo "
27 "(id INT PRIMARY KEY AUTO_INCREMENT,"
28 " title VARCHAR(50) DEFAULT 'Default Title')"
29 " ENGINE=InnoDB")
30 connection.execute("CREATE TABLE bar "
31 "(id INT PRIMARY KEY AUTO_INCREMENT,"
32 " foo_id INTEGER, title VARCHAR(50))"
33 " ENGINE=InnoDB")
34 connection.execute("CREATE TABLE egg "
35 "(id INT PRIMARY KEY AUTO_INCREMENT, value INTEGER)"
36 " ENGINE=InnoDB")
37
38
39class MySQLStorePoolTest(TestCase, StorePoolTest):
40
41 def setUp(self):
42 return StorePoolTest.setUp(self)
43
44 def tearDown(self):
45 return StorePoolTest.tearDown(self)
46
47 def is_supported(self):
48 return bool(os.environ.get("STORM_MYSQL_URI"))
49
50 def create_database(self):
51 self.database = create_database(os.environ["STORM_MYSQL_URI"])
52
53 def create_tables(self):
54 connection = self.connection
55 connection.execute("CREATE TABLE foo "
56 "(id INT PRIMARY KEY AUTO_INCREMENT,"
57 " title VARCHAR(50) DEFAULT 'Default Title')"
58 " ENGINE=InnoDB")
59 connection.execute("CREATE TABLE bar "
60 "(id INT PRIMARY KEY AUTO_INCREMENT,"
61 " foo_id INTEGER, title VARCHAR(50))"
62 " ENGINE=InnoDB")
063
=== added file 'tests/twisted/postgres.py'
--- tests/twisted/postgres.py 1970-01-01 00:00:00 +0000
+++ tests/twisted/postgres.py 2009-12-27 11:31:15 +0000
@@ -0,0 +1,57 @@
1import os
2
3from storm.database import create_database
4
5from tests.twisted.base import DeferredStoreTest, StorePoolTest
6
7from twisted.trial.unittest import TestCase
8
9
10class PostgresDeferredStoreTest(TestCase, DeferredStoreTest):
11
12 def setUp(self):
13 return DeferredStoreTest.setUp(self)
14
15 def tearDown(self):
16 return DeferredStoreTest.tearDown(self)
17
18 def is_supported(self):
19 return bool(os.environ.get("STORM_POSTGRES_URI"))
20
21 def create_database(self):
22 self.database = create_database(os.environ["STORM_POSTGRES_URI"])
23
24 def create_tables(self):
25 connection = self.connection
26 connection.execute("CREATE TABLE foo "
27 "(id SERIAL PRIMARY KEY,"
28 " title VARCHAR DEFAULT 'Default Title')")
29 connection.execute("CREATE TABLE bar "
30 "(id SERIAL PRIMARY KEY,"
31 " foo_id INTEGER, title VARCHAR)")
32 connection.execute("CREATE TABLE egg "
33 "(id SERIAL PRIMARY KEY, value INTEGER)")
34
35
36class PostgresStorePoolTest(TestCase, StorePoolTest):
37
38 def setUp(self):
39 return StorePoolTest.setUp(self)
40
41 def tearDown(self):
42 return StorePoolTest.tearDown(self)
43
44 def is_supported(self):
45 return bool(os.environ.get("STORM_POSTGRES_URI"))
46
47 def create_database(self):
48 self.database = create_database(os.environ["STORM_POSTGRES_URI"])
49
50 def create_tables(self):
51 connection = self.connection
52 connection.execute("CREATE TABLE foo "
53 "(id SERIAL PRIMARY KEY,"
54 " title VARCHAR DEFAULT 'Default Title')")
55 connection.execute("CREATE TABLE bar "
56 "(id SERIAL PRIMARY KEY,"
57 " foo_id INTEGER, title VARCHAR)")
058
=== added file 'tests/twisted/sqlite.py'
--- tests/twisted/sqlite.py 1970-01-01 00:00:00 +0000
+++ tests/twisted/sqlite.py 2009-12-27 11:31:15 +0000
@@ -0,0 +1,65 @@
1from storm.databases.sqlite import SQLite
2from storm.uri import URI
3
4from tests.twisted.base import DeferredStoreTest, StorePoolTest
5from tests.helper import TestHelper, MakePath
6
7from twisted.trial.unittest import TestCase
8
9
10class SQLiteDeferredStoreTest(TestCase, TestHelper, DeferredStoreTest):
11
12 helpers = [MakePath]
13
14 def setUp(self):
15 TestHelper.setUp(self)
16 return DeferredStoreTest.setUp(self)
17
18 def tearDown(self):
19 def cb(passthrough):
20 TestHelper.tearDown(self)
21 return passthrough
22 return DeferredStoreTest.tearDown(self).addBoth(cb)
23
24 def create_database(self):
25 self.database = SQLite(URI("sqlite:%s?synchronous=OFF" %
26 self.make_path()))
27
28 def create_tables(self):
29 connection = self.connection
30 connection.execute("CREATE TABLE foo "
31 "(id INTEGER PRIMARY KEY,"
32 " title VARCHAR DEFAULT 'Default Title')")
33 connection.execute("CREATE TABLE bar "
34 "(id INTEGER PRIMARY KEY,"
35 " foo_id INTEGER, title VARCHAR)")
36 connection.execute("CREATE TABLE egg "
37 "(id INTEGER PRIMARY KEY, value INTEGER)")
38
39
40class SQLiteStorePoolTest(TestCase, TestHelper, StorePoolTest):
41
42 helpers = [MakePath]
43
44 def setUp(self):
45 TestHelper.setUp(self)
46 return StorePoolTest.setUp(self)
47
48 def tearDown(self):
49 def cb(passthrough):
50 TestHelper.tearDown(self)
51 return passthrough
52 return StorePoolTest.tearDown(self).addBoth(cb)
53
54 def create_database(self):
55 self.database = SQLite(URI("sqlite:%s?synchronous=OFF" %
56 self.make_path()))
57
58 def create_tables(self):
59 connection = self.connection
60 connection.execute("CREATE TABLE foo "
61 "(id INTEGER PRIMARY KEY,"
62 " title VARCHAR DEFAULT 'Default Title')")
63 connection.execute("CREATE TABLE bar "
64 "(id INTEGER PRIMARY KEY,"
65 " foo_id INTEGER, title VARCHAR)")