Merge lp:~therve/storm/twisted-integration into lp:storm

Proposed by Christopher Armstrong
Status: Work in progress
Proposed branch: lp:~therve/storm/twisted-integration
Merge into: lp:storm
Diff against target: 2287 lines (+2201/-1)
9 files modified
storm/database.py (+18/-1)
storm/exceptions.py (+4/-0)
storm/twisted/__init__.py (+23/-0)
storm/twisted/store.py (+480/-0)
storm/twisted/wrapper.py (+206/-0)
tests/twisted/base.py (+1286/-0)
tests/twisted/mysql.py (+62/-0)
tests/twisted/postgres.py (+57/-0)
tests/twisted/sqlite.py (+65/-0)
To merge this branch: bzr merge lp:~therve/storm/twisted-integration
Reviewer Review Type Date Requested Status
Christopher Armstrong (community) Abstain
Review via email: mp+3733@code.launchpad.net
To post a comment you must log in.
Revision history for this message
Christopher Armstrong (radix) wrote :

[1] Why setDaemon(True)

[2] can you use underscores for DeferredStore

[3] doc for get explaining that it loads all attributes

[4] It may be useful to extend the core Store API to allow us to avoid accessing internal attributes in the wrapper. Like, for example, Store.disable_lazy_loading(). I dunno.

[5] global mapping of storm store to deferred store? maybe?

[6] Make it so that objects that came from stores which were then put back into a pool are "broken" so that any attribute access breaks in the future.

review: Abstain
Revision history for this message
Gabriel (gabriel-rossetti) wrote :

Christopher Armstrong wrote:
> Christopher Armstrong has proposed merging lp:~therve/storm/twisted-integration into lp:storm.
>
> Requested reviews:
> Storm Developers (storm)
>
I second that :-)

Gabriel

Revision history for this message
Christopher Armstrong (radix) wrote :

[7] Expose reload to DeferredStore

[8] Maybe add deferredStore.commit(reload=True)

review: Abstain
Revision history for this message
Drew Smathers (djfroofy) wrote :

> [4] It may be useful to extend the core Store API to allow us to avoid accessing internal attributes in the wrapper. Like, for example, Store.disable_lazy_loading(). I dunno.

That sounds like a great idea to me. Otherwise having to refetch objects just to avoid lazy value errors in less than ideal.

lp:~therve/storm/twisted-integration updated
215. By Thomas Herve

Merge from trunk.

216. By Thomas Herve

Move the tests to be able to run Postgres/MySQL tests as well

217. By Thomas Herve

Add postgres tests

218. By Thomas Herve

Add mysql, some tests failing.

219. By Thomas Herve

Use InnoDB to be able to use transactions.

220. By Thomas Herve

Merge from trunk.

221. By Thomas Herve

Force respect of the max_stores value, by incrementing the number of stores
once start_store is called, instead of doing it in the callback [f=373816]

222. By Thomas Herve

Use TestHelper/MakePath to manage sqlite file.

223. By Thomas Herve

Fix doc typos

224. By Thomas Herve

Merge from trunk

225. By Thomas Herve

Some tests cleanup

226. By Thomas Herve

Merge from trunk

227. By Thomas Herve

Cleanups

228. By Thomas Herve

Check thread in Connection to catch problems.

229. By Thomas Herve

Handle errors in rollback

230. By Thomas Herve

Typo, cleanup

231. By Thomas Herve

Make start/stop methods private

Unmerged revisions

231. By Thomas Herve

Make start/stop methods private

230. By Thomas Herve

Typo, cleanup

229. By Thomas Herve

Handle errors in rollback

228. By Thomas Herve

Check thread in Connection to catch problems.

227. By Thomas Herve

Cleanups

226. By Thomas Herve

Merge from trunk

225. By Thomas Herve

Some tests cleanup

224. By Thomas Herve

Merge from trunk

223. By Thomas Herve

Fix doc typos

222. By Thomas Herve

Use TestHelper/MakePath to manage sqlite file.

Preview Diff

[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1=== modified file 'storm/database.py'
2--- storm/database.py 2009-07-30 06:19:27 +0000
3+++ storm/database.py 2009-12-27 11:31:14 +0000
4@@ -24,12 +24,13 @@
5 This is the common code for database support; specific databases are
6 supported in modules in L{storm.databases}.
7 """
8+import threading
9
10 from storm.expr import Expr, State, compile
11 from storm.tracer import trace
12 from storm.variables import Variable
13 from storm.exceptions import (
14- ClosedError, DatabaseError, DisconnectionError, Error)
15+ ClosedError, DatabaseError, DisconnectionError, Error, ThreadSafetyError)
16 from storm.uri import URI
17 import storm
18
19@@ -180,6 +181,20 @@
20 self._database = database # Ensures deallocation order.
21 self._event = event
22 self._raw_connection = self._database.raw_connect()
23+ self._thread = threading.currentThread()
24+
25+ def _check_thread(self):
26+ """
27+ Check if the current thread is the same thread that created the
28+ connection. This is a safety check that prevents breaking thread
29+ boundaries which can create weird bugs. C{execute}, C{commit} and
30+ C{rollback} should call it, directly or indirectly.
31+ """
32+ thread = threading.currentThread()
33+ if thread is not self._thread:
34+ raise ThreadSafetyError(
35+ "'%s' is not the connection thread '%s'" %
36+ (thread, self._thread))
37
38 def __del__(self):
39 """Close the connection."""
40@@ -240,6 +255,7 @@
41
42 def rollback(self):
43 """Rollback the connection."""
44+ self._check_thread()
45 if self._state == STATE_CONNECTED:
46 try:
47 self._raw_connection.rollback()
48@@ -314,6 +330,7 @@
49 If the connection is marked as dead, or if we can't reconnect,
50 then raise DisconnectionError.
51 """
52+ self._check_thread()
53 if self._state == STATE_CONNECTED:
54 return
55 elif self._state == STATE_DISCONNECTED:
56
57=== modified file 'storm/exceptions.py'
58--- storm/exceptions.py 2009-05-15 08:43:21 +0000
59+++ storm/exceptions.py 2009-12-27 11:31:14 +0000
60@@ -82,6 +82,10 @@
61 pass
62
63
64+class ThreadSafetyError(StormError):
65+ """Exception raised when cross-threads operations are attempted."""
66+
67+
68 class Error(StormError):
69 pass
70
71
72=== added directory 'storm/twisted'
73=== added file 'storm/twisted/__init__.py'
74--- storm/twisted/__init__.py 1970-01-01 00:00:00 +0000
75+++ storm/twisted/__init__.py 2009-12-27 11:31:14 +0000
76@@ -0,0 +1,23 @@
77+#
78+# Copyright (c) 2007 Canonical
79+# Copyright (c) 2007 Thomas Herve <thomas@nimail.org>
80+#
81+# This file is part of Storm Object Relational Mapper.
82+#
83+# Storm is free software; you can redistribute it and/or modify
84+# it under the terms of the GNU Lesser General Public License as
85+# published by the Free Software Foundation; either version 2.1 of
86+# the License, or (at your option) any later version.
87+#
88+# Storm is distributed in the hope that it will be useful,
89+# but WITHOUT ANY WARRANTY; without even the implied warranty of
90+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
91+# GNU Lesser General Public License for more details.
92+#
93+# You should have received a copy of the GNU Lesser General Public License
94+# along with this program. If not, see <http://www.gnu.org/licenses/>.
95+#
96+
97+"""
98+Asynchronous wrapper around storm to be used within a Twisted application.
99+"""
100
101=== added file 'storm/twisted/store.py'
102--- storm/twisted/store.py 1970-01-01 00:00:00 +0000
103+++ storm/twisted/store.py 2009-12-27 11:31:14 +0000
104@@ -0,0 +1,480 @@
105+#
106+# Copyright (c) 2007 Canonical
107+# Copyright (c) 2007 Thomas Herve <thomas@nimail.org>
108+#
109+# This file is part of Storm Object Relational Mapper.
110+#
111+# Storm is free software; you can redistribute it and/or modify
112+# it under the terms of the GNU Lesser General Public License as
113+# published by the Free Software Foundation; either version 2.1 of
114+# the License, or (at your option) any later version.
115+#
116+# Storm is distributed in the hope that it will be useful,
117+# but WITHOUT ANY WARRANTY; without even the implied warranty of
118+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
119+# GNU Lesser General Public License for more details.
120+#
121+# You should have received a copy of the GNU Lesser General Public License
122+# along with this program. If not, see <http://www.gnu.org/licenses/>.
123+#
124+
125+"""
126+Store wrapper and custom thread runner.
127+"""
128+
129+from threading import Thread
130+from Queue import Queue
131+
132+from storm.store import Store, AutoReload
133+from storm.info import get_obj_info
134+from storm.twisted.wrapper import partial, DeferredResult, DeferredResultSet
135+
136+from twisted.internet.defer import Deferred, deferredGenerator, maybeDeferred
137+from twisted.internet.defer import waitForDeferred, succeed, fail
138+from twisted.python.failure import Failure
139+
140+
141+
142+class AlreadyStopped(Exception):
143+ """
144+ Except raised when a store is stopped multiple time.
145+ """
146+
147+
148+
149+class StoreThread(Thread):
150+ """
151+ A thread class that wraps methods calls and fires deferred in the reactor
152+ thread.
153+ """
154+ STOP = object()
155+
156+ def __init__(self):
157+ """
158+ Initialize the thread, and create a L{Queue} to stack jobs.
159+ """
160+ Thread.__init__(self)
161+ self.setDaemon(True)
162+ self._queue = Queue()
163+ self._stop_deferred = None
164+ self.stopped = False
165+
166+
167+ def defer_to_thread(self, f, *args, **kwargs):
168+ """
169+ Run the given function in the thread, wrapping the result with a
170+ L{Deferred}.
171+
172+ @return: a deferred whose result will be the result of the function.
173+ @rtype: C{Deferred}
174+ """
175+ if self.stopped:
176+ # to prevent having pending calls after the thread got stopped
177+ return fail(AlreadyStopped(f))
178+ d = Deferred()
179+ self._queue.put((d, f, args, kwargs))
180+ return d
181+
182+
183+ def run(self):
184+ """
185+ Main execution loop: retrieve jobs from the queue and run them.
186+ """
187+ from twisted.internet import reactor
188+ o = self._queue.get()
189+ while o is not self.STOP:
190+ d, f, args, kwargs = o
191+ try:
192+ result = f(*args, **kwargs)
193+ except:
194+ f = Failure()
195+ reactor.callFromThread(d.errback, f)
196+ else:
197+ reactor.callFromThread(d.callback, result)
198+ o = self._queue.get()
199+ reactor.callFromThread(self._stop_deferred.callback, None)
200+
201+
202+ def stop(self):
203+ """
204+ Stop the thread.
205+ """
206+ if self.stopped:
207+ return self._stop_deferred
208+ self._stop_deferred = Deferred()
209+ self._queue.put(self.STOP)
210+ self.stopped = True
211+ return self._stop_deferred
212+
213+
214+
215+class DeferredStore(object):
216+ """
217+ A wrapper around L{Store} to have async operations.
218+ """
219+ store = None
220+
221+ def __init__(self, database):
222+ """
223+ @param database: instance of database providing connection, used to
224+ instantiate the store later.
225+ @type database: L{storm.database.Database}
226+ """
227+ self.thread = StoreThread()
228+ self.database = database
229+ self.started = False
230+
231+
232+ def start(self):
233+ """
234+ Start the store.
235+
236+ @return: a deferred that will fire once the store is started.
237+ """
238+ if not self.started:
239+ self.started = True
240+ self.thread.start()
241+ # Add a event trigger to be sure that the thread is stopped
242+ from twisted.internet import reactor
243+ reactor.addSystemEventTrigger(
244+ "before", "shutdown", self.stop)
245+ return self.thread.defer_to_thread(Store, self.database
246+ ).addCallback(self._got_store)
247+ else:
248+ raise RuntimeError("Already started")
249+
250+
251+ def _got_store(self, store):
252+ """
253+ Internal method called when the store is created, initializing most of
254+ the API methods.
255+ """
256+ self.store = store
257+ # Maybe not ?
258+ self.store._deferredStore = self
259+ for methodName in ("commit", "flush", "remove", "reload",
260+ "rollback"):
261+ method = partial(self.thread.defer_to_thread,
262+ getattr(self.store, methodName))
263+ setattr(self, methodName, method)
264+
265+ self._do_resolve_lazy_value = self.store._resolve_lazy_value
266+ self.store._resolve_lazy_value = self._resolve_lazy_value
267+
268+
269+ def get(self, cls, key):
270+ def _get():
271+ obj = self.store.get(cls, key)
272+ if obj is not None:
273+ obj_info = get_obj_info(obj)
274+ self._do_resolve_lazy_value(obj_info, None, AutoReload)
275+ return obj
276+ return self.thread.defer_to_thread(_get)
277+
278+
279+ def add(self, obj):
280+ """
281+ Specific add method that doesn't return any result, to not make think
282+ that it's something usable.
283+ """
284+ def _add():
285+ self.store.add(obj)
286+ return self.thread.defer_to_thread(_add)
287+
288+
289+ def execute(self, *args, **kwargs):
290+ """
291+ Wrapper around C{execute} to have a C{DeferredResult} instead of the
292+ standard L{storm.database.Result} object.
293+ """
294+ if self.store is None:
295+ raise RuntimeError("Store not started")
296+ return self.thread.defer_to_thread(
297+ self.store.execute, *args, **kwargs
298+ ).addCallback(self._cb_execute)
299+
300+
301+ def _cb_execute(self, result):
302+ """
303+ Wrap the result with a C{DeferredResult}.
304+ """
305+ if result is not None:
306+ return DeferredResult(self.thread, result)
307+
308+
309+ def find(self, *args, **kwargs):
310+ """
311+ Wrapper around C{find}.
312+ """
313+ if self.store is None:
314+ raise RuntimeError("Store not started")
315+ return self.thread.defer_to_thread(
316+ self.store.find, *args, **kwargs
317+ ).addCallback(self._cb_find)
318+
319+
320+ def _cb_find(self, resultSet):
321+ """
322+ Wrap the result set with a C{DeferredResultSet}.
323+ """
324+ return DeferredResultSet(self.thread, resultSet)
325+
326+
327+ def stop(self):
328+ """
329+ Stop the store.
330+ """
331+ if self.thread.stopped:
332+ return succeed(None)
333+ def close():
334+ self.store.rollback()
335+ self.store.close()
336+ return self.thread.defer_to_thread(close
337+ ).addCallback(lambda ign: self.thread.stop())
338+
339+
340+ def _resolve_lazy_value(self, *args):
341+ raise RuntimeError(
342+ "Resolving lazy values with the Twisted wrapper is not possible "
343+ "right now! Please refetch your object using "
344+ "store.get/store.find")
345+
346+
347+ @staticmethod
348+ def of(obj):
349+ """
350+ Get the DeferredStore object is associated with
351+
352+ If the given object has not been associated with a DeferredStore,
353+ return None.
354+ """
355+ store = Store.of(obj)
356+ if not store:
357+ return
358+ return getattr(store, '_deferredStore', None)
359+
360+
361+
362+class StorePool(object):
363+ """
364+ A pool of started stores, maintaining persistent connections.
365+ """
366+ started = False
367+ store_factory = DeferredStore
368+
369+ def __init__(self, database, min_stores=0, max_stores=10):
370+ """
371+ @param database: instance of database providing connection, used to
372+ instantiate the store later.
373+ @type database: L{storm.database.Database}
374+
375+ @param min_stores: initial number of stores.
376+ @type min_stores: C{int}
377+
378+ @param max_stores: maximum number of stores.
379+ @type max_stores: C{int}
380+ """
381+ self.database = database
382+ self.min_stores = min_stores
383+ self.max_stores = max_stores
384+ self._stores = []
385+ self._stores_created = 0
386+ self._pending_get = []
387+ self._store_refs = []
388+
389+
390+ def start(self):
391+ """
392+ Start the pool.
393+ """
394+ if self.started:
395+ raise RuntimeError("Already started")
396+ self.started = True
397+ return self.adjust_size()
398+
399+
400+ def stop(self):
401+ """
402+ Stop the pool: this is not a total stop, it just try to kill the
403+ current available stores.
404+ """
405+ return self.adjust_size(0, 0, self._store_refs)
406+
407+
408+ def _start_store(self):
409+ """
410+ Create a new store.
411+ """
412+ store = self.store_factory(self.database)
413+ # Increment here, so that other simultaneous calls don't make the
414+ # number of connections pass the maximum
415+ self._stores_created += 1
416+ return store.start(
417+ ).addCallback(self._cb_start_store, store
418+ ).addErrback(self._eb_start_store)
419+
420+
421+ def _cb_start_store(self, ign, store):
422+ """
423+ Add the created store to the list of available stores.
424+ """
425+ self._stores.append(store)
426+ self._store_refs.append(store)
427+
428+
429+ def _eb_start_store(self, failure):
430+ """
431+ Reduce the amount of created stores, and let the failure propagate.
432+ """
433+ self._stores_created -= 1
434+ return failure
435+
436+
437+ def _stop_store(self, stores=None):
438+ """
439+ Stop a store and remove it from the available stores.
440+ """
441+ if stores is None:
442+ stores = self._stores
443+ self._stores_created -= 1
444+ store = stores.pop()
445+ return store.stop()
446+
447+
448+ @deferredGenerator
449+ def adjust_size(self, min_stores=None, max_stores=None, stores=None):
450+ """
451+ Change the number of available stores, shrinking or raising as
452+ necessary.
453+ """
454+ if min_stores is None:
455+ min_stores = self.min_stores
456+ if max_stores is None:
457+ max_stores = self.max_stores
458+ if stores is None:
459+ stores = self._stores
460+
461+ if min_stores < 0:
462+ raise ValueError('minimum is negative')
463+ if min_stores > max_stores:
464+ raise ValueError('minimum is greater than maximum')
465+
466+ self.min_stores = min_stores
467+ self.max_stores = max_stores
468+ if not self.started:
469+ return
470+
471+ # Kill of some stores if we have too many.
472+ while self._stores_created > self.max_stores and stores:
473+ wfd = waitForDeferred(self._stop_store(stores))
474+ yield wfd
475+ wfd.getResult()
476+ # Start some stores if we have too few.
477+ while self._stores_created < self.min_stores:
478+ wfd = waitForDeferred(self._start_store())
479+ yield wfd
480+ wfd.getResult()
481+
482+
483+ def get(self):
484+ """
485+ Return a started store from the pool, or start a new one if necessary.
486+ A store retrieve by this way should be put back using the put
487+ method, or it won't be used anymore.
488+ """
489+ if not self.started:
490+ raise RuntimeError("Not started")
491+ if self._stores:
492+ store = self._stores.pop()
493+ return succeed(store)
494+ elif self._stores_created < self.max_stores:
495+ return self._start_store().addCallback(self._cb_get)
496+ else:
497+ # Maybe all stores are consumed?
498+ return self.adjust_size().addCallback(self._cb_get)
499+
500+
501+ def _cb_get(self, ign):
502+ """
503+ If the previous operation added a store, return it, or return a pending
504+ C{Deferred}.
505+ """
506+ if self._stores:
507+ store = self._stores.pop()
508+ return store
509+ else:
510+ # All stores are in used, wait
511+ d = Deferred()
512+ self._pending_get.append(d)
513+ return d
514+
515+
516+ def put(self, store):
517+ """
518+ Make a store available again.
519+
520+ This should be done explicitely to have the store back in the pool.
521+ The good way to use the pool is this:
522+
523+ >>> d1 = pool.get()
524+
525+ >>> # d1 callback with a store
526+ >>> d2 = store.add(foo)
527+ >>> d2.addCallback(doSomething).addErrback(manageErrors)
528+ >>> d2.addBoth(lambda x: pool.put(store))
529+ """
530+ return store.rollback().addBoth(self._cb_put, store)
531+
532+
533+ def _cb_put(self, passthrough, store):
534+ """
535+ Once the rollback has finished, the store is really available.
536+ """
537+ if self._pending_get:
538+ # People are waiting, fire with the store
539+ d = self._pending_get.pop(0)
540+ d.callback(store)
541+ else:
542+ self._stores.append(store)
543+ return passthrough
544+
545+
546+ def transact(self, f, *args, **kwargs):
547+ """
548+ Call function C{f} with a L{Store} instance and arguments C{args} and
549+ C{kwargs} in transaction bound to the acquired store. If transaction
550+ succeeds, store will be commited. Store is returned to this pool after
551+ call to C{f} completes.
552+
553+ Note that the function C{f} must return an instance of L{Deferred}.
554+
555+ @param f: function to call in transaction
556+ @param args: positional arguments to function C{f}
557+ @param kwargs: keyword arguments to function C{f}
558+ """
559+ return self.get(
560+ ).addCallback(self._cb_transact_start, f, args, kwargs)
561+
562+
563+ def _cb_transact_start(self, store, f, args, kwargs):
564+ """
565+ Call transacted function with acquired store.
566+ """
567+ result = maybeDeferred(f, store, *args, **kwargs)
568+ result.addCallback(self._cb_transact_success, store)
569+ result.addBoth(self._cb_transact_stop, store)
570+ return result
571+
572+
573+ def _cb_transact_success(self, result, store):
574+ """
575+ Commit and pass through function result.
576+ """
577+ return store.commit().addCallback(lambda ignore: result)
578+
579+
580+ def _cb_transact_stop(self, result, store):
581+ """
582+ Return the store back to the pool and pass through the result again.
583+ """
584+ return self.put(store).addCallback(lambda ignore: result)
585
586=== added file 'storm/twisted/wrapper.py'
587--- storm/twisted/wrapper.py 1970-01-01 00:00:00 +0000
588+++ storm/twisted/wrapper.py 2009-12-27 11:31:14 +0000
589@@ -0,0 +1,206 @@
590+#
591+# Copyright (c) 2007 Canonical
592+# Copyright (c) 2007 Thomas Herve <thomas@nimail.org>
593+#
594+# This file is part of Storm Object Relational Mapper.
595+#
596+# Storm is free software; you can redistribute it and/or modify
597+# it under the terms of the GNU Lesser General Public License as
598+# published by the Free Software Foundation; either version 2.1 of
599+# the License, or (at your option) any later version.
600+#
601+# Storm is distributed in the hope that it will be useful,
602+# but WITHOUT ANY WARRANTY; without even the implied warranty of
603+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
604+# GNU Lesser General Public License for more details.
605+#
606+# You should have received a copy of the GNU Lesser General Public License
607+# along with this program. If not, see <http://www.gnu.org/licenses/>.
608+#
609+
610+"""
611+Asynchronous wrapper around storm.
612+"""
613+
614+from storm.store import Store
615+from storm.references import Reference, ReferenceSet
616+
617+
618+try:
619+ from functools import partial
620+except ImportError:
621+ # For Python < 2.5
622+ class partial(object):
623+ def __init__(self, fn, *args, **kw):
624+ self.fn = fn
625+ self.args = args
626+ self.kw = kw
627+
628+ def __call__(self, *args, **kw):
629+ if kw and self.kw:
630+ d = self.kw.copy()
631+ d.update(kw)
632+ else:
633+ d = kw or self.kw
634+ return self.fn(*(self.args + args), **d)
635+
636+
637+
638+class DeferredResult(object):
639+ """
640+ Proxy for a storm result, running the blocking methods in a thread and
641+ returning C{Deferred}s.
642+ """
643+
644+ def __init__(self, thread, result):
645+ """
646+ @param thread: the running thread of the store
647+ @type thread: C{StoreThread}
648+
649+ @param result: the result instance to be wrapped.
650+ @type result: C{storm.database.Result}
651+ """
652+ self.result = result
653+ for methodName in ("get_one", "get_all"):
654+ method = partial(thread.defer_to_thread,
655+ getattr(result, methodName))
656+ setattr(self, methodName, method)
657+
658+
659+
660+class DeferredResultSet(object):
661+ """
662+ Wrapper for a L{storm.store.ResultSet}.
663+ """
664+
665+ def __init__(self, thread, resultSet):
666+ """
667+ Create the results with given C{StoreThread} and the set to wrap.
668+ """
669+ self._thread = thread
670+ self._resultSet = resultSet
671+ for methodName in ("any", "one", "first", "last", "remove", "count",
672+ "max", "min", "avg", "sum", "set", "is_empty"):
673+ method = partial(thread.defer_to_thread,
674+ getattr(resultSet, methodName))
675+ setattr(self, methodName, method)
676+ for methodName in ("union", "difference", "intersection"):
677+ method = partial(self._set_expr,
678+ getattr(resultSet, methodName))
679+ setattr(self, methodName, method)
680+ for methodName in ("order_by", "config", "group_by", "having"):
681+ setattr(self, methodName, getattr(self._resultSet, methodName))
682+
683+
684+ def all(self):
685+ """
686+ Specific method to emulate C{__iter__}.
687+ """
688+ return self._thread.defer_to_thread(list, self._resultSet)
689+
690+
691+ def values(self, *columns):
692+ """
693+ Wrapper around values that remove the iterator feature to return a list
694+ instead.
695+ """
696+ def _get_values():
697+ return list(self._resultSet.values(*columns))
698+ return self._thread.defer_to_thread(_get_values)
699+
700+
701+ def _set_expr(self, method, other, all=False):
702+ """
703+ Wrap a set expression with a C{DeferredResultSet}.
704+ """
705+ return DeferredResultSet(self._thread, method(other, all))
706+
707+
708+
709+class DeferredReference(Reference):
710+ """
711+ A reference property but within a C{Deferred}.
712+ """
713+
714+ def __get__(self, local, cls=None):
715+ """
716+ Wrapper around C{Reference.__get__}.
717+ """
718+ store = Store.of(local)
719+ if store is None:
720+ return None
721+ _thread = store._deferredStore.thread
722+ return _thread.defer_to_thread(Reference.__get__, self, local, cls)
723+
724+
725+ def __set__(self, local, remote):
726+ """
727+ Wrapper around C{Reference.__set__}.
728+ """
729+ raise RuntimeError("Can't set a DeferredReference")
730+
731+
732+
733+class DeferredReferenceSet(ReferenceSet):
734+ """
735+ A C{ReferenceSet} but within a C{Deferred}.
736+ """
737+
738+ def __get__(self, local, cls=None):
739+ """
740+ Wrapper around C{ReferenceSet.__get__}.
741+ """
742+ store = Store.of(local)
743+ if store is None:
744+ return None
745+ _thread = store._deferredStore.thread
746+ boundReference = ReferenceSet.__get__(self, local, cls)
747+ return DeferredBoundReference(_thread, boundReference)
748+
749+
750+
751+class DeferredBoundReference(object):
752+ """
753+ Wrapper around C{BoundReferenceSet} and C{BoundIndirectReferenceSet}.
754+ """
755+
756+ def __init__(self, thread, boundReference):
757+ """
758+ Create the reference with given C{StoreThread} and the reference to
759+ wrap.
760+ """
761+ self._thread = thread
762+ self._boundReference = boundReference
763+ for methodName in ("clear", "add", "remove", "any", "count", "one",
764+ "first", "last"):
765+ method = partial(thread.defer_to_thread,
766+ getattr(boundReference, methodName))
767+ setattr(self, methodName, method)
768+ for methodName in ("order_by", "find"):
769+ method = partial(self._defer_and_wrap_result,
770+ getattr(boundReference, methodName))
771+ setattr(self, methodName, method)
772+
773+
774+ def all(self):
775+ """
776+ Specific method to emulate C{__iter__}.
777+ """
778+ return self._thread.defer_to_thread(list, self._boundReference)
779+
780+
781+ def values(self, *columns):
782+ """
783+ Emulate the values method.
784+ """
785+ def _get_values():
786+ return list(self._boundReference.values(*columns))
787+ return self._thread.defer_to_thread(_get_values)
788+
789+
790+ def _defer_and_wrap_result(self, method, *args, **kwargs):
791+ """
792+ Helper for methods returning another C{ResultSet}.
793+ """
794+ return self._thread.defer_to_thread(method, *args, **kwargs
795+ ).addCallback(lambda x: DeferredResultSet(self._thread, x))
796
797=== added directory 'tests/twisted'
798=== added file 'tests/twisted/__init__.py'
799=== added file 'tests/twisted/base.py'
800--- tests/twisted/base.py 1970-01-01 00:00:00 +0000
801+++ tests/twisted/base.py 2009-12-27 11:31:15 +0000
802@@ -0,0 +1,1286 @@
803+#
804+# Copyright (c) 2007 Canonical
805+# Copyright (c) 2007 Thomas Herve <thomas@nimail.org>
806+#
807+# This file is part of Storm Object Relational Mapper.
808+#
809+# Storm is free software; you can redistribute it and/or modify
810+# it under the terms of the GNU Lesser General Public License as
811+# published by the Free Software Foundation; either version 2.1 of
812+# the License, or (at your option) any later version.
813+#
814+# Storm is distributed in the hope that it will be useful,
815+# but WITHOUT ANY WARRANTY; without even the implied warranty of
816+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
817+# GNU Lesser General Public License for more details.
818+#
819+# You should have received a copy of the GNU Lesser General Public License
820+# along with this program. If not, see <http://www.gnu.org/licenses/>.
821+#
822+
823+"""
824+Test for twistorm.
825+"""
826+
827+from storm.properties import Int, Unicode
828+from storm.expr import Count
829+from storm.references import Reference
830+from storm.exceptions import OperationalError, ThreadSafetyError
831+
832+from storm.twisted.store import (
833+ DeferredStore, StoreThread, StorePool, AlreadyStopped )
834+from storm.twisted.wrapper import DeferredReference, DeferredReferenceSet
835+
836+from twisted.trial.unittest import TestCase
837+from twisted.internet.defer import gatherResults, deferredGenerator
838+from twisted.internet.defer import waitForDeferred, DeferredList
839+from twisted.internet.defer import succeed, fail
840+
841+
842+class Foo(object):
843+ """
844+ Test table.
845+ """
846+ __storm_table__ = "foo"
847+ id = Int(primary=True)
848+ title = Unicode()
849+
850+
851+
852+class Bar(object):
853+ """
854+ Test table referencing to C{Foo}
855+ """
856+ __storm_table__ = "bar"
857+ id = Int(primary=True)
858+ title = Unicode()
859+ foo_id = Int()
860+ foo = DeferredReference(foo_id, Foo.id)
861+
862+
863+
864+class FooRefSet(Foo):
865+ """
866+ A C{Foo} class with a C{DeferredReferenceSet} to get all the bars related.
867+ """
868+ bars = DeferredReferenceSet(Foo.id, Bar.foo_id)
869+
870+
871+
872+class FooRefSetOrderID(Foo):
873+ """
874+ A C{Foo} class with an order C{DeferredReferenceSet} to C{Bar}.
875+ """
876+ bars = DeferredReferenceSet(Foo.id, Bar.foo_id, order_by=Bar.id)
877+
878+
879+
880+class Egg(object):
881+ """
882+ Test table.
883+ """
884+ __storm_table__ = "egg"
885+ id = Int(primary=True)
886+ value = Int()
887+
888+
889+
890+class DeferredStoreTest(object):
891+ """
892+ Tests for L{DeferredStore}.
893+ """
894+
895+ def setUp(self):
896+ """
897+ Create a test sqlite database, and insert some data.
898+ """
899+ self.create_database()
900+ connection = self.connection = self.database.connect()
901+ self.drop_tables()
902+ self.create_tables()
903+ connection.execute("INSERT INTO foo VALUES (10, 'Title 30')")
904+ connection.execute("INSERT INTO bar VALUES (10, 10, 'Title 50')")
905+ connection.execute("INSERT INTO bar VALUES (11, 10, 'Title 40')")
906+ connection.execute("INSERT INTO egg VALUES (1, 4)")
907+ connection.execute("INSERT INTO egg VALUES (2, 3)")
908+ connection.execute("INSERT INTO egg VALUES (3, 7)")
909+ connection.execute("INSERT INTO egg VALUES (4, 5)")
910+ connection.commit()
911+ self.store = DeferredStore(self.database)
912+ return self.store.start()
913+
914+
915+ def tearDown(self):
916+ """
917+ Kill the store (and its underlying thread).
918+ """
919+ def _stop(ign):
920+ return self.store.stop().addCallback(_drop)
921+ def _drop(ign):
922+ self.drop_tables()
923+ return self.store.rollback().addCallback(_stop)
924+
925+
926+ def create_database(self):
927+ raise NotImplementedError()
928+
929+
930+ def create_tables(self):
931+ raise NotImplementedError()
932+
933+
934+ def drop_tables(self):
935+ for table in ["foo", "bar", "egg"]:
936+ try:
937+ self.connection.execute("DROP TABLE %s" % table)
938+ self.connection.commit()
939+ except:
940+ self.connection.rollback()
941+
942+
943+ def test_multiple_start(self):
944+ """
945+ Check that start raises an exception when the store is already started.
946+ """
947+ self.assertRaises(RuntimeError, self.store.start)
948+
949+
950+ def test_get(self):
951+ """
952+ Try to get an object from the store and check its attributes.
953+ """
954+ def cb(result):
955+ self.assertEquals(result.title, u"Title 30")
956+ self.assertEquals(result.id, 10)
957+ return self.store.get(Foo, 10).addCallback(cb)
958+
959+
960+ def test_add(self):
961+ """
962+ Add an object to the store.
963+ """
964+ foo = Foo()
965+ foo.title = u"Great title"
966+ foo.id = 11
967+ def cb_add(ign):
968+ return self.store.get(Foo, 11).addCallback(cb_get)
969+ def cb_get(result):
970+ self.assertEquals(result.title, u"Great title")
971+ self.assertEquals(result.id, 11)
972+ return self.store.add(foo).addCallback(cb_add)
973+
974+
975+ def test_add_default_value(self):
976+ """
977+ When adding an object to the store, the default values from the
978+ database are retrieved and put into the object.
979+ """
980+ foo = Foo()
981+ foo.id = 11
982+ def cb_add(result):
983+ self.assertIdentical(result, None)
984+ return self.store.get(Foo, 11).addCallback(cb_get)
985+ def cb_get(result):
986+ self.assertEquals(result.title, u"Default Title")
987+ self.assertEquals(result.id, 11)
988+ return self.store.add(foo).addCallback(cb_add)
989+
990+
991+ def test_execute(self):
992+ """
993+ Test a direct execute on the store, and the C{get_one} method of
994+ C{DeferredResult}.
995+ """
996+ def cb_execute(result):
997+ return result.get_one().addCallback(cb_result)
998+ def cb_result(result):
999+ self.assertEquals(result, (u"Title 30",))
1000+ return self.store.execute("SELECT title FROM foo WHERE id=10"
1001+ ).addCallback(cb_execute)
1002+
1003+
1004+ def test_execute_all(self):
1005+ """
1006+ Test a direct execute on the store, and the C{all} method of
1007+ C{DeferredResult}.
1008+ """
1009+ def cb_execute(result):
1010+ return result.get_all().addCallback(cb_result)
1011+ def cb_result(result):
1012+ self.assertEquals(result, [(u"Title 50",), (u"Title 40",)])
1013+ return self.store.execute("SELECT title FROM bar"
1014+ ).addCallback(cb_execute)
1015+
1016+
1017+ def test_remove(self):
1018+ """
1019+ Trying removing an object from the database.
1020+ """
1021+ def cb_get(result):
1022+ return self.store.remove(result).addCallback(cb_remove)
1023+ def cb_remove(ign):
1024+ return self.store.get(Foo, 10).addCallback(cb_get_after_remove)
1025+ def cb_get_after_remove(result):
1026+ self.assertIdentical(result, None)
1027+ return self.store.get(Foo, 10).addCallback(cb_get)
1028+
1029+
1030+ def test_find(self):
1031+ """
1032+ Try to find a list of objects using the store.
1033+ """
1034+ def cb_find(results):
1035+ return results.all().addCallback(cb_all)
1036+ def cb_all(results):
1037+ self.assertEquals(len(results), 2)
1038+ titles = [results[0].title, results[1].title]
1039+ titles.sort()
1040+ self.assertEquals(titles, [u"Title 40", u"Title 50"])
1041+ return self.store.find(Bar).addCallback(cb_find)
1042+
1043+
1044+ def test_find_first(self):
1045+ """
1046+ Try to get the first object matching a query.
1047+ """
1048+ def cb_find(results):
1049+ results.order_by(Bar.title)
1050+ return results.first().addCallback(cb_all)
1051+ def cb_all(result):
1052+ self.assertEquals(result.title, u"Title 40")
1053+ self.assertEquals(result.id, 11)
1054+ return self.store.find(Bar).addCallback(cb_find)
1055+
1056+
1057+ def test_find_last(self):
1058+ """
1059+ Try to get the last object matching a query.
1060+ """
1061+ def cb_find(results):
1062+ results.order_by(Bar.title)
1063+ return results.last().addCallback(cb_all)
1064+ def cb_all(result):
1065+ self.assertEquals(result.title, u"Title 50")
1066+ self.assertEquals(result.id, 10)
1067+ return self.store.find(Bar).addCallback(cb_find)
1068+
1069+
1070+ def test_find_any(self):
1071+ """
1072+ Try to get an object matching a query using the C{any} method.
1073+ """
1074+ def cb_find(results):
1075+ return results.any().addCallback(cb_all)
1076+ def cb_all(result):
1077+ self.assertEquals(result.title, u"Title 50")
1078+ self.assertEquals(result.id, 10)
1079+ return self.store.find(Bar).addCallback(cb_find)
1080+
1081+
1082+ def test_find_max(self):
1083+ """
1084+ Try to get the maximum of a value after a find.
1085+ """
1086+ def cb_find(results):
1087+ return results.max(Egg.value).addCallback(cb_all)
1088+ def cb_all(result):
1089+ self.assertEquals(result, 7)
1090+ return self.store.find(Egg).addCallback(cb_find)
1091+
1092+
1093+ def test_find_min(self):
1094+ """
1095+ Try to get the minimum of a value after a find.
1096+ """
1097+ def cb_find(results):
1098+ return results.min(Egg.value).addCallback(cb_all)
1099+ def cb_all(result):
1100+ self.assertEquals(result, 3)
1101+ return self.store.find(Egg).addCallback(cb_find)
1102+
1103+
1104+ def test_find_avg(self):
1105+ """
1106+ Try to get the average of a value after a find.
1107+ """
1108+ def cb_find(results):
1109+ return results.avg(Egg.value).addCallback(cb_all)
1110+ def cb_all(result):
1111+ self.assertEquals(result, 4.75)
1112+ return self.store.find(Egg).addCallback(cb_find)
1113+
1114+
1115+ def test_find_sum(self):
1116+ """
1117+ Try to get the sum of a value after a find.
1118+ """
1119+ def cb_find(results):
1120+ return results.sum(Egg.value).addCallback(cb_all)
1121+ def cb_all(result):
1122+ self.assertEquals(result, 19)
1123+ return self.store.find(Egg).addCallback(cb_find)
1124+
1125+
1126+ def test_find_count(self):
1127+ """
1128+ Try to get the count of a result after a find.
1129+ """
1130+ def cb_find(results):
1131+ return results.count().addCallback(cb_all)
1132+ def cb_all(result):
1133+ self.assertEquals(result, 2)
1134+ return self.store.find(Egg, Egg.value >= 5).addCallback(cb_find)
1135+
1136+
1137+ def test_find_remove(self):
1138+ """
1139+ Remove the result of a find query.
1140+ """
1141+ def cb_find(results):
1142+ return results.remove().addCallback(cb_remove)
1143+ def cb_remove(ignore):
1144+ return self.store.find(Egg).addCallback(cb_find_after_remove)
1145+ def cb_find_after_remove(results):
1146+ return results.all().addCallback(cb_all)
1147+ def cb_all(results):
1148+ self.assertEquals(len(results), 2)
1149+ return self.store.find(Egg, Egg.value >= 5).addCallback(cb_find)
1150+
1151+
1152+ def test_find_limit(self):
1153+ """
1154+ Put a limit on the number of results of a find.
1155+ """
1156+ def cb_find(results):
1157+ results.config(limit=3)
1158+ return results.all().addCallback(cb_all)
1159+ def cb_all(results):
1160+ self.assertEquals(len(results), 3)
1161+ return self.store.find(Egg).addCallback(cb_find)
1162+
1163+
1164+ def test_find_union(self):
1165+ """
1166+ Call C{union} on 2 differents C{DeferredResultSet}.
1167+ """
1168+ def cb_find(results):
1169+ result1, result2 = results
1170+ results = result1.union(result2._resultSet)
1171+ return results.all().addCallback(cb_all)
1172+ def cb_all(results):
1173+ self.assertEquals(len(results), 3)
1174+ d1 = self.store.find(Egg, Egg.value >= 5)
1175+ d2 = self.store.find(Egg, Egg.value == 3)
1176+ return gatherResults([d1, d2]).addCallback(cb_find)
1177+
1178+
1179+ def test_find_difference(self):
1180+ """
1181+ Call C{union} on 2 differents C{DeferredResultSet}.
1182+ """
1183+ if self.__class__.__name__.startswith("MySQL"):
1184+ return
1185+ def cb_find(results):
1186+ result1, result2 = results
1187+ results = result1.difference(result2._resultSet)
1188+ return results.all().addCallback(cb_all)
1189+ def cb_all(results):
1190+ self.assertEquals(len(results), 1)
1191+ self.assertEquals(results[0].value, 5)
1192+ d1 = self.store.find(Egg, Egg.value >= 5)
1193+ d2 = self.store.find(Egg, Egg.value == 7)
1194+ return gatherResults([d1, d2]).addCallback(cb_find)
1195+
1196+
1197+ def test_find_intersection(self):
1198+ """
1199+ Call C{intersection} on 2 differents C{DeferredResultSet}.
1200+ """
1201+ if self.__class__.__name__.startswith("MySQL"):
1202+ return
1203+ def cb_find(results):
1204+ result1, result2 = results
1205+ results = result1.intersection(result2._resultSet)
1206+ return results.all().addCallback(cb_all)
1207+ def cb_all(results):
1208+ self.assertEquals(len(results), 1)
1209+ self.assertEquals(results[0].value, 7)
1210+ d1 = self.store.find(Egg, Egg.value >= 5)
1211+ d2 = self.store.find(Egg, Egg.value == 7)
1212+ return gatherResults([d1, d2]).addCallback(cb_find)
1213+
1214+
1215+ def test_find_values(self):
1216+ """
1217+ Filter the fields returned by a find using the values method.
1218+ """
1219+ def cb_find(results):
1220+ return results.values(Bar.title).addCallback(cb_all)
1221+ def cb_all(titles):
1222+ titles.sort()
1223+ self.assertEquals(titles, [u"Title 40", u"Title 50"])
1224+ return self.store.find(Bar).addCallback(cb_find)
1225+
1226+
1227+ def test_find_and_set(self):
1228+ """
1229+ The C{set} method of a C{ResultSet} should update the specified fields
1230+ in a thread.
1231+ """
1232+ def cb_find(results):
1233+ return results.set(title=u"Title").addCallback(cb_set, results)
1234+ def cb_set(ignore, results):
1235+ return results.values(Bar.title).addCallback(cb_all)
1236+ def cb_all(titles):
1237+ titles.sort()
1238+ self.assertEquals(titles, [u"Title", u"Title"])
1239+ return self.store.find(Bar).addCallback(cb_find)
1240+
1241+
1242+ def test_find_offset(self):
1243+ """
1244+ Put an offset on the number of results of a find.
1245+ """
1246+ def cb_find(results):
1247+ results.config(offset=2)
1248+ return results.all().addCallback(cb_all)
1249+ def cb_all(results):
1250+ self.assertEquals(len(results), 2)
1251+ return self.store.find(Egg).addCallback(cb_find)
1252+
1253+
1254+ def test_find_offset_limit(self):
1255+ """
1256+ Put an offset and limit in the number of results of a find.
1257+ """
1258+ def cb_find(results):
1259+ results.config(offset=1, limit=2)
1260+ return results.all().addCallback(cb_all)
1261+ def cb_all(results):
1262+ self.assertEquals(len(results), 2)
1263+ return self.store.find(Egg).addCallback(cb_find)
1264+
1265+
1266+ @deferredGenerator
1267+ def test_find_defgen(self):
1268+ """
1269+ Do a find, add an object, then do another find: this to ensure that the
1270+ connection remains in the dedicated thread.
1271+ """
1272+ d = self.store.find(Bar)
1273+ wfd = waitForDeferred(d)
1274+ yield wfd
1275+ results = wfd.getResult()
1276+ d = results.all()
1277+ wfd = waitForDeferred(d)
1278+ yield wfd
1279+ wfd.getResult()
1280+ foo = Foo()
1281+ foo.title = u"Great title"
1282+ foo.id = 11
1283+ d = self.store.add(foo)
1284+ wfd = waitForDeferred(d)
1285+ yield wfd
1286+ wfd.getResult()
1287+ d = self.store.find(Foo)
1288+ wfd = waitForDeferred(d)
1289+ yield wfd
1290+ results = wfd.getResult()
1291+
1292+
1293+ def test_find_order_by(self):
1294+ """
1295+ Try to find a list of objects using the store, then order the result
1296+ set.
1297+ """
1298+ def cb_find(results):
1299+ results.order_by(Bar.title)
1300+ return results.all().addCallback(cb_all)
1301+ def cb_all(results):
1302+ self.assertEquals(len(results), 2)
1303+ titles = [results[0].title, results[1].title]
1304+ self.assertEquals(titles, [u"Title 40", u"Title 50"])
1305+ return self.store.find(Bar).addCallback(cb_find)
1306+
1307+
1308+ def test_find_and_rollback(self):
1309+ """
1310+ Accessing an object outside of a transaction fails because the object
1311+ hasn't been resolved yet.
1312+ """
1313+ def cb_find(results):
1314+ results.order_by(Bar.title)
1315+ return results.all().addCallback(cb_all)
1316+ def cb_all(results):
1317+ return self.store.rollback().addCallback(cbRollback, results)
1318+ def cbRollback(ign, results):
1319+ self.assertEquals(len(results), 2)
1320+ self.assertRaises(RuntimeError, getattr, results[0], "title")
1321+ return self.store.find(Bar).addCallback(cb_find)
1322+
1323+
1324+ def test_find_is_empty(self):
1325+ """
1326+ DeferredReference.is_empty returns a Deferred that fires with True or
1327+ False depending if the matched result set is empty or not.
1328+ """
1329+ def cb_find(results):
1330+ return results.is_empty().addCallback(self.assertEquals, False)
1331+
1332+ return self.store.find(Bar).addCallback(cb_find)
1333+
1334+
1335+ def test_find_group_by(self):
1336+ """
1337+ DeferredReference.group_by is a simple wrapper to the group_by method
1338+ of the reference set.
1339+ """
1340+ def cb_find(results):
1341+ results.group_by(Bar.foo_id)
1342+ return results.all().addCallback(check)
1343+
1344+ def check(result):
1345+ self.assertEquals(result, [(2, 10)])
1346+
1347+ return self.store.find((Count(Bar.id), Bar.foo_id)
1348+ ).addCallback(cb_find)
1349+
1350+
1351+ def test_find_having(self):
1352+ """
1353+ DeferredReference.having is a simple wrapper to the having method of
1354+ the reference set.
1355+ """
1356+ connection = self.database.connect()
1357+ connection.execute("INSERT INTO egg VALUES (5, 7)")
1358+ connection.commit()
1359+
1360+ def cb_find(results):
1361+ results.group_by(Egg.value)
1362+ results.having(Egg.value >= 5)
1363+ results.order_by(Egg.value)
1364+ return results.all().addCallback(check)
1365+
1366+ def check(result):
1367+ self.assertEquals(result, [(1, 5), (2, 7)])
1368+
1369+ return self.store.find((Count(Egg.id), Egg.value)
1370+ ).addCallback(cb_find)
1371+
1372+
1373+ def test_reference(self):
1374+ """
1375+ Trying to get a reference of an object using C{DeferredReference}.
1376+ """
1377+ def cb_getBar(result):
1378+ return result.foo.addCallback(cb_getFoo)
1379+ def cb_getFoo(fooResult):
1380+ return self.store.get(Foo, 10).addCallback(cb_getFooBar, fooResult)
1381+ def cb_getFooBar(result, fooResult):
1382+ self.assertIdentical(fooResult, result)
1383+ # The result should be valid too
1384+ self.assertEquals(fooResult.title, u"Title 30")
1385+ return self.store.get(Bar, 10).addCallback(cb_getBar)
1386+
1387+
1388+ def test_reference_setting(self):
1389+ """
1390+ Try to set a reference of an object.
1391+ """
1392+ connection = self.database.connect()
1393+ connection.execute("INSERT INTO foo VALUES (20, 'Title 20')")
1394+ connection.commit()
1395+ def cb_getBar(result):
1396+ return self.store.get(Foo, 20).addCallback(cb_getFooBar, result)
1397+ def cb_getFooBar(result, barResult):
1398+ self.assertRaises(RuntimeError, setattr, barResult, "foo", result)
1399+ return self.store.get(Bar, 10).addCallback(cb_getBar)
1400+
1401+
1402+ def test_reference_set_unordered(self):
1403+ """
1404+ Get a reference set and call various wrapped methods on it.
1405+ """
1406+ # find test
1407+ def cb_find(results):
1408+ return results.all().addCallback(cb_all)
1409+
1410+ def cb_all(results):
1411+ self.assertEquals(len(results), 2)
1412+ titles = [results[0].title, results[1].title]
1413+ titles.sort()
1414+ self.assertEquals(titles, [u"Title 40", u"Title 50"])
1415+
1416+ def cb_any(result):
1417+ self.assertTrue(result)
1418+
1419+ def cb_values(titles):
1420+ titles.sort()
1421+ self.assertEquals(titles, [u"Title 40", u"Title 50"])
1422+
1423+ def do_tests(results):
1424+ results = results.bars
1425+ dfrs = [
1426+ results.find().addCallback(cb_find),
1427+ results.any().addCallback(cb_any),
1428+ results.values(Bar.title).addCallback(cb_values),
1429+ ]
1430+ return DeferredList(dfrs)
1431+
1432+ return self.store.get(FooRefSet, 10).addCallback(do_tests)
1433+
1434+
1435+ def test_reference_set_ordered(self):
1436+ """
1437+ A DeferredReferenceSet has a order_by method which returns a Deferred
1438+ firing when the reference set is ordered.
1439+ """
1440+ def do_tests(result):
1441+ dfrs = [
1442+ result.first().addCallback(lambda t:
1443+ self.assertEquals(t.title, u"Title 40")),
1444+ result.last().addCallback(lambda t:
1445+ self.assertEquals(t.title, u"Title 50")),
1446+ ]
1447+ return DeferredList(dfrs)
1448+
1449+ def order(results):
1450+ dfr = results.bars.order_by("title")
1451+ return dfr.addCallback(do_tests)
1452+
1453+ return self.store.get(FooRefSet, 10).addCallback(order)
1454+
1455+
1456+ def test_reference_set_add_remove(self):
1457+ """
1458+ A DeferredReferenceSet has a add method with returns a Deferred once
1459+ the object has been added.
1460+ Try to add things from the reference set async.
1461+ """
1462+ def add_one(result):
1463+ bar = Bar()
1464+ bar.title = u"Yeah"
1465+ return result.bars.add(bar).addCallback(remove_one, result, bar)
1466+
1467+ def remove_one(add_result, result, bar):
1468+ return result.bars.remove(bar).addCallback(get_all, result)
1469+
1470+ def get_all(ignore, result):
1471+ return result.bars.all().addCallback(check)
1472+
1473+ def check(result):
1474+ self.assertEquals(len(result), 2)
1475+
1476+ return self.store.get(FooRefSet, 10).addCallback(add_one)
1477+
1478+
1479+ def test_reference_set_clear(self):
1480+ """
1481+ A DeferredReferenceSet has a clear method which removes all elements
1482+ from the reference set and fires the returned Deferred when done.
1483+ """
1484+ def first_cb(result):
1485+ refs = result.bars
1486+ return check_count(refs, 2).addCallback(clear_cb, refs)
1487+
1488+ def check_count(ref_set, num):
1489+ return ref_set.count().addCallback(self.assertEquals, num)
1490+
1491+ def clear_cb(result, refs):
1492+ return refs.clear().addCallback(lambda x: check_count(refs, 0))
1493+
1494+ return self.store.get(FooRefSet, 10).addCallback(first_cb)
1495+
1496+
1497+ def test_reference_set_one(self):
1498+ """
1499+ Call C{one} on a C{DeferredBoundReference}.
1500+ """
1501+ connection = self.database.connect()
1502+ connection.execute("INSERT INTO foo VALUES (11, 'Title 40')")
1503+ connection.execute("INSERT INTO bar VALUES (20, 11, 'Title 50')")
1504+ connection.commit()
1505+ def cb_get(result):
1506+ return result.bars.one().addCallback(cb_one)
1507+ def cb_one(result):
1508+ return self.store.get(Bar, 20).addCallback(check, result)
1509+ def check(result, expected):
1510+ self.assertIdentical(result, expected)
1511+ return self.store.get(FooRefSet, 11).addCallback(cb_get)
1512+
1513+
1514+ def test_reference_set_first(self):
1515+ """
1516+ Call C{first} on an ordered C{DeferredBoundReference}.
1517+ """
1518+ def cb_get(result):
1519+ return result.bars.first().addCallback(cb_one)
1520+ def cb_one(result):
1521+ return self.store.get(Bar, 10).addCallback(check, result)
1522+ def check(result, expected):
1523+ self.assertIdentical(result, expected)
1524+ return self.store.get(FooRefSetOrderID, 10).addCallback(cb_get)
1525+
1526+
1527+ def test_reference_set_last(self):
1528+ """
1529+ Call C{last} on an ordered C{DeferredBoundReference}.
1530+ """
1531+ def cb_get(result):
1532+ return result.bars.last().addCallback(cb_one)
1533+ def cb_one(result):
1534+ return self.store.get(Bar, 11).addCallback(check, result)
1535+ def check(result, expected):
1536+ self.assertIdentical(result, expected)
1537+ return self.store.get(FooRefSetOrderID, 10).addCallback(cb_get)
1538+
1539+
1540+ def test_commit(self):
1541+ """
1542+ Make some changes and commit them.
1543+ """
1544+ def cb_get(result):
1545+ return self.store.remove(result).addCallback(cb_remove)
1546+ def cb_remove(ign):
1547+ return self.store.commit().addCallback(cb_commit)
1548+ def cb_commit(ign):
1549+ # To be sure the data is no more in the db, the best is to
1550+ # directly connect to the db
1551+ connection = self.database.connect()
1552+ result = connection.execute("SELECT * FROM foo")
1553+ self.assertEquals(list(result), [])
1554+ return self.store.get(Foo, 10).addCallback(cb_get)
1555+
1556+
1557+ def test_rollback(self):
1558+ """
1559+ Make and some changes and rollback them.
1560+ """
1561+ def cb_get(result):
1562+ return self.store.remove(result).addCallback(cb_remove)
1563+ def cb_remove(ign):
1564+ return self.store.rollback().addCallback(cbRollback)
1565+ def cbRollback(ign):
1566+ connection = self.database.connect()
1567+ result = connection.execute("SELECT * FROM foo")
1568+ self.assertEquals(list(result), [(10, u"Title 30")])
1569+ return self.store.get(Foo, 10).addCallback(cb_get)
1570+
1571+
1572+ def test_deferred_reference_multithread(self):
1573+ """
1574+ If a store is restarted, the objects in the cache should still be
1575+ usable, in particular an object shouldn't not store a reference to the
1576+ store thread, as it can change.
1577+ """
1578+ def test(ignore):
1579+ # get a bar and retrieve the deferred reference
1580+ def get_foo(bar):
1581+ return bar.foo
1582+
1583+ return self.store.find(Bar, Bar.id == 10).addCallback(lambda x:
1584+ x.one()).addCallback(get_foo)
1585+
1586+ def _restart_store(res):
1587+ def stopped(res):
1588+ self.store = DeferredStore(self.database)
1589+ return self.store.start()
1590+ return self.store.stop().addCallback(stopped)
1591+
1592+ def check(foo):
1593+ self.assertEquals(foo.id, 10)
1594+ self.assertEquals(foo.title, u"Title 30")
1595+
1596+ return test(None
1597+ ).addCallback(_restart_store
1598+ ).addCallback(test
1599+ ).addCallback(check)
1600+
1601+
1602+ def test_of(self):
1603+ """
1604+ The DeferredStore associated with an object is returned by the static
1605+ C{of} method.
1606+ """
1607+ def cb_get(result):
1608+ store = DeferredStore.of(result)
1609+ self.assertIdentical(self.store, store)
1610+
1611+ return self.store.get(Foo, 10).addCallback(cb_get)
1612+
1613+
1614+ def test_thread_check(self):
1615+ """
1616+ A L{ThreadSafetyError} is raised when attempting to do an unsafe
1617+ operation, like accessing a C{Reference} attribute via a
1618+ C{DeferredStore}.
1619+ """
1620+ class WrongBar(Bar):
1621+ foo_sync = Reference(Bar.foo_id, Foo.id)
1622+ d = self.store.find(WrongBar).addCallback(lambda result: result.all())
1623+
1624+ def check(results):
1625+ self.assertRaises(
1626+ ThreadSafetyError, getattr, results[0], "foo_sync")
1627+ return d.addCallback(check)
1628+
1629+
1630+
1631+class StoreThreadTestCase(TestCase):
1632+ """
1633+ Tests for L{StoreThread}.
1634+ """
1635+
1636+ def setUp(self):
1637+ """
1638+ Create an instance of C{StoreThread} and start it.
1639+ """
1640+ self.thread = StoreThread()
1641+ self.thread.start()
1642+
1643+
1644+ def tearDown(self):
1645+ """
1646+ Kill the running thread.
1647+ """
1648+ self.thread.stop()
1649+
1650+
1651+ def test_defer_after_stop(self):
1652+ """
1653+ Deferring calls after store is stopped raises C{AlreadyStopped}.
1654+ """
1655+ def cb_stop(r):
1656+ self.assertFailure(self.thread.defer_to_thread(lambda f : None),
1657+ AlreadyStopped)
1658+ return self.thread.stop().addCallback(cb_stop)
1659+
1660+
1661+ def test_callback(self):
1662+ """
1663+ Fire a simple function in a thread and check its result.
1664+ """
1665+ def testfunc():
1666+ return 1
1667+ return self.thread.defer_to_thread(testfunc
1668+ ).addCallback(self.assertEquals, 1)
1669+
1670+
1671+ def test_errback(self):
1672+ """
1673+ Raising an exception in a thread returns a failure.
1674+ """
1675+ def testfunc():
1676+ raise RuntimeError("Error!")
1677+ return self.assertFailure(self.thread.defer_to_thread(testfunc),
1678+ RuntimeError)
1679+
1680+
1681+
1682+class StorePoolTest(object):
1683+ """
1684+ Tests for L{StorePool}.
1685+ """
1686+
1687+ def setUp(self):
1688+ """
1689+ Build a database with data, a create a pool.
1690+ """
1691+ self.create_database()
1692+ connection = self.connection = self.database.connect()
1693+ self.drop_tables()
1694+ self.create_tables()
1695+ connection.execute("INSERT INTO foo VALUES (10, 'Title 30')")
1696+ connection.execute("INSERT INTO bar VALUES (10, 10, 'Title 40')")
1697+ connection.execute("INSERT INTO bar VALUES (11, 10, 'Title 50')")
1698+ connection.commit()
1699+ self.pool = StorePool(self.database, 2, 5)
1700+ return self.pool.start()
1701+
1702+
1703+ def tearDown(self):
1704+ """
1705+ Stop the pool.
1706+ """
1707+ def _drop(ign):
1708+ self.drop_tables()
1709+ return self.pool.stop().addCallback(_drop)
1710+
1711+
1712+ def drop_tables(self):
1713+ for table in ["foo", "bar"]:
1714+ try:
1715+ self.connection.execute("DROP TABLE %s" % table)
1716+ self.connection.commit()
1717+ except:
1718+ self.connection.rollback()
1719+
1720+
1721+ def test_already_started(self):
1722+ """
1723+ Check that the pool can't be restarted multiple times.
1724+ """
1725+ self.assertRaises(RuntimeError, self.pool.start)
1726+
1727+
1728+ def test_get(self):
1729+ """
1730+ get should return different stores if available.
1731+ """
1732+ def cb_get1(store1):
1733+ return self.pool.get().addCallback(cb_get2, store1)
1734+ def cb_get2(store2, store1):
1735+ self.assertNotIdentical(store1, store2)
1736+ self.assertTrue(store1.started)
1737+ self.assertTrue(store2.started)
1738+ return self.pool.get().addCallback(cb_get1)
1739+
1740+
1741+ def test_get_not_started(self):
1742+ """
1743+ If no store are available, the pool should create a store.
1744+ """
1745+ def cb(ign):
1746+ self.assertEquals(self.pool._stores_created, 0)
1747+ return self.pool.adjust_size(0, 5).addCallback(cb_add)
1748+ def cb_add(ign):
1749+ return self.pool.get().addCallback(cb_get)
1750+ def cb_get(store):
1751+ self.assertTrue(store.started)
1752+ return self.pool.adjust_size(0, 0).addCallback(cb)
1753+
1754+
1755+ def test_waiting_for_store(self):
1756+ """
1757+ Test waiting for a store availability.
1758+ """
1759+ def cb(ign):
1760+ return self.pool.get().addCallback(cb_get1)
1761+ def cb_get1(store1):
1762+ self.assertTrue(store1.started)
1763+ # Now we have a store, no store should be returned by the pool
1764+ # until we put it back
1765+ d1 = self.pool.get().addCallback(cb_get2, store1)
1766+ d2 = self.pool.put(store1)
1767+ return gatherResults([d1, d2])
1768+ def cb_get2(store2, store1):
1769+ self.assertIdentical(store1, store2)
1770+ return self.pool.adjust_size(1, 1).addCallback(cb)
1771+
1772+
1773+ def test_adjust_size_minmax(self):
1774+ """
1775+ Test sanity check on min/max - i.e. min <= max.
1776+ """
1777+ return self.assertFailure(self.pool.adjust_size(2, 1), ValueError)
1778+
1779+
1780+ def test_adjust_size_nonnegative(self):
1781+ """
1782+ Test sanity check for nonnegative min.
1783+ """
1784+ return self.assertFailure(self.pool.adjust_size(-1), ValueError)
1785+
1786+
1787+ @deferredGenerator
1788+ def test_concurrent_data(self):
1789+ """
1790+ Test that different stores have different states: if the first store
1791+ hasn't yet committed, the second one shouldn't get the new data.
1792+ """
1793+ foo = Foo()
1794+ foo.title = u"Great title"
1795+ foo.id = 11
1796+ d = self.pool.get()
1797+ wfd = waitForDeferred(d)
1798+ yield wfd
1799+ store1 = wfd.getResult()
1800+
1801+ d = self.pool.get()
1802+ wfd = waitForDeferred(d)
1803+ yield wfd
1804+ store2 = wfd.getResult()
1805+
1806+ d = store1.add(foo)
1807+ wfd = waitForDeferred(d)
1808+ yield wfd
1809+ wfd.getResult()
1810+
1811+ d = store1.get(Foo, 11)
1812+ wfd = waitForDeferred(d)
1813+ yield wfd
1814+ foo2 = wfd.getResult()
1815+
1816+ # The object is already in the store cache
1817+ self.assertIdentical(foo2, foo)
1818+
1819+ d = store2.get(Foo, 11)
1820+ wfd = waitForDeferred(d)
1821+ yield wfd
1822+ foo3 = wfd.getResult()
1823+
1824+ # The object isn't in the db yet
1825+ self.assertIdentical(foo3, None)
1826+
1827+ # Let's rollback, because even select open a transaction
1828+ d = store2.rollback()
1829+ wfd = waitForDeferred(d)
1830+ yield wfd
1831+ wfd.getResult()
1832+
1833+ # Let's commit
1834+ d = store1.commit()
1835+ wfd = waitForDeferred(d)
1836+ yield wfd
1837+ wfd.getResult()
1838+
1839+ d = store2.get(Foo, 11)
1840+ wfd = waitForDeferred(d)
1841+ yield wfd
1842+ foo4 = wfd.getResult()
1843+
1844+ # The objects must be different
1845+ self.assertNotIdentical(foo4, foo)
1846+ # But the content must be the same
1847+ self.assertEquals(foo4.title, u"Great title")
1848+
1849+
1850+ def test_no_overflow(self):
1851+ """
1852+ Test that pool does not allocate more connections than store_max.
1853+ """
1854+ ds = []
1855+ stores = set()
1856+
1857+ def cb_get(store):
1858+ stores.add(store)
1859+ return self.pool.put(store)
1860+
1861+ for i in range(10):
1862+ ds.append(self.pool.get().addCallback(cb_get))
1863+
1864+ def checkInstances(result):
1865+ self.assertEquals(len(stores), 5)
1866+
1867+ return gatherResults(ds).addCallback(checkInstances)
1868+
1869+
1870+ def test_start_failure(self):
1871+ """
1872+ If a store failed to start, the number of allocated connections doesn't
1873+ grow, so we're later able to start more stores.
1874+ """
1875+ ds = []
1876+ stores = set()
1877+
1878+ class DontStartStore(DeferredStore):
1879+ def start(self):
1880+ return fail(RuntimeError("oops"))
1881+
1882+ calls = []
1883+
1884+ def vicious_store_factory(database):
1885+ if not calls:
1886+ store = DontStartStore(database)
1887+ else:
1888+ store = DeferredStore(database)
1889+ calls.append(None)
1890+ return store
1891+
1892+ self.pool.store_factory = vicious_store_factory
1893+
1894+ def cb_get(store):
1895+ stores.add(store)
1896+ return self.pool.put(store)
1897+
1898+ errors = []
1899+
1900+ def save_errors(failure):
1901+ errors.append(failure)
1902+
1903+ for i in range(6):
1904+ ds.append(
1905+ self.pool.get().addCallback(cb_get).addErrback(save_errors))
1906+
1907+ def checkInstances(result):
1908+ self.assertEquals(len(stores), 5)
1909+ self.assertEquals(len(errors), 1)
1910+ errors[0].trap(RuntimeError)
1911+
1912+ return gatherResults(ds).addCallback(checkInstances)
1913+
1914+
1915+ def test_arguments(self):
1916+ """
1917+ Arguments are passed along to transacted method when store
1918+ is available.
1919+ """
1920+ def tx(store, a, b=None):
1921+ self.assertIsInstance(store, DeferredStore)
1922+ self.assertEquals(1, a)
1923+ self.assertEquals(2, b)
1924+ return succeed("ok")
1925+ return self.pool.transact(tx, 1, b=2)
1926+
1927+
1928+ def test_commit(self):
1929+ """
1930+ Changes made inside a successful transaction are committed.
1931+ """
1932+ @deferredGenerator
1933+ def check(result):
1934+ d = self.pool.get()
1935+ wfd = waitForDeferred(d)
1936+ yield wfd
1937+ store = wfd.getResult()
1938+ d = store.execute("SELECT * FROM foo ORDER BY id", [])
1939+ wfd = waitForDeferred(d)
1940+ yield wfd
1941+ dr = wfd.getResult()
1942+ d = dr.get_all()
1943+ wfd = waitForDeferred(d)
1944+ yield wfd
1945+ results = wfd.getResult()
1946+ self.assertEquals(2, len(results))
1947+ self.assertEquals(1, results[0][0])
1948+ self.assertEquals("test", results[0][1])
1949+
1950+ def tx(store):
1951+ d = store.execute("INSERT INTO foo(id, title) "
1952+ "VALUES (1, 'test')", noresult=True)
1953+ return d
1954+
1955+ return self.pool.transact(tx).addCallback(check)
1956+
1957+
1958+ def test_rollback(self):
1959+ """
1960+ Changes made inside a failed transaction are rolled back.
1961+ """
1962+ @deferredGenerator
1963+ def check(reason):
1964+ d = self.pool.get()
1965+ wfd = waitForDeferred(d)
1966+ yield wfd
1967+ store = wfd.getResult()
1968+ d = store.execute("SELECT * FROM foo", [])
1969+ wfd = waitForDeferred(d)
1970+ yield wfd
1971+ dr = wfd.getResult()
1972+ d = dr.get_all()
1973+ wfd = waitForDeferred(d)
1974+ yield wfd
1975+ results = wfd.getResult()
1976+ self.assertEquals(1, len(results))
1977+
1978+ @deferredGenerator
1979+ def tx(store):
1980+ d = store.execute("INSERT INTO foo(id, title) "
1981+ "VALUES (1, 'test')", [])
1982+ wfd = waitForDeferred(d)
1983+ yield wfd
1984+ wfd.getResult()
1985+ d = store.execute("INSERT INTO foo(id, title) "
1986+ "VALUES (1, 'test')", [])
1987+ wfd = waitForDeferred(d)
1988+ yield wfd
1989+ wfd.getResult()
1990+
1991+ return self.pool.transact(tx).addBoth(check)
1992+
1993+
1994+ def test_return_value(self):
1995+ """
1996+ Final return value should match return value from successful call to
1997+ transacted function.
1998+ """
1999+ def check(result):
2000+ self.assertEquals("completed", result)
2001+
2002+ def cb(result):
2003+ return "completed"
2004+
2005+ def tx(store):
2006+ return store.execute("INSERT INTO foo(id, title) "
2007+ "VALUES (1, 'test')", []
2008+ ).addCallback(cb)
2009+
2010+ return self.pool.transact(tx).addCallback(check)
2011+
2012+
2013+ def test_poolsize_after_success(self):
2014+ """
2015+ After successful transaction, pool size should be same size as before.
2016+ """
2017+ size = len(self.pool._stores)
2018+
2019+ def check(result):
2020+ self.assertEquals(size, len(self.pool._stores))
2021+
2022+ def tx(store):
2023+ d = store.execute("SELECT * from foo", [])
2024+ return d.addCallback(lambda result: result.get_all())
2025+
2026+ return self.pool.transact(tx).addCallback(check)
2027+
2028+
2029+ def test_poolsize_after_failure(self):
2030+ """
2031+ After failed transaction, pool size is restored to the initial value.
2032+ """
2033+ size = len(self.pool._stores)
2034+
2035+ def check(reason):
2036+ self.assertEquals(size, len(self.pool._stores))
2037+
2038+ def tx(store):
2039+ return store.execute("SELECT * from not_a_table", [])
2040+
2041+ d = self.assertFailure(self.pool.transact(tx), OperationalError)
2042+ return d.addCallback(check)
2043+
2044+
2045+ def test_failure_propagation(self):
2046+ """
2047+ A custom exception is propagated by a C{transact} call.
2048+ """
2049+ class MyException(Exception):
2050+ pass
2051+
2052+ def tx(store):
2053+ raise MyException("Bad things happened")
2054+
2055+ return self.assertFailure(self.pool.transact(tx), MyException)
2056+
2057+
2058+ def test_non_deferred_function(self):
2059+ """
2060+ C{transact} can handle functions that don't return a C{Deferred}.
2061+ """
2062+ def tx(store):
2063+ return "foo"
2064+ return self.pool.transact(tx).addCallback(self.assertEquals, "foo")
2065+
2066+
2067+ def test_rollback_failure(self):
2068+ """
2069+ If C{rollback} fails, the store is put back into the pool.
2070+ """
2071+
2072+ def cb_get(store):
2073+ store.rollback = lambda: fail(RuntimeError("oops"))
2074+ d = self.assertFailure(self.pool.put(store), RuntimeError)
2075+ return d.addCallback(get_all)
2076+
2077+ def get_all(ignore):
2078+ dl = []
2079+ for i in range(5):
2080+ dl.append(self.pool.get())
2081+ return gatherResults(dl).addCallback(check)
2082+
2083+ def check(stores):
2084+ # The fact that we're here show that the test succeeds, because we
2085+ # didn't hang waiting for a store
2086+ self.assertEquals(len(stores), 5)
2087+
2088+ return self.pool.get().addCallback(cb_get)
2089
2090=== added file 'tests/twisted/mysql.py'
2091--- tests/twisted/mysql.py 1970-01-01 00:00:00 +0000
2092+++ tests/twisted/mysql.py 2009-12-27 11:31:15 +0000
2093@@ -0,0 +1,62 @@
2094+import os
2095+
2096+from storm.database import create_database
2097+
2098+from tests.twisted.base import DeferredStoreTest, StorePoolTest
2099+
2100+from twisted.trial.unittest import TestCase
2101+
2102+
2103+class MySQLDeferredStoreTest(TestCase, DeferredStoreTest):
2104+
2105+ def setUp(self):
2106+ return DeferredStoreTest.setUp(self)
2107+
2108+ def tearDown(self):
2109+ return DeferredStoreTest.tearDown(self)
2110+
2111+ def is_supported(self):
2112+ return bool(os.environ.get("STORM_MYSQL_URI"))
2113+
2114+ def create_database(self):
2115+ self.database = create_database(os.environ["STORM_MYSQL_URI"])
2116+
2117+ def create_tables(self):
2118+ connection = self.connection
2119+ connection.execute("CREATE TABLE foo "
2120+ "(id INT PRIMARY KEY AUTO_INCREMENT,"
2121+ " title VARCHAR(50) DEFAULT 'Default Title')"
2122+ " ENGINE=InnoDB")
2123+ connection.execute("CREATE TABLE bar "
2124+ "(id INT PRIMARY KEY AUTO_INCREMENT,"
2125+ " foo_id INTEGER, title VARCHAR(50))"
2126+ " ENGINE=InnoDB")
2127+ connection.execute("CREATE TABLE egg "
2128+ "(id INT PRIMARY KEY AUTO_INCREMENT, value INTEGER)"
2129+ " ENGINE=InnoDB")
2130+
2131+
2132+class MySQLStorePoolTest(TestCase, StorePoolTest):
2133+
2134+ def setUp(self):
2135+ return StorePoolTest.setUp(self)
2136+
2137+ def tearDown(self):
2138+ return StorePoolTest.tearDown(self)
2139+
2140+ def is_supported(self):
2141+ return bool(os.environ.get("STORM_MYSQL_URI"))
2142+
2143+ def create_database(self):
2144+ self.database = create_database(os.environ["STORM_MYSQL_URI"])
2145+
2146+ def create_tables(self):
2147+ connection = self.connection
2148+ connection.execute("CREATE TABLE foo "
2149+ "(id INT PRIMARY KEY AUTO_INCREMENT,"
2150+ " title VARCHAR(50) DEFAULT 'Default Title')"
2151+ " ENGINE=InnoDB")
2152+ connection.execute("CREATE TABLE bar "
2153+ "(id INT PRIMARY KEY AUTO_INCREMENT,"
2154+ " foo_id INTEGER, title VARCHAR(50))"
2155+ " ENGINE=InnoDB")
2156
2157=== added file 'tests/twisted/postgres.py'
2158--- tests/twisted/postgres.py 1970-01-01 00:00:00 +0000
2159+++ tests/twisted/postgres.py 2009-12-27 11:31:15 +0000
2160@@ -0,0 +1,57 @@
2161+import os
2162+
2163+from storm.database import create_database
2164+
2165+from tests.twisted.base import DeferredStoreTest, StorePoolTest
2166+
2167+from twisted.trial.unittest import TestCase
2168+
2169+
2170+class PostgresDeferredStoreTest(TestCase, DeferredStoreTest):
2171+
2172+ def setUp(self):
2173+ return DeferredStoreTest.setUp(self)
2174+
2175+ def tearDown(self):
2176+ return DeferredStoreTest.tearDown(self)
2177+
2178+ def is_supported(self):
2179+ return bool(os.environ.get("STORM_POSTGRES_URI"))
2180+
2181+ def create_database(self):
2182+ self.database = create_database(os.environ["STORM_POSTGRES_URI"])
2183+
2184+ def create_tables(self):
2185+ connection = self.connection
2186+ connection.execute("CREATE TABLE foo "
2187+ "(id SERIAL PRIMARY KEY,"
2188+ " title VARCHAR DEFAULT 'Default Title')")
2189+ connection.execute("CREATE TABLE bar "
2190+ "(id SERIAL PRIMARY KEY,"
2191+ " foo_id INTEGER, title VARCHAR)")
2192+ connection.execute("CREATE TABLE egg "
2193+ "(id SERIAL PRIMARY KEY, value INTEGER)")
2194+
2195+
2196+class PostgresStorePoolTest(TestCase, StorePoolTest):
2197+
2198+ def setUp(self):
2199+ return StorePoolTest.setUp(self)
2200+
2201+ def tearDown(self):
2202+ return StorePoolTest.tearDown(self)
2203+
2204+ def is_supported(self):
2205+ return bool(os.environ.get("STORM_POSTGRES_URI"))
2206+
2207+ def create_database(self):
2208+ self.database = create_database(os.environ["STORM_POSTGRES_URI"])
2209+
2210+ def create_tables(self):
2211+ connection = self.connection
2212+ connection.execute("CREATE TABLE foo "
2213+ "(id SERIAL PRIMARY KEY,"
2214+ " title VARCHAR DEFAULT 'Default Title')")
2215+ connection.execute("CREATE TABLE bar "
2216+ "(id SERIAL PRIMARY KEY,"
2217+ " foo_id INTEGER, title VARCHAR)")
2218
2219=== added file 'tests/twisted/sqlite.py'
2220--- tests/twisted/sqlite.py 1970-01-01 00:00:00 +0000
2221+++ tests/twisted/sqlite.py 2009-12-27 11:31:15 +0000
2222@@ -0,0 +1,65 @@
2223+from storm.databases.sqlite import SQLite
2224+from storm.uri import URI
2225+
2226+from tests.twisted.base import DeferredStoreTest, StorePoolTest
2227+from tests.helper import TestHelper, MakePath
2228+
2229+from twisted.trial.unittest import TestCase
2230+
2231+
2232+class SQLiteDeferredStoreTest(TestCase, TestHelper, DeferredStoreTest):
2233+
2234+ helpers = [MakePath]
2235+
2236+ def setUp(self):
2237+ TestHelper.setUp(self)
2238+ return DeferredStoreTest.setUp(self)
2239+
2240+ def tearDown(self):
2241+ def cb(passthrough):
2242+ TestHelper.tearDown(self)
2243+ return passthrough
2244+ return DeferredStoreTest.tearDown(self).addBoth(cb)
2245+
2246+ def create_database(self):
2247+ self.database = SQLite(URI("sqlite:%s?synchronous=OFF" %
2248+ self.make_path()))
2249+
2250+ def create_tables(self):
2251+ connection = self.connection
2252+ connection.execute("CREATE TABLE foo "
2253+ "(id INTEGER PRIMARY KEY,"
2254+ " title VARCHAR DEFAULT 'Default Title')")
2255+ connection.execute("CREATE TABLE bar "
2256+ "(id INTEGER PRIMARY KEY,"
2257+ " foo_id INTEGER, title VARCHAR)")
2258+ connection.execute("CREATE TABLE egg "
2259+ "(id INTEGER PRIMARY KEY, value INTEGER)")
2260+
2261+
2262+class SQLiteStorePoolTest(TestCase, TestHelper, StorePoolTest):
2263+
2264+ helpers = [MakePath]
2265+
2266+ def setUp(self):
2267+ TestHelper.setUp(self)
2268+ return StorePoolTest.setUp(self)
2269+
2270+ def tearDown(self):
2271+ def cb(passthrough):
2272+ TestHelper.tearDown(self)
2273+ return passthrough
2274+ return StorePoolTest.tearDown(self).addBoth(cb)
2275+
2276+ def create_database(self):
2277+ self.database = SQLite(URI("sqlite:%s?synchronous=OFF" %
2278+ self.make_path()))
2279+
2280+ def create_tables(self):
2281+ connection = self.connection
2282+ connection.execute("CREATE TABLE foo "
2283+ "(id INTEGER PRIMARY KEY,"
2284+ " title VARCHAR DEFAULT 'Default Title')")
2285+ connection.execute("CREATE TABLE bar "
2286+ "(id INTEGER PRIMARY KEY,"
2287+ " foo_id INTEGER, title VARCHAR)")