Merge lp:~therve/storm/twisted-integration into lp:storm
- twisted-integration
- Merge into trunk
Status: | Work in progress |
---|---|
Proposed branch: | lp:~therve/storm/twisted-integration |
Merge into: | lp:storm |
Diff against target: |
2287 lines (+2201/-1) 9 files modified
storm/database.py (+18/-1) storm/exceptions.py (+4/-0) storm/twisted/__init__.py (+23/-0) storm/twisted/store.py (+480/-0) storm/twisted/wrapper.py (+206/-0) tests/twisted/base.py (+1286/-0) tests/twisted/mysql.py (+62/-0) tests/twisted/postgres.py (+57/-0) tests/twisted/sqlite.py (+65/-0) |
To merge this branch: | bzr merge lp:~therve/storm/twisted-integration |
Related bugs: |
Reviewer | Review Type | Date Requested | Status |
---|---|---|---|
Christopher Armstrong (community) | Abstain | ||
Review via email: mp+3733@code.launchpad.net |
Commit message
Description of the change
Gabriel (gabriel-rossetti) wrote : | # |
Christopher Armstrong wrote:
> Christopher Armstrong has proposed merging lp:~therve/storm/twisted-integration into lp:storm.
>
> Requested reviews:
> Storm Developers (storm)
>
I second that :-)
Gabriel
Christopher Armstrong (radix) wrote : | # |
[7] Expose reload to DeferredStore
[8] Maybe add deferredStore.
Drew Smathers (djfroofy) wrote : | # |
> [4] It may be useful to extend the core Store API to allow us to avoid accessing internal attributes in the wrapper. Like, for example, Store.disable_
That sounds like a great idea to me. Otherwise having to refetch objects just to avoid lazy value errors in less than ideal.
- 215. By Thomas Herve
-
Merge from trunk.
- 216. By Thomas Herve
-
Move the tests to be able to run Postgres/MySQL tests as well
- 217. By Thomas Herve
-
Add postgres tests
- 218. By Thomas Herve
-
Add mysql, some tests failing.
- 219. By Thomas Herve
-
Use InnoDB to be able to use transactions.
- 220. By Thomas Herve
-
Merge from trunk.
- 221. By Thomas Herve
-
Force respect of the max_stores value, by incrementing the number of stores
once start_store is called, instead of doing it in the callback [f=373816] - 222. By Thomas Herve
-
Use TestHelper/MakePath to manage sqlite file.
- 223. By Thomas Herve
-
Fix doc typos
- 224. By Thomas Herve
-
Merge from trunk
- 225. By Thomas Herve
-
Some tests cleanup
- 226. By Thomas Herve
-
Merge from trunk
- 227. By Thomas Herve
-
Cleanups
- 228. By Thomas Herve
-
Check thread in Connection to catch problems.
- 229. By Thomas Herve
-
Handle errors in rollback
- 230. By Thomas Herve
-
Typo, cleanup
- 231. By Thomas Herve
-
Make start/stop methods private
Unmerged revisions
- 231. By Thomas Herve
-
Make start/stop methods private
- 230. By Thomas Herve
-
Typo, cleanup
- 229. By Thomas Herve
-
Handle errors in rollback
- 228. By Thomas Herve
-
Check thread in Connection to catch problems.
- 227. By Thomas Herve
-
Cleanups
- 226. By Thomas Herve
-
Merge from trunk
- 225. By Thomas Herve
-
Some tests cleanup
- 224. By Thomas Herve
-
Merge from trunk
- 223. By Thomas Herve
-
Fix doc typos
- 222. By Thomas Herve
-
Use TestHelper/MakePath to manage sqlite file.
Preview Diff
1 | === modified file 'storm/database.py' | |||
2 | --- storm/database.py 2009-07-30 06:19:27 +0000 | |||
3 | +++ storm/database.py 2009-12-27 11:31:14 +0000 | |||
4 | @@ -24,12 +24,13 @@ | |||
5 | 24 | This is the common code for database support; specific databases are | 24 | This is the common code for database support; specific databases are |
6 | 25 | supported in modules in L{storm.databases}. | 25 | supported in modules in L{storm.databases}. |
7 | 26 | """ | 26 | """ |
8 | 27 | import threading | ||
9 | 27 | 28 | ||
10 | 28 | from storm.expr import Expr, State, compile | 29 | from storm.expr import Expr, State, compile |
11 | 29 | from storm.tracer import trace | 30 | from storm.tracer import trace |
12 | 30 | from storm.variables import Variable | 31 | from storm.variables import Variable |
13 | 31 | from storm.exceptions import ( | 32 | from storm.exceptions import ( |
15 | 32 | ClosedError, DatabaseError, DisconnectionError, Error) | 33 | ClosedError, DatabaseError, DisconnectionError, Error, ThreadSafetyError) |
16 | 33 | from storm.uri import URI | 34 | from storm.uri import URI |
17 | 34 | import storm | 35 | import storm |
18 | 35 | 36 | ||
19 | @@ -180,6 +181,20 @@ | |||
20 | 180 | self._database = database # Ensures deallocation order. | 181 | self._database = database # Ensures deallocation order. |
21 | 181 | self._event = event | 182 | self._event = event |
22 | 182 | self._raw_connection = self._database.raw_connect() | 183 | self._raw_connection = self._database.raw_connect() |
23 | 184 | self._thread = threading.currentThread() | ||
24 | 185 | |||
25 | 186 | def _check_thread(self): | ||
26 | 187 | """ | ||
27 | 188 | Check if the current thread is the same thread that created the | ||
28 | 189 | connection. This is a safety check that prevents breaking thread | ||
29 | 190 | boundaries which can create weird bugs. C{execute}, C{commit} and | ||
30 | 191 | C{rollback} should call it, directly or indirectly. | ||
31 | 192 | """ | ||
32 | 193 | thread = threading.currentThread() | ||
33 | 194 | if thread is not self._thread: | ||
34 | 195 | raise ThreadSafetyError( | ||
35 | 196 | "'%s' is not the connection thread '%s'" % | ||
36 | 197 | (thread, self._thread)) | ||
37 | 183 | 198 | ||
38 | 184 | def __del__(self): | 199 | def __del__(self): |
39 | 185 | """Close the connection.""" | 200 | """Close the connection.""" |
40 | @@ -240,6 +255,7 @@ | |||
41 | 240 | 255 | ||
42 | 241 | def rollback(self): | 256 | def rollback(self): |
43 | 242 | """Rollback the connection.""" | 257 | """Rollback the connection.""" |
44 | 258 | self._check_thread() | ||
45 | 243 | if self._state == STATE_CONNECTED: | 259 | if self._state == STATE_CONNECTED: |
46 | 244 | try: | 260 | try: |
47 | 245 | self._raw_connection.rollback() | 261 | self._raw_connection.rollback() |
48 | @@ -314,6 +330,7 @@ | |||
49 | 314 | If the connection is marked as dead, or if we can't reconnect, | 330 | If the connection is marked as dead, or if we can't reconnect, |
50 | 315 | then raise DisconnectionError. | 331 | then raise DisconnectionError. |
51 | 316 | """ | 332 | """ |
52 | 333 | self._check_thread() | ||
53 | 317 | if self._state == STATE_CONNECTED: | 334 | if self._state == STATE_CONNECTED: |
54 | 318 | return | 335 | return |
55 | 319 | elif self._state == STATE_DISCONNECTED: | 336 | elif self._state == STATE_DISCONNECTED: |
56 | 320 | 337 | ||
57 | === modified file 'storm/exceptions.py' | |||
58 | --- storm/exceptions.py 2009-05-15 08:43:21 +0000 | |||
59 | +++ storm/exceptions.py 2009-12-27 11:31:14 +0000 | |||
60 | @@ -82,6 +82,10 @@ | |||
61 | 82 | pass | 82 | pass |
62 | 83 | 83 | ||
63 | 84 | 84 | ||
64 | 85 | class ThreadSafetyError(StormError): | ||
65 | 86 | """Exception raised when cross-threads operations are attempted.""" | ||
66 | 87 | |||
67 | 88 | |||
68 | 85 | class Error(StormError): | 89 | class Error(StormError): |
69 | 86 | pass | 90 | pass |
70 | 87 | 91 | ||
71 | 88 | 92 | ||
72 | === added directory 'storm/twisted' | |||
73 | === added file 'storm/twisted/__init__.py' | |||
74 | --- storm/twisted/__init__.py 1970-01-01 00:00:00 +0000 | |||
75 | +++ storm/twisted/__init__.py 2009-12-27 11:31:14 +0000 | |||
76 | @@ -0,0 +1,23 @@ | |||
77 | 1 | # | ||
78 | 2 | # Copyright (c) 2007 Canonical | ||
79 | 3 | # Copyright (c) 2007 Thomas Herve <thomas@nimail.org> | ||
80 | 4 | # | ||
81 | 5 | # This file is part of Storm Object Relational Mapper. | ||
82 | 6 | # | ||
83 | 7 | # Storm is free software; you can redistribute it and/or modify | ||
84 | 8 | # it under the terms of the GNU Lesser General Public License as | ||
85 | 9 | # published by the Free Software Foundation; either version 2.1 of | ||
86 | 10 | # the License, or (at your option) any later version. | ||
87 | 11 | # | ||
88 | 12 | # Storm is distributed in the hope that it will be useful, | ||
89 | 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
90 | 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
91 | 15 | # GNU Lesser General Public License for more details. | ||
92 | 16 | # | ||
93 | 17 | # You should have received a copy of the GNU Lesser General Public License | ||
94 | 18 | # along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
95 | 19 | # | ||
96 | 20 | |||
97 | 21 | """ | ||
98 | 22 | Asynchronous wrapper around storm to be used within a Twisted application. | ||
99 | 23 | """ | ||
100 | 0 | 24 | ||
101 | === added file 'storm/twisted/store.py' | |||
102 | --- storm/twisted/store.py 1970-01-01 00:00:00 +0000 | |||
103 | +++ storm/twisted/store.py 2009-12-27 11:31:14 +0000 | |||
104 | @@ -0,0 +1,480 @@ | |||
105 | 1 | # | ||
106 | 2 | # Copyright (c) 2007 Canonical | ||
107 | 3 | # Copyright (c) 2007 Thomas Herve <thomas@nimail.org> | ||
108 | 4 | # | ||
109 | 5 | # This file is part of Storm Object Relational Mapper. | ||
110 | 6 | # | ||
111 | 7 | # Storm is free software; you can redistribute it and/or modify | ||
112 | 8 | # it under the terms of the GNU Lesser General Public License as | ||
113 | 9 | # published by the Free Software Foundation; either version 2.1 of | ||
114 | 10 | # the License, or (at your option) any later version. | ||
115 | 11 | # | ||
116 | 12 | # Storm is distributed in the hope that it will be useful, | ||
117 | 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
118 | 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
119 | 15 | # GNU Lesser General Public License for more details. | ||
120 | 16 | # | ||
121 | 17 | # You should have received a copy of the GNU Lesser General Public License | ||
122 | 18 | # along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
123 | 19 | # | ||
124 | 20 | |||
125 | 21 | """ | ||
126 | 22 | Store wrapper and custom thread runner. | ||
127 | 23 | """ | ||
128 | 24 | |||
129 | 25 | from threading import Thread | ||
130 | 26 | from Queue import Queue | ||
131 | 27 | |||
132 | 28 | from storm.store import Store, AutoReload | ||
133 | 29 | from storm.info import get_obj_info | ||
134 | 30 | from storm.twisted.wrapper import partial, DeferredResult, DeferredResultSet | ||
135 | 31 | |||
136 | 32 | from twisted.internet.defer import Deferred, deferredGenerator, maybeDeferred | ||
137 | 33 | from twisted.internet.defer import waitForDeferred, succeed, fail | ||
138 | 34 | from twisted.python.failure import Failure | ||
139 | 35 | |||
140 | 36 | |||
141 | 37 | |||
142 | 38 | class AlreadyStopped(Exception): | ||
143 | 39 | """ | ||
144 | 40 | Except raised when a store is stopped multiple time. | ||
145 | 41 | """ | ||
146 | 42 | |||
147 | 43 | |||
148 | 44 | |||
149 | 45 | class StoreThread(Thread): | ||
150 | 46 | """ | ||
151 | 47 | A thread class that wraps methods calls and fires deferred in the reactor | ||
152 | 48 | thread. | ||
153 | 49 | """ | ||
154 | 50 | STOP = object() | ||
155 | 51 | |||
156 | 52 | def __init__(self): | ||
157 | 53 | """ | ||
158 | 54 | Initialize the thread, and create a L{Queue} to stack jobs. | ||
159 | 55 | """ | ||
160 | 56 | Thread.__init__(self) | ||
161 | 57 | self.setDaemon(True) | ||
162 | 58 | self._queue = Queue() | ||
163 | 59 | self._stop_deferred = None | ||
164 | 60 | self.stopped = False | ||
165 | 61 | |||
166 | 62 | |||
167 | 63 | def defer_to_thread(self, f, *args, **kwargs): | ||
168 | 64 | """ | ||
169 | 65 | Run the given function in the thread, wrapping the result with a | ||
170 | 66 | L{Deferred}. | ||
171 | 67 | |||
172 | 68 | @return: a deferred whose result will be the result of the function. | ||
173 | 69 | @rtype: C{Deferred} | ||
174 | 70 | """ | ||
175 | 71 | if self.stopped: | ||
176 | 72 | # to prevent having pending calls after the thread got stopped | ||
177 | 73 | return fail(AlreadyStopped(f)) | ||
178 | 74 | d = Deferred() | ||
179 | 75 | self._queue.put((d, f, args, kwargs)) | ||
180 | 76 | return d | ||
181 | 77 | |||
182 | 78 | |||
183 | 79 | def run(self): | ||
184 | 80 | """ | ||
185 | 81 | Main execution loop: retrieve jobs from the queue and run them. | ||
186 | 82 | """ | ||
187 | 83 | from twisted.internet import reactor | ||
188 | 84 | o = self._queue.get() | ||
189 | 85 | while o is not self.STOP: | ||
190 | 86 | d, f, args, kwargs = o | ||
191 | 87 | try: | ||
192 | 88 | result = f(*args, **kwargs) | ||
193 | 89 | except: | ||
194 | 90 | f = Failure() | ||
195 | 91 | reactor.callFromThread(d.errback, f) | ||
196 | 92 | else: | ||
197 | 93 | reactor.callFromThread(d.callback, result) | ||
198 | 94 | o = self._queue.get() | ||
199 | 95 | reactor.callFromThread(self._stop_deferred.callback, None) | ||
200 | 96 | |||
201 | 97 | |||
202 | 98 | def stop(self): | ||
203 | 99 | """ | ||
204 | 100 | Stop the thread. | ||
205 | 101 | """ | ||
206 | 102 | if self.stopped: | ||
207 | 103 | return self._stop_deferred | ||
208 | 104 | self._stop_deferred = Deferred() | ||
209 | 105 | self._queue.put(self.STOP) | ||
210 | 106 | self.stopped = True | ||
211 | 107 | return self._stop_deferred | ||
212 | 108 | |||
213 | 109 | |||
214 | 110 | |||
215 | 111 | class DeferredStore(object): | ||
216 | 112 | """ | ||
217 | 113 | A wrapper around L{Store} to have async operations. | ||
218 | 114 | """ | ||
219 | 115 | store = None | ||
220 | 116 | |||
221 | 117 | def __init__(self, database): | ||
222 | 118 | """ | ||
223 | 119 | @param database: instance of database providing connection, used to | ||
224 | 120 | instantiate the store later. | ||
225 | 121 | @type database: L{storm.database.Database} | ||
226 | 122 | """ | ||
227 | 123 | self.thread = StoreThread() | ||
228 | 124 | self.database = database | ||
229 | 125 | self.started = False | ||
230 | 126 | |||
231 | 127 | |||
232 | 128 | def start(self): | ||
233 | 129 | """ | ||
234 | 130 | Start the store. | ||
235 | 131 | |||
236 | 132 | @return: a deferred that will fire once the store is started. | ||
237 | 133 | """ | ||
238 | 134 | if not self.started: | ||
239 | 135 | self.started = True | ||
240 | 136 | self.thread.start() | ||
241 | 137 | # Add a event trigger to be sure that the thread is stopped | ||
242 | 138 | from twisted.internet import reactor | ||
243 | 139 | reactor.addSystemEventTrigger( | ||
244 | 140 | "before", "shutdown", self.stop) | ||
245 | 141 | return self.thread.defer_to_thread(Store, self.database | ||
246 | 142 | ).addCallback(self._got_store) | ||
247 | 143 | else: | ||
248 | 144 | raise RuntimeError("Already started") | ||
249 | 145 | |||
250 | 146 | |||
251 | 147 | def _got_store(self, store): | ||
252 | 148 | """ | ||
253 | 149 | Internal method called when the store is created, initializing most of | ||
254 | 150 | the API methods. | ||
255 | 151 | """ | ||
256 | 152 | self.store = store | ||
257 | 153 | # Maybe not ? | ||
258 | 154 | self.store._deferredStore = self | ||
259 | 155 | for methodName in ("commit", "flush", "remove", "reload", | ||
260 | 156 | "rollback"): | ||
261 | 157 | method = partial(self.thread.defer_to_thread, | ||
262 | 158 | getattr(self.store, methodName)) | ||
263 | 159 | setattr(self, methodName, method) | ||
264 | 160 | |||
265 | 161 | self._do_resolve_lazy_value = self.store._resolve_lazy_value | ||
266 | 162 | self.store._resolve_lazy_value = self._resolve_lazy_value | ||
267 | 163 | |||
268 | 164 | |||
269 | 165 | def get(self, cls, key): | ||
270 | 166 | def _get(): | ||
271 | 167 | obj = self.store.get(cls, key) | ||
272 | 168 | if obj is not None: | ||
273 | 169 | obj_info = get_obj_info(obj) | ||
274 | 170 | self._do_resolve_lazy_value(obj_info, None, AutoReload) | ||
275 | 171 | return obj | ||
276 | 172 | return self.thread.defer_to_thread(_get) | ||
277 | 173 | |||
278 | 174 | |||
279 | 175 | def add(self, obj): | ||
280 | 176 | """ | ||
281 | 177 | Specific add method that doesn't return any result, to not make think | ||
282 | 178 | that it's something usable. | ||
283 | 179 | """ | ||
284 | 180 | def _add(): | ||
285 | 181 | self.store.add(obj) | ||
286 | 182 | return self.thread.defer_to_thread(_add) | ||
287 | 183 | |||
288 | 184 | |||
289 | 185 | def execute(self, *args, **kwargs): | ||
290 | 186 | """ | ||
291 | 187 | Wrapper around C{execute} to have a C{DeferredResult} instead of the | ||
292 | 188 | standard L{storm.database.Result} object. | ||
293 | 189 | """ | ||
294 | 190 | if self.store is None: | ||
295 | 191 | raise RuntimeError("Store not started") | ||
296 | 192 | return self.thread.defer_to_thread( | ||
297 | 193 | self.store.execute, *args, **kwargs | ||
298 | 194 | ).addCallback(self._cb_execute) | ||
299 | 195 | |||
300 | 196 | |||
301 | 197 | def _cb_execute(self, result): | ||
302 | 198 | """ | ||
303 | 199 | Wrap the result with a C{DeferredResult}. | ||
304 | 200 | """ | ||
305 | 201 | if result is not None: | ||
306 | 202 | return DeferredResult(self.thread, result) | ||
307 | 203 | |||
308 | 204 | |||
309 | 205 | def find(self, *args, **kwargs): | ||
310 | 206 | """ | ||
311 | 207 | Wrapper around C{find}. | ||
312 | 208 | """ | ||
313 | 209 | if self.store is None: | ||
314 | 210 | raise RuntimeError("Store not started") | ||
315 | 211 | return self.thread.defer_to_thread( | ||
316 | 212 | self.store.find, *args, **kwargs | ||
317 | 213 | ).addCallback(self._cb_find) | ||
318 | 214 | |||
319 | 215 | |||
320 | 216 | def _cb_find(self, resultSet): | ||
321 | 217 | """ | ||
322 | 218 | Wrap the result set with a C{DeferredResultSet}. | ||
323 | 219 | """ | ||
324 | 220 | return DeferredResultSet(self.thread, resultSet) | ||
325 | 221 | |||
326 | 222 | |||
327 | 223 | def stop(self): | ||
328 | 224 | """ | ||
329 | 225 | Stop the store. | ||
330 | 226 | """ | ||
331 | 227 | if self.thread.stopped: | ||
332 | 228 | return succeed(None) | ||
333 | 229 | def close(): | ||
334 | 230 | self.store.rollback() | ||
335 | 231 | self.store.close() | ||
336 | 232 | return self.thread.defer_to_thread(close | ||
337 | 233 | ).addCallback(lambda ign: self.thread.stop()) | ||
338 | 234 | |||
339 | 235 | |||
340 | 236 | def _resolve_lazy_value(self, *args): | ||
341 | 237 | raise RuntimeError( | ||
342 | 238 | "Resolving lazy values with the Twisted wrapper is not possible " | ||
343 | 239 | "right now! Please refetch your object using " | ||
344 | 240 | "store.get/store.find") | ||
345 | 241 | |||
346 | 242 | |||
347 | 243 | @staticmethod | ||
348 | 244 | def of(obj): | ||
349 | 245 | """ | ||
350 | 246 | Get the DeferredStore object is associated with | ||
351 | 247 | |||
352 | 248 | If the given object has not been associated with a DeferredStore, | ||
353 | 249 | return None. | ||
354 | 250 | """ | ||
355 | 251 | store = Store.of(obj) | ||
356 | 252 | if not store: | ||
357 | 253 | return | ||
358 | 254 | return getattr(store, '_deferredStore', None) | ||
359 | 255 | |||
360 | 256 | |||
361 | 257 | |||
362 | 258 | class StorePool(object): | ||
363 | 259 | """ | ||
364 | 260 | A pool of started stores, maintaining persistent connections. | ||
365 | 261 | """ | ||
366 | 262 | started = False | ||
367 | 263 | store_factory = DeferredStore | ||
368 | 264 | |||
369 | 265 | def __init__(self, database, min_stores=0, max_stores=10): | ||
370 | 266 | """ | ||
371 | 267 | @param database: instance of database providing connection, used to | ||
372 | 268 | instantiate the store later. | ||
373 | 269 | @type database: L{storm.database.Database} | ||
374 | 270 | |||
375 | 271 | @param min_stores: initial number of stores. | ||
376 | 272 | @type min_stores: C{int} | ||
377 | 273 | |||
378 | 274 | @param max_stores: maximum number of stores. | ||
379 | 275 | @type max_stores: C{int} | ||
380 | 276 | """ | ||
381 | 277 | self.database = database | ||
382 | 278 | self.min_stores = min_stores | ||
383 | 279 | self.max_stores = max_stores | ||
384 | 280 | self._stores = [] | ||
385 | 281 | self._stores_created = 0 | ||
386 | 282 | self._pending_get = [] | ||
387 | 283 | self._store_refs = [] | ||
388 | 284 | |||
389 | 285 | |||
390 | 286 | def start(self): | ||
391 | 287 | """ | ||
392 | 288 | Start the pool. | ||
393 | 289 | """ | ||
394 | 290 | if self.started: | ||
395 | 291 | raise RuntimeError("Already started") | ||
396 | 292 | self.started = True | ||
397 | 293 | return self.adjust_size() | ||
398 | 294 | |||
399 | 295 | |||
400 | 296 | def stop(self): | ||
401 | 297 | """ | ||
402 | 298 | Stop the pool: this is not a total stop, it just try to kill the | ||
403 | 299 | current available stores. | ||
404 | 300 | """ | ||
405 | 301 | return self.adjust_size(0, 0, self._store_refs) | ||
406 | 302 | |||
407 | 303 | |||
408 | 304 | def _start_store(self): | ||
409 | 305 | """ | ||
410 | 306 | Create a new store. | ||
411 | 307 | """ | ||
412 | 308 | store = self.store_factory(self.database) | ||
413 | 309 | # Increment here, so that other simultaneous calls don't make the | ||
414 | 310 | # number of connections pass the maximum | ||
415 | 311 | self._stores_created += 1 | ||
416 | 312 | return store.start( | ||
417 | 313 | ).addCallback(self._cb_start_store, store | ||
418 | 314 | ).addErrback(self._eb_start_store) | ||
419 | 315 | |||
420 | 316 | |||
421 | 317 | def _cb_start_store(self, ign, store): | ||
422 | 318 | """ | ||
423 | 319 | Add the created store to the list of available stores. | ||
424 | 320 | """ | ||
425 | 321 | self._stores.append(store) | ||
426 | 322 | self._store_refs.append(store) | ||
427 | 323 | |||
428 | 324 | |||
429 | 325 | def _eb_start_store(self, failure): | ||
430 | 326 | """ | ||
431 | 327 | Reduce the amount of created stores, and let the failure propagate. | ||
432 | 328 | """ | ||
433 | 329 | self._stores_created -= 1 | ||
434 | 330 | return failure | ||
435 | 331 | |||
436 | 332 | |||
437 | 333 | def _stop_store(self, stores=None): | ||
438 | 334 | """ | ||
439 | 335 | Stop a store and remove it from the available stores. | ||
440 | 336 | """ | ||
441 | 337 | if stores is None: | ||
442 | 338 | stores = self._stores | ||
443 | 339 | self._stores_created -= 1 | ||
444 | 340 | store = stores.pop() | ||
445 | 341 | return store.stop() | ||
446 | 342 | |||
447 | 343 | |||
448 | 344 | @deferredGenerator | ||
449 | 345 | def adjust_size(self, min_stores=None, max_stores=None, stores=None): | ||
450 | 346 | """ | ||
451 | 347 | Change the number of available stores, shrinking or raising as | ||
452 | 348 | necessary. | ||
453 | 349 | """ | ||
454 | 350 | if min_stores is None: | ||
455 | 351 | min_stores = self.min_stores | ||
456 | 352 | if max_stores is None: | ||
457 | 353 | max_stores = self.max_stores | ||
458 | 354 | if stores is None: | ||
459 | 355 | stores = self._stores | ||
460 | 356 | |||
461 | 357 | if min_stores < 0: | ||
462 | 358 | raise ValueError('minimum is negative') | ||
463 | 359 | if min_stores > max_stores: | ||
464 | 360 | raise ValueError('minimum is greater than maximum') | ||
465 | 361 | |||
466 | 362 | self.min_stores = min_stores | ||
467 | 363 | self.max_stores = max_stores | ||
468 | 364 | if not self.started: | ||
469 | 365 | return | ||
470 | 366 | |||
471 | 367 | # Kill of some stores if we have too many. | ||
472 | 368 | while self._stores_created > self.max_stores and stores: | ||
473 | 369 | wfd = waitForDeferred(self._stop_store(stores)) | ||
474 | 370 | yield wfd | ||
475 | 371 | wfd.getResult() | ||
476 | 372 | # Start some stores if we have too few. | ||
477 | 373 | while self._stores_created < self.min_stores: | ||
478 | 374 | wfd = waitForDeferred(self._start_store()) | ||
479 | 375 | yield wfd | ||
480 | 376 | wfd.getResult() | ||
481 | 377 | |||
482 | 378 | |||
483 | 379 | def get(self): | ||
484 | 380 | """ | ||
485 | 381 | Return a started store from the pool, or start a new one if necessary. | ||
486 | 382 | A store retrieve by this way should be put back using the put | ||
487 | 383 | method, or it won't be used anymore. | ||
488 | 384 | """ | ||
489 | 385 | if not self.started: | ||
490 | 386 | raise RuntimeError("Not started") | ||
491 | 387 | if self._stores: | ||
492 | 388 | store = self._stores.pop() | ||
493 | 389 | return succeed(store) | ||
494 | 390 | elif self._stores_created < self.max_stores: | ||
495 | 391 | return self._start_store().addCallback(self._cb_get) | ||
496 | 392 | else: | ||
497 | 393 | # Maybe all stores are consumed? | ||
498 | 394 | return self.adjust_size().addCallback(self._cb_get) | ||
499 | 395 | |||
500 | 396 | |||
501 | 397 | def _cb_get(self, ign): | ||
502 | 398 | """ | ||
503 | 399 | If the previous operation added a store, return it, or return a pending | ||
504 | 400 | C{Deferred}. | ||
505 | 401 | """ | ||
506 | 402 | if self._stores: | ||
507 | 403 | store = self._stores.pop() | ||
508 | 404 | return store | ||
509 | 405 | else: | ||
510 | 406 | # All stores are in used, wait | ||
511 | 407 | d = Deferred() | ||
512 | 408 | self._pending_get.append(d) | ||
513 | 409 | return d | ||
514 | 410 | |||
515 | 411 | |||
516 | 412 | def put(self, store): | ||
517 | 413 | """ | ||
518 | 414 | Make a store available again. | ||
519 | 415 | |||
520 | 416 | This should be done explicitely to have the store back in the pool. | ||
521 | 417 | The good way to use the pool is this: | ||
522 | 418 | |||
523 | 419 | >>> d1 = pool.get() | ||
524 | 420 | |||
525 | 421 | >>> # d1 callback with a store | ||
526 | 422 | >>> d2 = store.add(foo) | ||
527 | 423 | >>> d2.addCallback(doSomething).addErrback(manageErrors) | ||
528 | 424 | >>> d2.addBoth(lambda x: pool.put(store)) | ||
529 | 425 | """ | ||
530 | 426 | return store.rollback().addBoth(self._cb_put, store) | ||
531 | 427 | |||
532 | 428 | |||
533 | 429 | def _cb_put(self, passthrough, store): | ||
534 | 430 | """ | ||
535 | 431 | Once the rollback has finished, the store is really available. | ||
536 | 432 | """ | ||
537 | 433 | if self._pending_get: | ||
538 | 434 | # People are waiting, fire with the store | ||
539 | 435 | d = self._pending_get.pop(0) | ||
540 | 436 | d.callback(store) | ||
541 | 437 | else: | ||
542 | 438 | self._stores.append(store) | ||
543 | 439 | return passthrough | ||
544 | 440 | |||
545 | 441 | |||
546 | 442 | def transact(self, f, *args, **kwargs): | ||
547 | 443 | """ | ||
548 | 444 | Call function C{f} with a L{Store} instance and arguments C{args} and | ||
549 | 445 | C{kwargs} in transaction bound to the acquired store. If transaction | ||
550 | 446 | succeeds, store will be commited. Store is returned to this pool after | ||
551 | 447 | call to C{f} completes. | ||
552 | 448 | |||
553 | 449 | Note that the function C{f} must return an instance of L{Deferred}. | ||
554 | 450 | |||
555 | 451 | @param f: function to call in transaction | ||
556 | 452 | @param args: positional arguments to function C{f} | ||
557 | 453 | @param kwargs: keyword arguments to function C{f} | ||
558 | 454 | """ | ||
559 | 455 | return self.get( | ||
560 | 456 | ).addCallback(self._cb_transact_start, f, args, kwargs) | ||
561 | 457 | |||
562 | 458 | |||
563 | 459 | def _cb_transact_start(self, store, f, args, kwargs): | ||
564 | 460 | """ | ||
565 | 461 | Call transacted function with acquired store. | ||
566 | 462 | """ | ||
567 | 463 | result = maybeDeferred(f, store, *args, **kwargs) | ||
568 | 464 | result.addCallback(self._cb_transact_success, store) | ||
569 | 465 | result.addBoth(self._cb_transact_stop, store) | ||
570 | 466 | return result | ||
571 | 467 | |||
572 | 468 | |||
573 | 469 | def _cb_transact_success(self, result, store): | ||
574 | 470 | """ | ||
575 | 471 | Commit and pass through function result. | ||
576 | 472 | """ | ||
577 | 473 | return store.commit().addCallback(lambda ignore: result) | ||
578 | 474 | |||
579 | 475 | |||
580 | 476 | def _cb_transact_stop(self, result, store): | ||
581 | 477 | """ | ||
582 | 478 | Return the store back to the pool and pass through the result again. | ||
583 | 479 | """ | ||
584 | 480 | return self.put(store).addCallback(lambda ignore: result) | ||
585 | 0 | 481 | ||
586 | === added file 'storm/twisted/wrapper.py' | |||
587 | --- storm/twisted/wrapper.py 1970-01-01 00:00:00 +0000 | |||
588 | +++ storm/twisted/wrapper.py 2009-12-27 11:31:14 +0000 | |||
589 | @@ -0,0 +1,206 @@ | |||
590 | 1 | # | ||
591 | 2 | # Copyright (c) 2007 Canonical | ||
592 | 3 | # Copyright (c) 2007 Thomas Herve <thomas@nimail.org> | ||
593 | 4 | # | ||
594 | 5 | # This file is part of Storm Object Relational Mapper. | ||
595 | 6 | # | ||
596 | 7 | # Storm is free software; you can redistribute it and/or modify | ||
597 | 8 | # it under the terms of the GNU Lesser General Public License as | ||
598 | 9 | # published by the Free Software Foundation; either version 2.1 of | ||
599 | 10 | # the License, or (at your option) any later version. | ||
600 | 11 | # | ||
601 | 12 | # Storm is distributed in the hope that it will be useful, | ||
602 | 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
603 | 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
604 | 15 | # GNU Lesser General Public License for more details. | ||
605 | 16 | # | ||
606 | 17 | # You should have received a copy of the GNU Lesser General Public License | ||
607 | 18 | # along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
608 | 19 | # | ||
609 | 20 | |||
610 | 21 | """ | ||
611 | 22 | Asynchronous wrapper around storm. | ||
612 | 23 | """ | ||
613 | 24 | |||
614 | 25 | from storm.store import Store | ||
615 | 26 | from storm.references import Reference, ReferenceSet | ||
616 | 27 | |||
617 | 28 | |||
618 | 29 | try: | ||
619 | 30 | from functools import partial | ||
620 | 31 | except ImportError: | ||
621 | 32 | # For Python < 2.5 | ||
622 | 33 | class partial(object): | ||
623 | 34 | def __init__(self, fn, *args, **kw): | ||
624 | 35 | self.fn = fn | ||
625 | 36 | self.args = args | ||
626 | 37 | self.kw = kw | ||
627 | 38 | |||
628 | 39 | def __call__(self, *args, **kw): | ||
629 | 40 | if kw and self.kw: | ||
630 | 41 | d = self.kw.copy() | ||
631 | 42 | d.update(kw) | ||
632 | 43 | else: | ||
633 | 44 | d = kw or self.kw | ||
634 | 45 | return self.fn(*(self.args + args), **d) | ||
635 | 46 | |||
636 | 47 | |||
637 | 48 | |||
638 | 49 | class DeferredResult(object): | ||
639 | 50 | """ | ||
640 | 51 | Proxy for a storm result, running the blocking methods in a thread and | ||
641 | 52 | returning C{Deferred}s. | ||
642 | 53 | """ | ||
643 | 54 | |||
644 | 55 | def __init__(self, thread, result): | ||
645 | 56 | """ | ||
646 | 57 | @param thread: the running thread of the store | ||
647 | 58 | @type thread: C{StoreThread} | ||
648 | 59 | |||
649 | 60 | @param result: the result instance to be wrapped. | ||
650 | 61 | @type result: C{storm.database.Result} | ||
651 | 62 | """ | ||
652 | 63 | self.result = result | ||
653 | 64 | for methodName in ("get_one", "get_all"): | ||
654 | 65 | method = partial(thread.defer_to_thread, | ||
655 | 66 | getattr(result, methodName)) | ||
656 | 67 | setattr(self, methodName, method) | ||
657 | 68 | |||
658 | 69 | |||
659 | 70 | |||
660 | 71 | class DeferredResultSet(object): | ||
661 | 72 | """ | ||
662 | 73 | Wrapper for a L{storm.store.ResultSet}. | ||
663 | 74 | """ | ||
664 | 75 | |||
665 | 76 | def __init__(self, thread, resultSet): | ||
666 | 77 | """ | ||
667 | 78 | Create the results with given C{StoreThread} and the set to wrap. | ||
668 | 79 | """ | ||
669 | 80 | self._thread = thread | ||
670 | 81 | self._resultSet = resultSet | ||
671 | 82 | for methodName in ("any", "one", "first", "last", "remove", "count", | ||
672 | 83 | "max", "min", "avg", "sum", "set", "is_empty"): | ||
673 | 84 | method = partial(thread.defer_to_thread, | ||
674 | 85 | getattr(resultSet, methodName)) | ||
675 | 86 | setattr(self, methodName, method) | ||
676 | 87 | for methodName in ("union", "difference", "intersection"): | ||
677 | 88 | method = partial(self._set_expr, | ||
678 | 89 | getattr(resultSet, methodName)) | ||
679 | 90 | setattr(self, methodName, method) | ||
680 | 91 | for methodName in ("order_by", "config", "group_by", "having"): | ||
681 | 92 | setattr(self, methodName, getattr(self._resultSet, methodName)) | ||
682 | 93 | |||
683 | 94 | |||
684 | 95 | def all(self): | ||
685 | 96 | """ | ||
686 | 97 | Specific method to emulate C{__iter__}. | ||
687 | 98 | """ | ||
688 | 99 | return self._thread.defer_to_thread(list, self._resultSet) | ||
689 | 100 | |||
690 | 101 | |||
691 | 102 | def values(self, *columns): | ||
692 | 103 | """ | ||
693 | 104 | Wrapper around values that remove the iterator feature to return a list | ||
694 | 105 | instead. | ||
695 | 106 | """ | ||
696 | 107 | def _get_values(): | ||
697 | 108 | return list(self._resultSet.values(*columns)) | ||
698 | 109 | return self._thread.defer_to_thread(_get_values) | ||
699 | 110 | |||
700 | 111 | |||
701 | 112 | def _set_expr(self, method, other, all=False): | ||
702 | 113 | """ | ||
703 | 114 | Wrap a set expression with a C{DeferredResultSet}. | ||
704 | 115 | """ | ||
705 | 116 | return DeferredResultSet(self._thread, method(other, all)) | ||
706 | 117 | |||
707 | 118 | |||
708 | 119 | |||
709 | 120 | class DeferredReference(Reference): | ||
710 | 121 | """ | ||
711 | 122 | A reference property but within a C{Deferred}. | ||
712 | 123 | """ | ||
713 | 124 | |||
714 | 125 | def __get__(self, local, cls=None): | ||
715 | 126 | """ | ||
716 | 127 | Wrapper around C{Reference.__get__}. | ||
717 | 128 | """ | ||
718 | 129 | store = Store.of(local) | ||
719 | 130 | if store is None: | ||
720 | 131 | return None | ||
721 | 132 | _thread = store._deferredStore.thread | ||
722 | 133 | return _thread.defer_to_thread(Reference.__get__, self, local, cls) | ||
723 | 134 | |||
724 | 135 | |||
725 | 136 | def __set__(self, local, remote): | ||
726 | 137 | """ | ||
727 | 138 | Wrapper around C{Reference.__set__}. | ||
728 | 139 | """ | ||
729 | 140 | raise RuntimeError("Can't set a DeferredReference") | ||
730 | 141 | |||
731 | 142 | |||
732 | 143 | |||
733 | 144 | class DeferredReferenceSet(ReferenceSet): | ||
734 | 145 | """ | ||
735 | 146 | A C{ReferenceSet} but within a C{Deferred}. | ||
736 | 147 | """ | ||
737 | 148 | |||
738 | 149 | def __get__(self, local, cls=None): | ||
739 | 150 | """ | ||
740 | 151 | Wrapper around C{ReferenceSet.__get__}. | ||
741 | 152 | """ | ||
742 | 153 | store = Store.of(local) | ||
743 | 154 | if store is None: | ||
744 | 155 | return None | ||
745 | 156 | _thread = store._deferredStore.thread | ||
746 | 157 | boundReference = ReferenceSet.__get__(self, local, cls) | ||
747 | 158 | return DeferredBoundReference(_thread, boundReference) | ||
748 | 159 | |||
749 | 160 | |||
750 | 161 | |||
751 | 162 | class DeferredBoundReference(object): | ||
752 | 163 | """ | ||
753 | 164 | Wrapper around C{BoundReferenceSet} and C{BoundIndirectReferenceSet}. | ||
754 | 165 | """ | ||
755 | 166 | |||
756 | 167 | def __init__(self, thread, boundReference): | ||
757 | 168 | """ | ||
758 | 169 | Create the reference with given C{StoreThread} and the reference to | ||
759 | 170 | wrap. | ||
760 | 171 | """ | ||
761 | 172 | self._thread = thread | ||
762 | 173 | self._boundReference = boundReference | ||
763 | 174 | for methodName in ("clear", "add", "remove", "any", "count", "one", | ||
764 | 175 | "first", "last"): | ||
765 | 176 | method = partial(thread.defer_to_thread, | ||
766 | 177 | getattr(boundReference, methodName)) | ||
767 | 178 | setattr(self, methodName, method) | ||
768 | 179 | for methodName in ("order_by", "find"): | ||
769 | 180 | method = partial(self._defer_and_wrap_result, | ||
770 | 181 | getattr(boundReference, methodName)) | ||
771 | 182 | setattr(self, methodName, method) | ||
772 | 183 | |||
773 | 184 | |||
774 | 185 | def all(self): | ||
775 | 186 | """ | ||
776 | 187 | Specific method to emulate C{__iter__}. | ||
777 | 188 | """ | ||
778 | 189 | return self._thread.defer_to_thread(list, self._boundReference) | ||
779 | 190 | |||
780 | 191 | |||
781 | 192 | def values(self, *columns): | ||
782 | 193 | """ | ||
783 | 194 | Emulate the values method. | ||
784 | 195 | """ | ||
785 | 196 | def _get_values(): | ||
786 | 197 | return list(self._boundReference.values(*columns)) | ||
787 | 198 | return self._thread.defer_to_thread(_get_values) | ||
788 | 199 | |||
789 | 200 | |||
790 | 201 | def _defer_and_wrap_result(self, method, *args, **kwargs): | ||
791 | 202 | """ | ||
792 | 203 | Helper for methods returning another C{ResultSet}. | ||
793 | 204 | """ | ||
794 | 205 | return self._thread.defer_to_thread(method, *args, **kwargs | ||
795 | 206 | ).addCallback(lambda x: DeferredResultSet(self._thread, x)) | ||
796 | 0 | 207 | ||
797 | === added directory 'tests/twisted' | |||
798 | === added file 'tests/twisted/__init__.py' | |||
799 | === added file 'tests/twisted/base.py' | |||
800 | --- tests/twisted/base.py 1970-01-01 00:00:00 +0000 | |||
801 | +++ tests/twisted/base.py 2009-12-27 11:31:15 +0000 | |||
802 | @@ -0,0 +1,1286 @@ | |||
803 | 1 | # | ||
804 | 2 | # Copyright (c) 2007 Canonical | ||
805 | 3 | # Copyright (c) 2007 Thomas Herve <thomas@nimail.org> | ||
806 | 4 | # | ||
807 | 5 | # This file is part of Storm Object Relational Mapper. | ||
808 | 6 | # | ||
809 | 7 | # Storm is free software; you can redistribute it and/or modify | ||
810 | 8 | # it under the terms of the GNU Lesser General Public License as | ||
811 | 9 | # published by the Free Software Foundation; either version 2.1 of | ||
812 | 10 | # the License, or (at your option) any later version. | ||
813 | 11 | # | ||
814 | 12 | # Storm is distributed in the hope that it will be useful, | ||
815 | 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
816 | 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
817 | 15 | # GNU Lesser General Public License for more details. | ||
818 | 16 | # | ||
819 | 17 | # You should have received a copy of the GNU Lesser General Public License | ||
820 | 18 | # along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
821 | 19 | # | ||
822 | 20 | |||
823 | 21 | """ | ||
824 | 22 | Test for twistorm. | ||
825 | 23 | """ | ||
826 | 24 | |||
827 | 25 | from storm.properties import Int, Unicode | ||
828 | 26 | from storm.expr import Count | ||
829 | 27 | from storm.references import Reference | ||
830 | 28 | from storm.exceptions import OperationalError, ThreadSafetyError | ||
831 | 29 | |||
832 | 30 | from storm.twisted.store import ( | ||
833 | 31 | DeferredStore, StoreThread, StorePool, AlreadyStopped ) | ||
834 | 32 | from storm.twisted.wrapper import DeferredReference, DeferredReferenceSet | ||
835 | 33 | |||
836 | 34 | from twisted.trial.unittest import TestCase | ||
837 | 35 | from twisted.internet.defer import gatherResults, deferredGenerator | ||
838 | 36 | from twisted.internet.defer import waitForDeferred, DeferredList | ||
839 | 37 | from twisted.internet.defer import succeed, fail | ||
840 | 38 | |||
841 | 39 | |||
842 | 40 | class Foo(object): | ||
843 | 41 | """ | ||
844 | 42 | Test table. | ||
845 | 43 | """ | ||
846 | 44 | __storm_table__ = "foo" | ||
847 | 45 | id = Int(primary=True) | ||
848 | 46 | title = Unicode() | ||
849 | 47 | |||
850 | 48 | |||
851 | 49 | |||
852 | 50 | class Bar(object): | ||
853 | 51 | """ | ||
854 | 52 | Test table referencing to C{Foo} | ||
855 | 53 | """ | ||
856 | 54 | __storm_table__ = "bar" | ||
857 | 55 | id = Int(primary=True) | ||
858 | 56 | title = Unicode() | ||
859 | 57 | foo_id = Int() | ||
860 | 58 | foo = DeferredReference(foo_id, Foo.id) | ||
861 | 59 | |||
862 | 60 | |||
863 | 61 | |||
864 | 62 | class FooRefSet(Foo): | ||
865 | 63 | """ | ||
866 | 64 | A C{Foo} class with a C{DeferredReferenceSet} to get all the bars related. | ||
867 | 65 | """ | ||
868 | 66 | bars = DeferredReferenceSet(Foo.id, Bar.foo_id) | ||
869 | 67 | |||
870 | 68 | |||
871 | 69 | |||
872 | 70 | class FooRefSetOrderID(Foo): | ||
873 | 71 | """ | ||
874 | 72 | A C{Foo} class with an order C{DeferredReferenceSet} to C{Bar}. | ||
875 | 73 | """ | ||
876 | 74 | bars = DeferredReferenceSet(Foo.id, Bar.foo_id, order_by=Bar.id) | ||
877 | 75 | |||
878 | 76 | |||
879 | 77 | |||
880 | 78 | class Egg(object): | ||
881 | 79 | """ | ||
882 | 80 | Test table. | ||
883 | 81 | """ | ||
884 | 82 | __storm_table__ = "egg" | ||
885 | 83 | id = Int(primary=True) | ||
886 | 84 | value = Int() | ||
887 | 85 | |||
888 | 86 | |||
889 | 87 | |||
890 | 88 | class DeferredStoreTest(object): | ||
891 | 89 | """ | ||
892 | 90 | Tests for L{DeferredStore}. | ||
893 | 91 | """ | ||
894 | 92 | |||
895 | 93 | def setUp(self): | ||
896 | 94 | """ | ||
897 | 95 | Create a test sqlite database, and insert some data. | ||
898 | 96 | """ | ||
899 | 97 | self.create_database() | ||
900 | 98 | connection = self.connection = self.database.connect() | ||
901 | 99 | self.drop_tables() | ||
902 | 100 | self.create_tables() | ||
903 | 101 | connection.execute("INSERT INTO foo VALUES (10, 'Title 30')") | ||
904 | 102 | connection.execute("INSERT INTO bar VALUES (10, 10, 'Title 50')") | ||
905 | 103 | connection.execute("INSERT INTO bar VALUES (11, 10, 'Title 40')") | ||
906 | 104 | connection.execute("INSERT INTO egg VALUES (1, 4)") | ||
907 | 105 | connection.execute("INSERT INTO egg VALUES (2, 3)") | ||
908 | 106 | connection.execute("INSERT INTO egg VALUES (3, 7)") | ||
909 | 107 | connection.execute("INSERT INTO egg VALUES (4, 5)") | ||
910 | 108 | connection.commit() | ||
911 | 109 | self.store = DeferredStore(self.database) | ||
912 | 110 | return self.store.start() | ||
913 | 111 | |||
914 | 112 | |||
915 | 113 | def tearDown(self): | ||
916 | 114 | """ | ||
917 | 115 | Kill the store (and its underlying thread). | ||
918 | 116 | """ | ||
919 | 117 | def _stop(ign): | ||
920 | 118 | return self.store.stop().addCallback(_drop) | ||
921 | 119 | def _drop(ign): | ||
922 | 120 | self.drop_tables() | ||
923 | 121 | return self.store.rollback().addCallback(_stop) | ||
924 | 122 | |||
925 | 123 | |||
926 | 124 | def create_database(self): | ||
927 | 125 | raise NotImplementedError() | ||
928 | 126 | |||
929 | 127 | |||
930 | 128 | def create_tables(self): | ||
931 | 129 | raise NotImplementedError() | ||
932 | 130 | |||
933 | 131 | |||
934 | 132 | def drop_tables(self): | ||
935 | 133 | for table in ["foo", "bar", "egg"]: | ||
936 | 134 | try: | ||
937 | 135 | self.connection.execute("DROP TABLE %s" % table) | ||
938 | 136 | self.connection.commit() | ||
939 | 137 | except: | ||
940 | 138 | self.connection.rollback() | ||
941 | 139 | |||
942 | 140 | |||
943 | 141 | def test_multiple_start(self): | ||
944 | 142 | """ | ||
945 | 143 | Check that start raises an exception when the store is already started. | ||
946 | 144 | """ | ||
947 | 145 | self.assertRaises(RuntimeError, self.store.start) | ||
948 | 146 | |||
949 | 147 | |||
950 | 148 | def test_get(self): | ||
951 | 149 | """ | ||
952 | 150 | Try to get an object from the store and check its attributes. | ||
953 | 151 | """ | ||
954 | 152 | def cb(result): | ||
955 | 153 | self.assertEquals(result.title, u"Title 30") | ||
956 | 154 | self.assertEquals(result.id, 10) | ||
957 | 155 | return self.store.get(Foo, 10).addCallback(cb) | ||
958 | 156 | |||
959 | 157 | |||
960 | 158 | def test_add(self): | ||
961 | 159 | """ | ||
962 | 160 | Add an object to the store. | ||
963 | 161 | """ | ||
964 | 162 | foo = Foo() | ||
965 | 163 | foo.title = u"Great title" | ||
966 | 164 | foo.id = 11 | ||
967 | 165 | def cb_add(ign): | ||
968 | 166 | return self.store.get(Foo, 11).addCallback(cb_get) | ||
969 | 167 | def cb_get(result): | ||
970 | 168 | self.assertEquals(result.title, u"Great title") | ||
971 | 169 | self.assertEquals(result.id, 11) | ||
972 | 170 | return self.store.add(foo).addCallback(cb_add) | ||
973 | 171 | |||
974 | 172 | |||
975 | 173 | def test_add_default_value(self): | ||
976 | 174 | """ | ||
977 | 175 | When adding an object to the store, the default values from the | ||
978 | 176 | database are retrieved and put into the object. | ||
979 | 177 | """ | ||
980 | 178 | foo = Foo() | ||
981 | 179 | foo.id = 11 | ||
982 | 180 | def cb_add(result): | ||
983 | 181 | self.assertIdentical(result, None) | ||
984 | 182 | return self.store.get(Foo, 11).addCallback(cb_get) | ||
985 | 183 | def cb_get(result): | ||
986 | 184 | self.assertEquals(result.title, u"Default Title") | ||
987 | 185 | self.assertEquals(result.id, 11) | ||
988 | 186 | return self.store.add(foo).addCallback(cb_add) | ||
989 | 187 | |||
990 | 188 | |||
991 | 189 | def test_execute(self): | ||
992 | 190 | """ | ||
993 | 191 | Test a direct execute on the store, and the C{get_one} method of | ||
994 | 192 | C{DeferredResult}. | ||
995 | 193 | """ | ||
996 | 194 | def cb_execute(result): | ||
997 | 195 | return result.get_one().addCallback(cb_result) | ||
998 | 196 | def cb_result(result): | ||
999 | 197 | self.assertEquals(result, (u"Title 30",)) | ||
1000 | 198 | return self.store.execute("SELECT title FROM foo WHERE id=10" | ||
1001 | 199 | ).addCallback(cb_execute) | ||
1002 | 200 | |||
1003 | 201 | |||
1004 | 202 | def test_execute_all(self): | ||
1005 | 203 | """ | ||
1006 | 204 | Test a direct execute on the store, and the C{all} method of | ||
1007 | 205 | C{DeferredResult}. | ||
1008 | 206 | """ | ||
1009 | 207 | def cb_execute(result): | ||
1010 | 208 | return result.get_all().addCallback(cb_result) | ||
1011 | 209 | def cb_result(result): | ||
1012 | 210 | self.assertEquals(result, [(u"Title 50",), (u"Title 40",)]) | ||
1013 | 211 | return self.store.execute("SELECT title FROM bar" | ||
1014 | 212 | ).addCallback(cb_execute) | ||
1015 | 213 | |||
1016 | 214 | |||
1017 | 215 | def test_remove(self): | ||
1018 | 216 | """ | ||
1019 | 217 | Trying removing an object from the database. | ||
1020 | 218 | """ | ||
1021 | 219 | def cb_get(result): | ||
1022 | 220 | return self.store.remove(result).addCallback(cb_remove) | ||
1023 | 221 | def cb_remove(ign): | ||
1024 | 222 | return self.store.get(Foo, 10).addCallback(cb_get_after_remove) | ||
1025 | 223 | def cb_get_after_remove(result): | ||
1026 | 224 | self.assertIdentical(result, None) | ||
1027 | 225 | return self.store.get(Foo, 10).addCallback(cb_get) | ||
1028 | 226 | |||
1029 | 227 | |||
1030 | 228 | def test_find(self): | ||
1031 | 229 | """ | ||
1032 | 230 | Try to find a list of objects using the store. | ||
1033 | 231 | """ | ||
1034 | 232 | def cb_find(results): | ||
1035 | 233 | return results.all().addCallback(cb_all) | ||
1036 | 234 | def cb_all(results): | ||
1037 | 235 | self.assertEquals(len(results), 2) | ||
1038 | 236 | titles = [results[0].title, results[1].title] | ||
1039 | 237 | titles.sort() | ||
1040 | 238 | self.assertEquals(titles, [u"Title 40", u"Title 50"]) | ||
1041 | 239 | return self.store.find(Bar).addCallback(cb_find) | ||
1042 | 240 | |||
1043 | 241 | |||
1044 | 242 | def test_find_first(self): | ||
1045 | 243 | """ | ||
1046 | 244 | Try to get the first object matching a query. | ||
1047 | 245 | """ | ||
1048 | 246 | def cb_find(results): | ||
1049 | 247 | results.order_by(Bar.title) | ||
1050 | 248 | return results.first().addCallback(cb_all) | ||
1051 | 249 | def cb_all(result): | ||
1052 | 250 | self.assertEquals(result.title, u"Title 40") | ||
1053 | 251 | self.assertEquals(result.id, 11) | ||
1054 | 252 | return self.store.find(Bar).addCallback(cb_find) | ||
1055 | 253 | |||
1056 | 254 | |||
1057 | 255 | def test_find_last(self): | ||
1058 | 256 | """ | ||
1059 | 257 | Try to get the last object matching a query. | ||
1060 | 258 | """ | ||
1061 | 259 | def cb_find(results): | ||
1062 | 260 | results.order_by(Bar.title) | ||
1063 | 261 | return results.last().addCallback(cb_all) | ||
1064 | 262 | def cb_all(result): | ||
1065 | 263 | self.assertEquals(result.title, u"Title 50") | ||
1066 | 264 | self.assertEquals(result.id, 10) | ||
1067 | 265 | return self.store.find(Bar).addCallback(cb_find) | ||
1068 | 266 | |||
1069 | 267 | |||
1070 | 268 | def test_find_any(self): | ||
1071 | 269 | """ | ||
1072 | 270 | Try to get an object matching a query using the C{any} method. | ||
1073 | 271 | """ | ||
1074 | 272 | def cb_find(results): | ||
1075 | 273 | return results.any().addCallback(cb_all) | ||
1076 | 274 | def cb_all(result): | ||
1077 | 275 | self.assertEquals(result.title, u"Title 50") | ||
1078 | 276 | self.assertEquals(result.id, 10) | ||
1079 | 277 | return self.store.find(Bar).addCallback(cb_find) | ||
1080 | 278 | |||
1081 | 279 | |||
1082 | 280 | def test_find_max(self): | ||
1083 | 281 | """ | ||
1084 | 282 | Try to get the maximum of a value after a find. | ||
1085 | 283 | """ | ||
1086 | 284 | def cb_find(results): | ||
1087 | 285 | return results.max(Egg.value).addCallback(cb_all) | ||
1088 | 286 | def cb_all(result): | ||
1089 | 287 | self.assertEquals(result, 7) | ||
1090 | 288 | return self.store.find(Egg).addCallback(cb_find) | ||
1091 | 289 | |||
1092 | 290 | |||
1093 | 291 | def test_find_min(self): | ||
1094 | 292 | """ | ||
1095 | 293 | Try to get the minimum of a value after a find. | ||
1096 | 294 | """ | ||
1097 | 295 | def cb_find(results): | ||
1098 | 296 | return results.min(Egg.value).addCallback(cb_all) | ||
1099 | 297 | def cb_all(result): | ||
1100 | 298 | self.assertEquals(result, 3) | ||
1101 | 299 | return self.store.find(Egg).addCallback(cb_find) | ||
1102 | 300 | |||
1103 | 301 | |||
1104 | 302 | def test_find_avg(self): | ||
1105 | 303 | """ | ||
1106 | 304 | Try to get the average of a value after a find. | ||
1107 | 305 | """ | ||
1108 | 306 | def cb_find(results): | ||
1109 | 307 | return results.avg(Egg.value).addCallback(cb_all) | ||
1110 | 308 | def cb_all(result): | ||
1111 | 309 | self.assertEquals(result, 4.75) | ||
1112 | 310 | return self.store.find(Egg).addCallback(cb_find) | ||
1113 | 311 | |||
1114 | 312 | |||
1115 | 313 | def test_find_sum(self): | ||
1116 | 314 | """ | ||
1117 | 315 | Try to get the sum of a value after a find. | ||
1118 | 316 | """ | ||
1119 | 317 | def cb_find(results): | ||
1120 | 318 | return results.sum(Egg.value).addCallback(cb_all) | ||
1121 | 319 | def cb_all(result): | ||
1122 | 320 | self.assertEquals(result, 19) | ||
1123 | 321 | return self.store.find(Egg).addCallback(cb_find) | ||
1124 | 322 | |||
1125 | 323 | |||
1126 | 324 | def test_find_count(self): | ||
1127 | 325 | """ | ||
1128 | 326 | Try to get the count of a result after a find. | ||
1129 | 327 | """ | ||
1130 | 328 | def cb_find(results): | ||
1131 | 329 | return results.count().addCallback(cb_all) | ||
1132 | 330 | def cb_all(result): | ||
1133 | 331 | self.assertEquals(result, 2) | ||
1134 | 332 | return self.store.find(Egg, Egg.value >= 5).addCallback(cb_find) | ||
1135 | 333 | |||
1136 | 334 | |||
1137 | 335 | def test_find_remove(self): | ||
1138 | 336 | """ | ||
1139 | 337 | Remove the result of a find query. | ||
1140 | 338 | """ | ||
1141 | 339 | def cb_find(results): | ||
1142 | 340 | return results.remove().addCallback(cb_remove) | ||
1143 | 341 | def cb_remove(ignore): | ||
1144 | 342 | return self.store.find(Egg).addCallback(cb_find_after_remove) | ||
1145 | 343 | def cb_find_after_remove(results): | ||
1146 | 344 | return results.all().addCallback(cb_all) | ||
1147 | 345 | def cb_all(results): | ||
1148 | 346 | self.assertEquals(len(results), 2) | ||
1149 | 347 | return self.store.find(Egg, Egg.value >= 5).addCallback(cb_find) | ||
1150 | 348 | |||
1151 | 349 | |||
1152 | 350 | def test_find_limit(self): | ||
1153 | 351 | """ | ||
1154 | 352 | Put a limit on the number of results of a find. | ||
1155 | 353 | """ | ||
1156 | 354 | def cb_find(results): | ||
1157 | 355 | results.config(limit=3) | ||
1158 | 356 | return results.all().addCallback(cb_all) | ||
1159 | 357 | def cb_all(results): | ||
1160 | 358 | self.assertEquals(len(results), 3) | ||
1161 | 359 | return self.store.find(Egg).addCallback(cb_find) | ||
1162 | 360 | |||
1163 | 361 | |||
1164 | 362 | def test_find_union(self): | ||
1165 | 363 | """ | ||
1166 | 364 | Call C{union} on 2 differents C{DeferredResultSet}. | ||
1167 | 365 | """ | ||
1168 | 366 | def cb_find(results): | ||
1169 | 367 | result1, result2 = results | ||
1170 | 368 | results = result1.union(result2._resultSet) | ||
1171 | 369 | return results.all().addCallback(cb_all) | ||
1172 | 370 | def cb_all(results): | ||
1173 | 371 | self.assertEquals(len(results), 3) | ||
1174 | 372 | d1 = self.store.find(Egg, Egg.value >= 5) | ||
1175 | 373 | d2 = self.store.find(Egg, Egg.value == 3) | ||
1176 | 374 | return gatherResults([d1, d2]).addCallback(cb_find) | ||
1177 | 375 | |||
1178 | 376 | |||
1179 | 377 | def test_find_difference(self): | ||
1180 | 378 | """ | ||
1181 | 379 | Call C{union} on 2 differents C{DeferredResultSet}. | ||
1182 | 380 | """ | ||
1183 | 381 | if self.__class__.__name__.startswith("MySQL"): | ||
1184 | 382 | return | ||
1185 | 383 | def cb_find(results): | ||
1186 | 384 | result1, result2 = results | ||
1187 | 385 | results = result1.difference(result2._resultSet) | ||
1188 | 386 | return results.all().addCallback(cb_all) | ||
1189 | 387 | def cb_all(results): | ||
1190 | 388 | self.assertEquals(len(results), 1) | ||
1191 | 389 | self.assertEquals(results[0].value, 5) | ||
1192 | 390 | d1 = self.store.find(Egg, Egg.value >= 5) | ||
1193 | 391 | d2 = self.store.find(Egg, Egg.value == 7) | ||
1194 | 392 | return gatherResults([d1, d2]).addCallback(cb_find) | ||
1195 | 393 | |||
1196 | 394 | |||
1197 | 395 | def test_find_intersection(self): | ||
1198 | 396 | """ | ||
1199 | 397 | Call C{intersection} on 2 differents C{DeferredResultSet}. | ||
1200 | 398 | """ | ||
1201 | 399 | if self.__class__.__name__.startswith("MySQL"): | ||
1202 | 400 | return | ||
1203 | 401 | def cb_find(results): | ||
1204 | 402 | result1, result2 = results | ||
1205 | 403 | results = result1.intersection(result2._resultSet) | ||
1206 | 404 | return results.all().addCallback(cb_all) | ||
1207 | 405 | def cb_all(results): | ||
1208 | 406 | self.assertEquals(len(results), 1) | ||
1209 | 407 | self.assertEquals(results[0].value, 7) | ||
1210 | 408 | d1 = self.store.find(Egg, Egg.value >= 5) | ||
1211 | 409 | d2 = self.store.find(Egg, Egg.value == 7) | ||
1212 | 410 | return gatherResults([d1, d2]).addCallback(cb_find) | ||
1213 | 411 | |||
1214 | 412 | |||
1215 | 413 | def test_find_values(self): | ||
1216 | 414 | """ | ||
1217 | 415 | Filter the fields returned by a find using the values method. | ||
1218 | 416 | """ | ||
1219 | 417 | def cb_find(results): | ||
1220 | 418 | return results.values(Bar.title).addCallback(cb_all) | ||
1221 | 419 | def cb_all(titles): | ||
1222 | 420 | titles.sort() | ||
1223 | 421 | self.assertEquals(titles, [u"Title 40", u"Title 50"]) | ||
1224 | 422 | return self.store.find(Bar).addCallback(cb_find) | ||
1225 | 423 | |||
1226 | 424 | |||
1227 | 425 | def test_find_and_set(self): | ||
1228 | 426 | """ | ||
1229 | 427 | The C{set} method of a C{ResultSet} should update the specified fields | ||
1230 | 428 | in a thread. | ||
1231 | 429 | """ | ||
1232 | 430 | def cb_find(results): | ||
1233 | 431 | return results.set(title=u"Title").addCallback(cb_set, results) | ||
1234 | 432 | def cb_set(ignore, results): | ||
1235 | 433 | return results.values(Bar.title).addCallback(cb_all) | ||
1236 | 434 | def cb_all(titles): | ||
1237 | 435 | titles.sort() | ||
1238 | 436 | self.assertEquals(titles, [u"Title", u"Title"]) | ||
1239 | 437 | return self.store.find(Bar).addCallback(cb_find) | ||
1240 | 438 | |||
1241 | 439 | |||
1242 | 440 | def test_find_offset(self): | ||
1243 | 441 | """ | ||
1244 | 442 | Put an offset on the number of results of a find. | ||
1245 | 443 | """ | ||
1246 | 444 | def cb_find(results): | ||
1247 | 445 | results.config(offset=2) | ||
1248 | 446 | return results.all().addCallback(cb_all) | ||
1249 | 447 | def cb_all(results): | ||
1250 | 448 | self.assertEquals(len(results), 2) | ||
1251 | 449 | return self.store.find(Egg).addCallback(cb_find) | ||
1252 | 450 | |||
1253 | 451 | |||
1254 | 452 | def test_find_offset_limit(self): | ||
1255 | 453 | """ | ||
1256 | 454 | Put an offset and limit in the number of results of a find. | ||
1257 | 455 | """ | ||
1258 | 456 | def cb_find(results): | ||
1259 | 457 | results.config(offset=1, limit=2) | ||
1260 | 458 | return results.all().addCallback(cb_all) | ||
1261 | 459 | def cb_all(results): | ||
1262 | 460 | self.assertEquals(len(results), 2) | ||
1263 | 461 | return self.store.find(Egg).addCallback(cb_find) | ||
1264 | 462 | |||
1265 | 463 | |||
1266 | 464 | @deferredGenerator | ||
1267 | 465 | def test_find_defgen(self): | ||
1268 | 466 | """ | ||
1269 | 467 | Do a find, add an object, then do another find: this to ensure that the | ||
1270 | 468 | connection remains in the dedicated thread. | ||
1271 | 469 | """ | ||
1272 | 470 | d = self.store.find(Bar) | ||
1273 | 471 | wfd = waitForDeferred(d) | ||
1274 | 472 | yield wfd | ||
1275 | 473 | results = wfd.getResult() | ||
1276 | 474 | d = results.all() | ||
1277 | 475 | wfd = waitForDeferred(d) | ||
1278 | 476 | yield wfd | ||
1279 | 477 | wfd.getResult() | ||
1280 | 478 | foo = Foo() | ||
1281 | 479 | foo.title = u"Great title" | ||
1282 | 480 | foo.id = 11 | ||
1283 | 481 | d = self.store.add(foo) | ||
1284 | 482 | wfd = waitForDeferred(d) | ||
1285 | 483 | yield wfd | ||
1286 | 484 | wfd.getResult() | ||
1287 | 485 | d = self.store.find(Foo) | ||
1288 | 486 | wfd = waitForDeferred(d) | ||
1289 | 487 | yield wfd | ||
1290 | 488 | results = wfd.getResult() | ||
1291 | 489 | |||
1292 | 490 | |||
1293 | 491 | def test_find_order_by(self): | ||
1294 | 492 | """ | ||
1295 | 493 | Try to find a list of objects using the store, then order the result | ||
1296 | 494 | set. | ||
1297 | 495 | """ | ||
1298 | 496 | def cb_find(results): | ||
1299 | 497 | results.order_by(Bar.title) | ||
1300 | 498 | return results.all().addCallback(cb_all) | ||
1301 | 499 | def cb_all(results): | ||
1302 | 500 | self.assertEquals(len(results), 2) | ||
1303 | 501 | titles = [results[0].title, results[1].title] | ||
1304 | 502 | self.assertEquals(titles, [u"Title 40", u"Title 50"]) | ||
1305 | 503 | return self.store.find(Bar).addCallback(cb_find) | ||
1306 | 504 | |||
1307 | 505 | |||
1308 | 506 | def test_find_and_rollback(self): | ||
1309 | 507 | """ | ||
1310 | 508 | Accessing an object outside of a transaction fails because the object | ||
1311 | 509 | hasn't been resolved yet. | ||
1312 | 510 | """ | ||
1313 | 511 | def cb_find(results): | ||
1314 | 512 | results.order_by(Bar.title) | ||
1315 | 513 | return results.all().addCallback(cb_all) | ||
1316 | 514 | def cb_all(results): | ||
1317 | 515 | return self.store.rollback().addCallback(cbRollback, results) | ||
1318 | 516 | def cbRollback(ign, results): | ||
1319 | 517 | self.assertEquals(len(results), 2) | ||
1320 | 518 | self.assertRaises(RuntimeError, getattr, results[0], "title") | ||
1321 | 519 | return self.store.find(Bar).addCallback(cb_find) | ||
1322 | 520 | |||
1323 | 521 | |||
1324 | 522 | def test_find_is_empty(self): | ||
1325 | 523 | """ | ||
1326 | 524 | DeferredReference.is_empty returns a Deferred that fires with True or | ||
1327 | 525 | False depending if the matched result set is empty or not. | ||
1328 | 526 | """ | ||
1329 | 527 | def cb_find(results): | ||
1330 | 528 | return results.is_empty().addCallback(self.assertEquals, False) | ||
1331 | 529 | |||
1332 | 530 | return self.store.find(Bar).addCallback(cb_find) | ||
1333 | 531 | |||
1334 | 532 | |||
1335 | 533 | def test_find_group_by(self): | ||
1336 | 534 | """ | ||
1337 | 535 | DeferredReference.group_by is a simple wrapper to the group_by method | ||
1338 | 536 | of the reference set. | ||
1339 | 537 | """ | ||
1340 | 538 | def cb_find(results): | ||
1341 | 539 | results.group_by(Bar.foo_id) | ||
1342 | 540 | return results.all().addCallback(check) | ||
1343 | 541 | |||
1344 | 542 | def check(result): | ||
1345 | 543 | self.assertEquals(result, [(2, 10)]) | ||
1346 | 544 | |||
1347 | 545 | return self.store.find((Count(Bar.id), Bar.foo_id) | ||
1348 | 546 | ).addCallback(cb_find) | ||
1349 | 547 | |||
1350 | 548 | |||
1351 | 549 | def test_find_having(self): | ||
1352 | 550 | """ | ||
1353 | 551 | DeferredReference.having is a simple wrapper to the having method of | ||
1354 | 552 | the reference set. | ||
1355 | 553 | """ | ||
1356 | 554 | connection = self.database.connect() | ||
1357 | 555 | connection.execute("INSERT INTO egg VALUES (5, 7)") | ||
1358 | 556 | connection.commit() | ||
1359 | 557 | |||
1360 | 558 | def cb_find(results): | ||
1361 | 559 | results.group_by(Egg.value) | ||
1362 | 560 | results.having(Egg.value >= 5) | ||
1363 | 561 | results.order_by(Egg.value) | ||
1364 | 562 | return results.all().addCallback(check) | ||
1365 | 563 | |||
1366 | 564 | def check(result): | ||
1367 | 565 | self.assertEquals(result, [(1, 5), (2, 7)]) | ||
1368 | 566 | |||
1369 | 567 | return self.store.find((Count(Egg.id), Egg.value) | ||
1370 | 568 | ).addCallback(cb_find) | ||
1371 | 569 | |||
1372 | 570 | |||
1373 | 571 | def test_reference(self): | ||
1374 | 572 | """ | ||
1375 | 573 | Trying to get a reference of an object using C{DeferredReference}. | ||
1376 | 574 | """ | ||
1377 | 575 | def cb_getBar(result): | ||
1378 | 576 | return result.foo.addCallback(cb_getFoo) | ||
1379 | 577 | def cb_getFoo(fooResult): | ||
1380 | 578 | return self.store.get(Foo, 10).addCallback(cb_getFooBar, fooResult) | ||
1381 | 579 | def cb_getFooBar(result, fooResult): | ||
1382 | 580 | self.assertIdentical(fooResult, result) | ||
1383 | 581 | # The result should be valid too | ||
1384 | 582 | self.assertEquals(fooResult.title, u"Title 30") | ||
1385 | 583 | return self.store.get(Bar, 10).addCallback(cb_getBar) | ||
1386 | 584 | |||
1387 | 585 | |||
1388 | 586 | def test_reference_setting(self): | ||
1389 | 587 | """ | ||
1390 | 588 | Try to set a reference of an object. | ||
1391 | 589 | """ | ||
1392 | 590 | connection = self.database.connect() | ||
1393 | 591 | connection.execute("INSERT INTO foo VALUES (20, 'Title 20')") | ||
1394 | 592 | connection.commit() | ||
1395 | 593 | def cb_getBar(result): | ||
1396 | 594 | return self.store.get(Foo, 20).addCallback(cb_getFooBar, result) | ||
1397 | 595 | def cb_getFooBar(result, barResult): | ||
1398 | 596 | self.assertRaises(RuntimeError, setattr, barResult, "foo", result) | ||
1399 | 597 | return self.store.get(Bar, 10).addCallback(cb_getBar) | ||
1400 | 598 | |||
1401 | 599 | |||
1402 | 600 | def test_reference_set_unordered(self): | ||
1403 | 601 | """ | ||
1404 | 602 | Get a reference set and call various wrapped methods on it. | ||
1405 | 603 | """ | ||
1406 | 604 | # find test | ||
1407 | 605 | def cb_find(results): | ||
1408 | 606 | return results.all().addCallback(cb_all) | ||
1409 | 607 | |||
1410 | 608 | def cb_all(results): | ||
1411 | 609 | self.assertEquals(len(results), 2) | ||
1412 | 610 | titles = [results[0].title, results[1].title] | ||
1413 | 611 | titles.sort() | ||
1414 | 612 | self.assertEquals(titles, [u"Title 40", u"Title 50"]) | ||
1415 | 613 | |||
1416 | 614 | def cb_any(result): | ||
1417 | 615 | self.assertTrue(result) | ||
1418 | 616 | |||
1419 | 617 | def cb_values(titles): | ||
1420 | 618 | titles.sort() | ||
1421 | 619 | self.assertEquals(titles, [u"Title 40", u"Title 50"]) | ||
1422 | 620 | |||
1423 | 621 | def do_tests(results): | ||
1424 | 622 | results = results.bars | ||
1425 | 623 | dfrs = [ | ||
1426 | 624 | results.find().addCallback(cb_find), | ||
1427 | 625 | results.any().addCallback(cb_any), | ||
1428 | 626 | results.values(Bar.title).addCallback(cb_values), | ||
1429 | 627 | ] | ||
1430 | 628 | return DeferredList(dfrs) | ||
1431 | 629 | |||
1432 | 630 | return self.store.get(FooRefSet, 10).addCallback(do_tests) | ||
1433 | 631 | |||
1434 | 632 | |||
1435 | 633 | def test_reference_set_ordered(self): | ||
1436 | 634 | """ | ||
1437 | 635 | A DeferredReferenceSet has a order_by method which returns a Deferred | ||
1438 | 636 | firing when the reference set is ordered. | ||
1439 | 637 | """ | ||
1440 | 638 | def do_tests(result): | ||
1441 | 639 | dfrs = [ | ||
1442 | 640 | result.first().addCallback(lambda t: | ||
1443 | 641 | self.assertEquals(t.title, u"Title 40")), | ||
1444 | 642 | result.last().addCallback(lambda t: | ||
1445 | 643 | self.assertEquals(t.title, u"Title 50")), | ||
1446 | 644 | ] | ||
1447 | 645 | return DeferredList(dfrs) | ||
1448 | 646 | |||
1449 | 647 | def order(results): | ||
1450 | 648 | dfr = results.bars.order_by("title") | ||
1451 | 649 | return dfr.addCallback(do_tests) | ||
1452 | 650 | |||
1453 | 651 | return self.store.get(FooRefSet, 10).addCallback(order) | ||
1454 | 652 | |||
1455 | 653 | |||
1456 | 654 | def test_reference_set_add_remove(self): | ||
1457 | 655 | """ | ||
1458 | 656 | A DeferredReferenceSet has a add method with returns a Deferred once | ||
1459 | 657 | the object has been added. | ||
1460 | 658 | Try to add things from the reference set async. | ||
1461 | 659 | """ | ||
1462 | 660 | def add_one(result): | ||
1463 | 661 | bar = Bar() | ||
1464 | 662 | bar.title = u"Yeah" | ||
1465 | 663 | return result.bars.add(bar).addCallback(remove_one, result, bar) | ||
1466 | 664 | |||
1467 | 665 | def remove_one(add_result, result, bar): | ||
1468 | 666 | return result.bars.remove(bar).addCallback(get_all, result) | ||
1469 | 667 | |||
1470 | 668 | def get_all(ignore, result): | ||
1471 | 669 | return result.bars.all().addCallback(check) | ||
1472 | 670 | |||
1473 | 671 | def check(result): | ||
1474 | 672 | self.assertEquals(len(result), 2) | ||
1475 | 673 | |||
1476 | 674 | return self.store.get(FooRefSet, 10).addCallback(add_one) | ||
1477 | 675 | |||
1478 | 676 | |||
1479 | 677 | def test_reference_set_clear(self): | ||
1480 | 678 | """ | ||
1481 | 679 | A DeferredReferenceSet has a clear method which removes all elements | ||
1482 | 680 | from the reference set and fires the returned Deferred when done. | ||
1483 | 681 | """ | ||
1484 | 682 | def first_cb(result): | ||
1485 | 683 | refs = result.bars | ||
1486 | 684 | return check_count(refs, 2).addCallback(clear_cb, refs) | ||
1487 | 685 | |||
1488 | 686 | def check_count(ref_set, num): | ||
1489 | 687 | return ref_set.count().addCallback(self.assertEquals, num) | ||
1490 | 688 | |||
1491 | 689 | def clear_cb(result, refs): | ||
1492 | 690 | return refs.clear().addCallback(lambda x: check_count(refs, 0)) | ||
1493 | 691 | |||
1494 | 692 | return self.store.get(FooRefSet, 10).addCallback(first_cb) | ||
1495 | 693 | |||
1496 | 694 | |||
1497 | 695 | def test_reference_set_one(self): | ||
1498 | 696 | """ | ||
1499 | 697 | Call C{one} on a C{DeferredBoundReference}. | ||
1500 | 698 | """ | ||
1501 | 699 | connection = self.database.connect() | ||
1502 | 700 | connection.execute("INSERT INTO foo VALUES (11, 'Title 40')") | ||
1503 | 701 | connection.execute("INSERT INTO bar VALUES (20, 11, 'Title 50')") | ||
1504 | 702 | connection.commit() | ||
1505 | 703 | def cb_get(result): | ||
1506 | 704 | return result.bars.one().addCallback(cb_one) | ||
1507 | 705 | def cb_one(result): | ||
1508 | 706 | return self.store.get(Bar, 20).addCallback(check, result) | ||
1509 | 707 | def check(result, expected): | ||
1510 | 708 | self.assertIdentical(result, expected) | ||
1511 | 709 | return self.store.get(FooRefSet, 11).addCallback(cb_get) | ||
1512 | 710 | |||
1513 | 711 | |||
1514 | 712 | def test_reference_set_first(self): | ||
1515 | 713 | """ | ||
1516 | 714 | Call C{first} on an ordered C{DeferredBoundReference}. | ||
1517 | 715 | """ | ||
1518 | 716 | def cb_get(result): | ||
1519 | 717 | return result.bars.first().addCallback(cb_one) | ||
1520 | 718 | def cb_one(result): | ||
1521 | 719 | return self.store.get(Bar, 10).addCallback(check, result) | ||
1522 | 720 | def check(result, expected): | ||
1523 | 721 | self.assertIdentical(result, expected) | ||
1524 | 722 | return self.store.get(FooRefSetOrderID, 10).addCallback(cb_get) | ||
1525 | 723 | |||
1526 | 724 | |||
1527 | 725 | def test_reference_set_last(self): | ||
1528 | 726 | """ | ||
1529 | 727 | Call C{last} on an ordered C{DeferredBoundReference}. | ||
1530 | 728 | """ | ||
1531 | 729 | def cb_get(result): | ||
1532 | 730 | return result.bars.last().addCallback(cb_one) | ||
1533 | 731 | def cb_one(result): | ||
1534 | 732 | return self.store.get(Bar, 11).addCallback(check, result) | ||
1535 | 733 | def check(result, expected): | ||
1536 | 734 | self.assertIdentical(result, expected) | ||
1537 | 735 | return self.store.get(FooRefSetOrderID, 10).addCallback(cb_get) | ||
1538 | 736 | |||
1539 | 737 | |||
1540 | 738 | def test_commit(self): | ||
1541 | 739 | """ | ||
1542 | 740 | Make some changes and commit them. | ||
1543 | 741 | """ | ||
1544 | 742 | def cb_get(result): | ||
1545 | 743 | return self.store.remove(result).addCallback(cb_remove) | ||
1546 | 744 | def cb_remove(ign): | ||
1547 | 745 | return self.store.commit().addCallback(cb_commit) | ||
1548 | 746 | def cb_commit(ign): | ||
1549 | 747 | # To be sure the data is no more in the db, the best is to | ||
1550 | 748 | # directly connect to the db | ||
1551 | 749 | connection = self.database.connect() | ||
1552 | 750 | result = connection.execute("SELECT * FROM foo") | ||
1553 | 751 | self.assertEquals(list(result), []) | ||
1554 | 752 | return self.store.get(Foo, 10).addCallback(cb_get) | ||
1555 | 753 | |||
1556 | 754 | |||
1557 | 755 | def test_rollback(self): | ||
1558 | 756 | """ | ||
1559 | 757 | Make and some changes and rollback them. | ||
1560 | 758 | """ | ||
1561 | 759 | def cb_get(result): | ||
1562 | 760 | return self.store.remove(result).addCallback(cb_remove) | ||
1563 | 761 | def cb_remove(ign): | ||
1564 | 762 | return self.store.rollback().addCallback(cbRollback) | ||
1565 | 763 | def cbRollback(ign): | ||
1566 | 764 | connection = self.database.connect() | ||
1567 | 765 | result = connection.execute("SELECT * FROM foo") | ||
1568 | 766 | self.assertEquals(list(result), [(10, u"Title 30")]) | ||
1569 | 767 | return self.store.get(Foo, 10).addCallback(cb_get) | ||
1570 | 768 | |||
1571 | 769 | |||
1572 | 770 | def test_deferred_reference_multithread(self): | ||
1573 | 771 | """ | ||
1574 | 772 | If a store is restarted, the objects in the cache should still be | ||
1575 | 773 | usable, in particular an object shouldn't not store a reference to the | ||
1576 | 774 | store thread, as it can change. | ||
1577 | 775 | """ | ||
1578 | 776 | def test(ignore): | ||
1579 | 777 | # get a bar and retrieve the deferred reference | ||
1580 | 778 | def get_foo(bar): | ||
1581 | 779 | return bar.foo | ||
1582 | 780 | |||
1583 | 781 | return self.store.find(Bar, Bar.id == 10).addCallback(lambda x: | ||
1584 | 782 | x.one()).addCallback(get_foo) | ||
1585 | 783 | |||
1586 | 784 | def _restart_store(res): | ||
1587 | 785 | def stopped(res): | ||
1588 | 786 | self.store = DeferredStore(self.database) | ||
1589 | 787 | return self.store.start() | ||
1590 | 788 | return self.store.stop().addCallback(stopped) | ||
1591 | 789 | |||
1592 | 790 | def check(foo): | ||
1593 | 791 | self.assertEquals(foo.id, 10) | ||
1594 | 792 | self.assertEquals(foo.title, u"Title 30") | ||
1595 | 793 | |||
1596 | 794 | return test(None | ||
1597 | 795 | ).addCallback(_restart_store | ||
1598 | 796 | ).addCallback(test | ||
1599 | 797 | ).addCallback(check) | ||
1600 | 798 | |||
1601 | 799 | |||
1602 | 800 | def test_of(self): | ||
1603 | 801 | """ | ||
1604 | 802 | The DeferredStore associated with an object is returned by the static | ||
1605 | 803 | C{of} method. | ||
1606 | 804 | """ | ||
1607 | 805 | def cb_get(result): | ||
1608 | 806 | store = DeferredStore.of(result) | ||
1609 | 807 | self.assertIdentical(self.store, store) | ||
1610 | 808 | |||
1611 | 809 | return self.store.get(Foo, 10).addCallback(cb_get) | ||
1612 | 810 | |||
1613 | 811 | |||
1614 | 812 | def test_thread_check(self): | ||
1615 | 813 | """ | ||
1616 | 814 | A L{ThreadSafetyError} is raised when attempting to do an unsafe | ||
1617 | 815 | operation, like accessing a C{Reference} attribute via a | ||
1618 | 816 | C{DeferredStore}. | ||
1619 | 817 | """ | ||
1620 | 818 | class WrongBar(Bar): | ||
1621 | 819 | foo_sync = Reference(Bar.foo_id, Foo.id) | ||
1622 | 820 | d = self.store.find(WrongBar).addCallback(lambda result: result.all()) | ||
1623 | 821 | |||
1624 | 822 | def check(results): | ||
1625 | 823 | self.assertRaises( | ||
1626 | 824 | ThreadSafetyError, getattr, results[0], "foo_sync") | ||
1627 | 825 | return d.addCallback(check) | ||
1628 | 826 | |||
1629 | 827 | |||
1630 | 828 | |||
1631 | 829 | class StoreThreadTestCase(TestCase): | ||
1632 | 830 | """ | ||
1633 | 831 | Tests for L{StoreThread}. | ||
1634 | 832 | """ | ||
1635 | 833 | |||
1636 | 834 | def setUp(self): | ||
1637 | 835 | """ | ||
1638 | 836 | Create an instance of C{StoreThread} and start it. | ||
1639 | 837 | """ | ||
1640 | 838 | self.thread = StoreThread() | ||
1641 | 839 | self.thread.start() | ||
1642 | 840 | |||
1643 | 841 | |||
1644 | 842 | def tearDown(self): | ||
1645 | 843 | """ | ||
1646 | 844 | Kill the running thread. | ||
1647 | 845 | """ | ||
1648 | 846 | self.thread.stop() | ||
1649 | 847 | |||
1650 | 848 | |||
1651 | 849 | def test_defer_after_stop(self): | ||
1652 | 850 | """ | ||
1653 | 851 | Deferring calls after store is stopped raises C{AlreadyStopped}. | ||
1654 | 852 | """ | ||
1655 | 853 | def cb_stop(r): | ||
1656 | 854 | self.assertFailure(self.thread.defer_to_thread(lambda f : None), | ||
1657 | 855 | AlreadyStopped) | ||
1658 | 856 | return self.thread.stop().addCallback(cb_stop) | ||
1659 | 857 | |||
1660 | 858 | |||
1661 | 859 | def test_callback(self): | ||
1662 | 860 | """ | ||
1663 | 861 | Fire a simple function in a thread and check its result. | ||
1664 | 862 | """ | ||
1665 | 863 | def testfunc(): | ||
1666 | 864 | return 1 | ||
1667 | 865 | return self.thread.defer_to_thread(testfunc | ||
1668 | 866 | ).addCallback(self.assertEquals, 1) | ||
1669 | 867 | |||
1670 | 868 | |||
1671 | 869 | def test_errback(self): | ||
1672 | 870 | """ | ||
1673 | 871 | Raising an exception in a thread returns a failure. | ||
1674 | 872 | """ | ||
1675 | 873 | def testfunc(): | ||
1676 | 874 | raise RuntimeError("Error!") | ||
1677 | 875 | return self.assertFailure(self.thread.defer_to_thread(testfunc), | ||
1678 | 876 | RuntimeError) | ||
1679 | 877 | |||
1680 | 878 | |||
1681 | 879 | |||
1682 | 880 | class StorePoolTest(object): | ||
1683 | 881 | """ | ||
1684 | 882 | Tests for L{StorePool}. | ||
1685 | 883 | """ | ||
1686 | 884 | |||
1687 | 885 | def setUp(self): | ||
1688 | 886 | """ | ||
1689 | 887 | Build a database with data, a create a pool. | ||
1690 | 888 | """ | ||
1691 | 889 | self.create_database() | ||
1692 | 890 | connection = self.connection = self.database.connect() | ||
1693 | 891 | self.drop_tables() | ||
1694 | 892 | self.create_tables() | ||
1695 | 893 | connection.execute("INSERT INTO foo VALUES (10, 'Title 30')") | ||
1696 | 894 | connection.execute("INSERT INTO bar VALUES (10, 10, 'Title 40')") | ||
1697 | 895 | connection.execute("INSERT INTO bar VALUES (11, 10, 'Title 50')") | ||
1698 | 896 | connection.commit() | ||
1699 | 897 | self.pool = StorePool(self.database, 2, 5) | ||
1700 | 898 | return self.pool.start() | ||
1701 | 899 | |||
1702 | 900 | |||
1703 | 901 | def tearDown(self): | ||
1704 | 902 | """ | ||
1705 | 903 | Stop the pool. | ||
1706 | 904 | """ | ||
1707 | 905 | def _drop(ign): | ||
1708 | 906 | self.drop_tables() | ||
1709 | 907 | return self.pool.stop().addCallback(_drop) | ||
1710 | 908 | |||
1711 | 909 | |||
1712 | 910 | def drop_tables(self): | ||
1713 | 911 | for table in ["foo", "bar"]: | ||
1714 | 912 | try: | ||
1715 | 913 | self.connection.execute("DROP TABLE %s" % table) | ||
1716 | 914 | self.connection.commit() | ||
1717 | 915 | except: | ||
1718 | 916 | self.connection.rollback() | ||
1719 | 917 | |||
1720 | 918 | |||
1721 | 919 | def test_already_started(self): | ||
1722 | 920 | """ | ||
1723 | 921 | Check that the pool can't be restarted multiple times. | ||
1724 | 922 | """ | ||
1725 | 923 | self.assertRaises(RuntimeError, self.pool.start) | ||
1726 | 924 | |||
1727 | 925 | |||
1728 | 926 | def test_get(self): | ||
1729 | 927 | """ | ||
1730 | 928 | get should return different stores if available. | ||
1731 | 929 | """ | ||
1732 | 930 | def cb_get1(store1): | ||
1733 | 931 | return self.pool.get().addCallback(cb_get2, store1) | ||
1734 | 932 | def cb_get2(store2, store1): | ||
1735 | 933 | self.assertNotIdentical(store1, store2) | ||
1736 | 934 | self.assertTrue(store1.started) | ||
1737 | 935 | self.assertTrue(store2.started) | ||
1738 | 936 | return self.pool.get().addCallback(cb_get1) | ||
1739 | 937 | |||
1740 | 938 | |||
1741 | 939 | def test_get_not_started(self): | ||
1742 | 940 | """ | ||
1743 | 941 | If no store are available, the pool should create a store. | ||
1744 | 942 | """ | ||
1745 | 943 | def cb(ign): | ||
1746 | 944 | self.assertEquals(self.pool._stores_created, 0) | ||
1747 | 945 | return self.pool.adjust_size(0, 5).addCallback(cb_add) | ||
1748 | 946 | def cb_add(ign): | ||
1749 | 947 | return self.pool.get().addCallback(cb_get) | ||
1750 | 948 | def cb_get(store): | ||
1751 | 949 | self.assertTrue(store.started) | ||
1752 | 950 | return self.pool.adjust_size(0, 0).addCallback(cb) | ||
1753 | 951 | |||
1754 | 952 | |||
1755 | 953 | def test_waiting_for_store(self): | ||
1756 | 954 | """ | ||
1757 | 955 | Test waiting for a store availability. | ||
1758 | 956 | """ | ||
1759 | 957 | def cb(ign): | ||
1760 | 958 | return self.pool.get().addCallback(cb_get1) | ||
1761 | 959 | def cb_get1(store1): | ||
1762 | 960 | self.assertTrue(store1.started) | ||
1763 | 961 | # Now we have a store, no store should be returned by the pool | ||
1764 | 962 | # until we put it back | ||
1765 | 963 | d1 = self.pool.get().addCallback(cb_get2, store1) | ||
1766 | 964 | d2 = self.pool.put(store1) | ||
1767 | 965 | return gatherResults([d1, d2]) | ||
1768 | 966 | def cb_get2(store2, store1): | ||
1769 | 967 | self.assertIdentical(store1, store2) | ||
1770 | 968 | return self.pool.adjust_size(1, 1).addCallback(cb) | ||
1771 | 969 | |||
1772 | 970 | |||
1773 | 971 | def test_adjust_size_minmax(self): | ||
1774 | 972 | """ | ||
1775 | 973 | Test sanity check on min/max - i.e. min <= max. | ||
1776 | 974 | """ | ||
1777 | 975 | return self.assertFailure(self.pool.adjust_size(2, 1), ValueError) | ||
1778 | 976 | |||
1779 | 977 | |||
1780 | 978 | def test_adjust_size_nonnegative(self): | ||
1781 | 979 | """ | ||
1782 | 980 | Test sanity check for nonnegative min. | ||
1783 | 981 | """ | ||
1784 | 982 | return self.assertFailure(self.pool.adjust_size(-1), ValueError) | ||
1785 | 983 | |||
1786 | 984 | |||
1787 | 985 | @deferredGenerator | ||
1788 | 986 | def test_concurrent_data(self): | ||
1789 | 987 | """ | ||
1790 | 988 | Test that different stores have different states: if the first store | ||
1791 | 989 | hasn't yet committed, the second one shouldn't get the new data. | ||
1792 | 990 | """ | ||
1793 | 991 | foo = Foo() | ||
1794 | 992 | foo.title = u"Great title" | ||
1795 | 993 | foo.id = 11 | ||
1796 | 994 | d = self.pool.get() | ||
1797 | 995 | wfd = waitForDeferred(d) | ||
1798 | 996 | yield wfd | ||
1799 | 997 | store1 = wfd.getResult() | ||
1800 | 998 | |||
1801 | 999 | d = self.pool.get() | ||
1802 | 1000 | wfd = waitForDeferred(d) | ||
1803 | 1001 | yield wfd | ||
1804 | 1002 | store2 = wfd.getResult() | ||
1805 | 1003 | |||
1806 | 1004 | d = store1.add(foo) | ||
1807 | 1005 | wfd = waitForDeferred(d) | ||
1808 | 1006 | yield wfd | ||
1809 | 1007 | wfd.getResult() | ||
1810 | 1008 | |||
1811 | 1009 | d = store1.get(Foo, 11) | ||
1812 | 1010 | wfd = waitForDeferred(d) | ||
1813 | 1011 | yield wfd | ||
1814 | 1012 | foo2 = wfd.getResult() | ||
1815 | 1013 | |||
1816 | 1014 | # The object is already in the store cache | ||
1817 | 1015 | self.assertIdentical(foo2, foo) | ||
1818 | 1016 | |||
1819 | 1017 | d = store2.get(Foo, 11) | ||
1820 | 1018 | wfd = waitForDeferred(d) | ||
1821 | 1019 | yield wfd | ||
1822 | 1020 | foo3 = wfd.getResult() | ||
1823 | 1021 | |||
1824 | 1022 | # The object isn't in the db yet | ||
1825 | 1023 | self.assertIdentical(foo3, None) | ||
1826 | 1024 | |||
1827 | 1025 | # Let's rollback, because even select open a transaction | ||
1828 | 1026 | d = store2.rollback() | ||
1829 | 1027 | wfd = waitForDeferred(d) | ||
1830 | 1028 | yield wfd | ||
1831 | 1029 | wfd.getResult() | ||
1832 | 1030 | |||
1833 | 1031 | # Let's commit | ||
1834 | 1032 | d = store1.commit() | ||
1835 | 1033 | wfd = waitForDeferred(d) | ||
1836 | 1034 | yield wfd | ||
1837 | 1035 | wfd.getResult() | ||
1838 | 1036 | |||
1839 | 1037 | d = store2.get(Foo, 11) | ||
1840 | 1038 | wfd = waitForDeferred(d) | ||
1841 | 1039 | yield wfd | ||
1842 | 1040 | foo4 = wfd.getResult() | ||
1843 | 1041 | |||
1844 | 1042 | # The objects must be different | ||
1845 | 1043 | self.assertNotIdentical(foo4, foo) | ||
1846 | 1044 | # But the content must be the same | ||
1847 | 1045 | self.assertEquals(foo4.title, u"Great title") | ||
1848 | 1046 | |||
1849 | 1047 | |||
1850 | 1048 | def test_no_overflow(self): | ||
1851 | 1049 | """ | ||
1852 | 1050 | Test that pool does not allocate more connections than store_max. | ||
1853 | 1051 | """ | ||
1854 | 1052 | ds = [] | ||
1855 | 1053 | stores = set() | ||
1856 | 1054 | |||
1857 | 1055 | def cb_get(store): | ||
1858 | 1056 | stores.add(store) | ||
1859 | 1057 | return self.pool.put(store) | ||
1860 | 1058 | |||
1861 | 1059 | for i in range(10): | ||
1862 | 1060 | ds.append(self.pool.get().addCallback(cb_get)) | ||
1863 | 1061 | |||
1864 | 1062 | def checkInstances(result): | ||
1865 | 1063 | self.assertEquals(len(stores), 5) | ||
1866 | 1064 | |||
1867 | 1065 | return gatherResults(ds).addCallback(checkInstances) | ||
1868 | 1066 | |||
1869 | 1067 | |||
1870 | 1068 | def test_start_failure(self): | ||
1871 | 1069 | """ | ||
1872 | 1070 | If a store failed to start, the number of allocated connections doesn't | ||
1873 | 1071 | grow, so we're later able to start more stores. | ||
1874 | 1072 | """ | ||
1875 | 1073 | ds = [] | ||
1876 | 1074 | stores = set() | ||
1877 | 1075 | |||
1878 | 1076 | class DontStartStore(DeferredStore): | ||
1879 | 1077 | def start(self): | ||
1880 | 1078 | return fail(RuntimeError("oops")) | ||
1881 | 1079 | |||
1882 | 1080 | calls = [] | ||
1883 | 1081 | |||
1884 | 1082 | def vicious_store_factory(database): | ||
1885 | 1083 | if not calls: | ||
1886 | 1084 | store = DontStartStore(database) | ||
1887 | 1085 | else: | ||
1888 | 1086 | store = DeferredStore(database) | ||
1889 | 1087 | calls.append(None) | ||
1890 | 1088 | return store | ||
1891 | 1089 | |||
1892 | 1090 | self.pool.store_factory = vicious_store_factory | ||
1893 | 1091 | |||
1894 | 1092 | def cb_get(store): | ||
1895 | 1093 | stores.add(store) | ||
1896 | 1094 | return self.pool.put(store) | ||
1897 | 1095 | |||
1898 | 1096 | errors = [] | ||
1899 | 1097 | |||
1900 | 1098 | def save_errors(failure): | ||
1901 | 1099 | errors.append(failure) | ||
1902 | 1100 | |||
1903 | 1101 | for i in range(6): | ||
1904 | 1102 | ds.append( | ||
1905 | 1103 | self.pool.get().addCallback(cb_get).addErrback(save_errors)) | ||
1906 | 1104 | |||
1907 | 1105 | def checkInstances(result): | ||
1908 | 1106 | self.assertEquals(len(stores), 5) | ||
1909 | 1107 | self.assertEquals(len(errors), 1) | ||
1910 | 1108 | errors[0].trap(RuntimeError) | ||
1911 | 1109 | |||
1912 | 1110 | return gatherResults(ds).addCallback(checkInstances) | ||
1913 | 1111 | |||
1914 | 1112 | |||
1915 | 1113 | def test_arguments(self): | ||
1916 | 1114 | """ | ||
1917 | 1115 | Arguments are passed along to transacted method when store | ||
1918 | 1116 | is available. | ||
1919 | 1117 | """ | ||
1920 | 1118 | def tx(store, a, b=None): | ||
1921 | 1119 | self.assertIsInstance(store, DeferredStore) | ||
1922 | 1120 | self.assertEquals(1, a) | ||
1923 | 1121 | self.assertEquals(2, b) | ||
1924 | 1122 | return succeed("ok") | ||
1925 | 1123 | return self.pool.transact(tx, 1, b=2) | ||
1926 | 1124 | |||
1927 | 1125 | |||
1928 | 1126 | def test_commit(self): | ||
1929 | 1127 | """ | ||
1930 | 1128 | Changes made inside a successful transaction are committed. | ||
1931 | 1129 | """ | ||
1932 | 1130 | @deferredGenerator | ||
1933 | 1131 | def check(result): | ||
1934 | 1132 | d = self.pool.get() | ||
1935 | 1133 | wfd = waitForDeferred(d) | ||
1936 | 1134 | yield wfd | ||
1937 | 1135 | store = wfd.getResult() | ||
1938 | 1136 | d = store.execute("SELECT * FROM foo ORDER BY id", []) | ||
1939 | 1137 | wfd = waitForDeferred(d) | ||
1940 | 1138 | yield wfd | ||
1941 | 1139 | dr = wfd.getResult() | ||
1942 | 1140 | d = dr.get_all() | ||
1943 | 1141 | wfd = waitForDeferred(d) | ||
1944 | 1142 | yield wfd | ||
1945 | 1143 | results = wfd.getResult() | ||
1946 | 1144 | self.assertEquals(2, len(results)) | ||
1947 | 1145 | self.assertEquals(1, results[0][0]) | ||
1948 | 1146 | self.assertEquals("test", results[0][1]) | ||
1949 | 1147 | |||
1950 | 1148 | def tx(store): | ||
1951 | 1149 | d = store.execute("INSERT INTO foo(id, title) " | ||
1952 | 1150 | "VALUES (1, 'test')", noresult=True) | ||
1953 | 1151 | return d | ||
1954 | 1152 | |||
1955 | 1153 | return self.pool.transact(tx).addCallback(check) | ||
1956 | 1154 | |||
1957 | 1155 | |||
1958 | 1156 | def test_rollback(self): | ||
1959 | 1157 | """ | ||
1960 | 1158 | Changes made inside a failed transaction are rolled back. | ||
1961 | 1159 | """ | ||
1962 | 1160 | @deferredGenerator | ||
1963 | 1161 | def check(reason): | ||
1964 | 1162 | d = self.pool.get() | ||
1965 | 1163 | wfd = waitForDeferred(d) | ||
1966 | 1164 | yield wfd | ||
1967 | 1165 | store = wfd.getResult() | ||
1968 | 1166 | d = store.execute("SELECT * FROM foo", []) | ||
1969 | 1167 | wfd = waitForDeferred(d) | ||
1970 | 1168 | yield wfd | ||
1971 | 1169 | dr = wfd.getResult() | ||
1972 | 1170 | d = dr.get_all() | ||
1973 | 1171 | wfd = waitForDeferred(d) | ||
1974 | 1172 | yield wfd | ||
1975 | 1173 | results = wfd.getResult() | ||
1976 | 1174 | self.assertEquals(1, len(results)) | ||
1977 | 1175 | |||
1978 | 1176 | @deferredGenerator | ||
1979 | 1177 | def tx(store): | ||
1980 | 1178 | d = store.execute("INSERT INTO foo(id, title) " | ||
1981 | 1179 | "VALUES (1, 'test')", []) | ||
1982 | 1180 | wfd = waitForDeferred(d) | ||
1983 | 1181 | yield wfd | ||
1984 | 1182 | wfd.getResult() | ||
1985 | 1183 | d = store.execute("INSERT INTO foo(id, title) " | ||
1986 | 1184 | "VALUES (1, 'test')", []) | ||
1987 | 1185 | wfd = waitForDeferred(d) | ||
1988 | 1186 | yield wfd | ||
1989 | 1187 | wfd.getResult() | ||
1990 | 1188 | |||
1991 | 1189 | return self.pool.transact(tx).addBoth(check) | ||
1992 | 1190 | |||
1993 | 1191 | |||
1994 | 1192 | def test_return_value(self): | ||
1995 | 1193 | """ | ||
1996 | 1194 | Final return value should match return value from successful call to | ||
1997 | 1195 | transacted function. | ||
1998 | 1196 | """ | ||
1999 | 1197 | def check(result): | ||
2000 | 1198 | self.assertEquals("completed", result) | ||
2001 | 1199 | |||
2002 | 1200 | def cb(result): | ||
2003 | 1201 | return "completed" | ||
2004 | 1202 | |||
2005 | 1203 | def tx(store): | ||
2006 | 1204 | return store.execute("INSERT INTO foo(id, title) " | ||
2007 | 1205 | "VALUES (1, 'test')", [] | ||
2008 | 1206 | ).addCallback(cb) | ||
2009 | 1207 | |||
2010 | 1208 | return self.pool.transact(tx).addCallback(check) | ||
2011 | 1209 | |||
2012 | 1210 | |||
2013 | 1211 | def test_poolsize_after_success(self): | ||
2014 | 1212 | """ | ||
2015 | 1213 | After successful transaction, pool size should be same size as before. | ||
2016 | 1214 | """ | ||
2017 | 1215 | size = len(self.pool._stores) | ||
2018 | 1216 | |||
2019 | 1217 | def check(result): | ||
2020 | 1218 | self.assertEquals(size, len(self.pool._stores)) | ||
2021 | 1219 | |||
2022 | 1220 | def tx(store): | ||
2023 | 1221 | d = store.execute("SELECT * from foo", []) | ||
2024 | 1222 | return d.addCallback(lambda result: result.get_all()) | ||
2025 | 1223 | |||
2026 | 1224 | return self.pool.transact(tx).addCallback(check) | ||
2027 | 1225 | |||
2028 | 1226 | |||
2029 | 1227 | def test_poolsize_after_failure(self): | ||
2030 | 1228 | """ | ||
2031 | 1229 | After failed transaction, pool size is restored to the initial value. | ||
2032 | 1230 | """ | ||
2033 | 1231 | size = len(self.pool._stores) | ||
2034 | 1232 | |||
2035 | 1233 | def check(reason): | ||
2036 | 1234 | self.assertEquals(size, len(self.pool._stores)) | ||
2037 | 1235 | |||
2038 | 1236 | def tx(store): | ||
2039 | 1237 | return store.execute("SELECT * from not_a_table", []) | ||
2040 | 1238 | |||
2041 | 1239 | d = self.assertFailure(self.pool.transact(tx), OperationalError) | ||
2042 | 1240 | return d.addCallback(check) | ||
2043 | 1241 | |||
2044 | 1242 | |||
2045 | 1243 | def test_failure_propagation(self): | ||
2046 | 1244 | """ | ||
2047 | 1245 | A custom exception is propagated by a C{transact} call. | ||
2048 | 1246 | """ | ||
2049 | 1247 | class MyException(Exception): | ||
2050 | 1248 | pass | ||
2051 | 1249 | |||
2052 | 1250 | def tx(store): | ||
2053 | 1251 | raise MyException("Bad things happened") | ||
2054 | 1252 | |||
2055 | 1253 | return self.assertFailure(self.pool.transact(tx), MyException) | ||
2056 | 1254 | |||
2057 | 1255 | |||
2058 | 1256 | def test_non_deferred_function(self): | ||
2059 | 1257 | """ | ||
2060 | 1258 | C{transact} can handle functions that don't return a C{Deferred}. | ||
2061 | 1259 | """ | ||
2062 | 1260 | def tx(store): | ||
2063 | 1261 | return "foo" | ||
2064 | 1262 | return self.pool.transact(tx).addCallback(self.assertEquals, "foo") | ||
2065 | 1263 | |||
2066 | 1264 | |||
2067 | 1265 | def test_rollback_failure(self): | ||
2068 | 1266 | """ | ||
2069 | 1267 | If C{rollback} fails, the store is put back into the pool. | ||
2070 | 1268 | """ | ||
2071 | 1269 | |||
2072 | 1270 | def cb_get(store): | ||
2073 | 1271 | store.rollback = lambda: fail(RuntimeError("oops")) | ||
2074 | 1272 | d = self.assertFailure(self.pool.put(store), RuntimeError) | ||
2075 | 1273 | return d.addCallback(get_all) | ||
2076 | 1274 | |||
2077 | 1275 | def get_all(ignore): | ||
2078 | 1276 | dl = [] | ||
2079 | 1277 | for i in range(5): | ||
2080 | 1278 | dl.append(self.pool.get()) | ||
2081 | 1279 | return gatherResults(dl).addCallback(check) | ||
2082 | 1280 | |||
2083 | 1281 | def check(stores): | ||
2084 | 1282 | # The fact that we're here show that the test succeeds, because we | ||
2085 | 1283 | # didn't hang waiting for a store | ||
2086 | 1284 | self.assertEquals(len(stores), 5) | ||
2087 | 1285 | |||
2088 | 1286 | return self.pool.get().addCallback(cb_get) | ||
2089 | 0 | 1287 | ||
2090 | === added file 'tests/twisted/mysql.py' | |||
2091 | --- tests/twisted/mysql.py 1970-01-01 00:00:00 +0000 | |||
2092 | +++ tests/twisted/mysql.py 2009-12-27 11:31:15 +0000 | |||
2093 | @@ -0,0 +1,62 @@ | |||
2094 | 1 | import os | ||
2095 | 2 | |||
2096 | 3 | from storm.database import create_database | ||
2097 | 4 | |||
2098 | 5 | from tests.twisted.base import DeferredStoreTest, StorePoolTest | ||
2099 | 6 | |||
2100 | 7 | from twisted.trial.unittest import TestCase | ||
2101 | 8 | |||
2102 | 9 | |||
2103 | 10 | class MySQLDeferredStoreTest(TestCase, DeferredStoreTest): | ||
2104 | 11 | |||
2105 | 12 | def setUp(self): | ||
2106 | 13 | return DeferredStoreTest.setUp(self) | ||
2107 | 14 | |||
2108 | 15 | def tearDown(self): | ||
2109 | 16 | return DeferredStoreTest.tearDown(self) | ||
2110 | 17 | |||
2111 | 18 | def is_supported(self): | ||
2112 | 19 | return bool(os.environ.get("STORM_MYSQL_URI")) | ||
2113 | 20 | |||
2114 | 21 | def create_database(self): | ||
2115 | 22 | self.database = create_database(os.environ["STORM_MYSQL_URI"]) | ||
2116 | 23 | |||
2117 | 24 | def create_tables(self): | ||
2118 | 25 | connection = self.connection | ||
2119 | 26 | connection.execute("CREATE TABLE foo " | ||
2120 | 27 | "(id INT PRIMARY KEY AUTO_INCREMENT," | ||
2121 | 28 | " title VARCHAR(50) DEFAULT 'Default Title')" | ||
2122 | 29 | " ENGINE=InnoDB") | ||
2123 | 30 | connection.execute("CREATE TABLE bar " | ||
2124 | 31 | "(id INT PRIMARY KEY AUTO_INCREMENT," | ||
2125 | 32 | " foo_id INTEGER, title VARCHAR(50))" | ||
2126 | 33 | " ENGINE=InnoDB") | ||
2127 | 34 | connection.execute("CREATE TABLE egg " | ||
2128 | 35 | "(id INT PRIMARY KEY AUTO_INCREMENT, value INTEGER)" | ||
2129 | 36 | " ENGINE=InnoDB") | ||
2130 | 37 | |||
2131 | 38 | |||
2132 | 39 | class MySQLStorePoolTest(TestCase, StorePoolTest): | ||
2133 | 40 | |||
2134 | 41 | def setUp(self): | ||
2135 | 42 | return StorePoolTest.setUp(self) | ||
2136 | 43 | |||
2137 | 44 | def tearDown(self): | ||
2138 | 45 | return StorePoolTest.tearDown(self) | ||
2139 | 46 | |||
2140 | 47 | def is_supported(self): | ||
2141 | 48 | return bool(os.environ.get("STORM_MYSQL_URI")) | ||
2142 | 49 | |||
2143 | 50 | def create_database(self): | ||
2144 | 51 | self.database = create_database(os.environ["STORM_MYSQL_URI"]) | ||
2145 | 52 | |||
2146 | 53 | def create_tables(self): | ||
2147 | 54 | connection = self.connection | ||
2148 | 55 | connection.execute("CREATE TABLE foo " | ||
2149 | 56 | "(id INT PRIMARY KEY AUTO_INCREMENT," | ||
2150 | 57 | " title VARCHAR(50) DEFAULT 'Default Title')" | ||
2151 | 58 | " ENGINE=InnoDB") | ||
2152 | 59 | connection.execute("CREATE TABLE bar " | ||
2153 | 60 | "(id INT PRIMARY KEY AUTO_INCREMENT," | ||
2154 | 61 | " foo_id INTEGER, title VARCHAR(50))" | ||
2155 | 62 | " ENGINE=InnoDB") | ||
2156 | 0 | 63 | ||
2157 | === added file 'tests/twisted/postgres.py' | |||
2158 | --- tests/twisted/postgres.py 1970-01-01 00:00:00 +0000 | |||
2159 | +++ tests/twisted/postgres.py 2009-12-27 11:31:15 +0000 | |||
2160 | @@ -0,0 +1,57 @@ | |||
2161 | 1 | import os | ||
2162 | 2 | |||
2163 | 3 | from storm.database import create_database | ||
2164 | 4 | |||
2165 | 5 | from tests.twisted.base import DeferredStoreTest, StorePoolTest | ||
2166 | 6 | |||
2167 | 7 | from twisted.trial.unittest import TestCase | ||
2168 | 8 | |||
2169 | 9 | |||
2170 | 10 | class PostgresDeferredStoreTest(TestCase, DeferredStoreTest): | ||
2171 | 11 | |||
2172 | 12 | def setUp(self): | ||
2173 | 13 | return DeferredStoreTest.setUp(self) | ||
2174 | 14 | |||
2175 | 15 | def tearDown(self): | ||
2176 | 16 | return DeferredStoreTest.tearDown(self) | ||
2177 | 17 | |||
2178 | 18 | def is_supported(self): | ||
2179 | 19 | return bool(os.environ.get("STORM_POSTGRES_URI")) | ||
2180 | 20 | |||
2181 | 21 | def create_database(self): | ||
2182 | 22 | self.database = create_database(os.environ["STORM_POSTGRES_URI"]) | ||
2183 | 23 | |||
2184 | 24 | def create_tables(self): | ||
2185 | 25 | connection = self.connection | ||
2186 | 26 | connection.execute("CREATE TABLE foo " | ||
2187 | 27 | "(id SERIAL PRIMARY KEY," | ||
2188 | 28 | " title VARCHAR DEFAULT 'Default Title')") | ||
2189 | 29 | connection.execute("CREATE TABLE bar " | ||
2190 | 30 | "(id SERIAL PRIMARY KEY," | ||
2191 | 31 | " foo_id INTEGER, title VARCHAR)") | ||
2192 | 32 | connection.execute("CREATE TABLE egg " | ||
2193 | 33 | "(id SERIAL PRIMARY KEY, value INTEGER)") | ||
2194 | 34 | |||
2195 | 35 | |||
2196 | 36 | class PostgresStorePoolTest(TestCase, StorePoolTest): | ||
2197 | 37 | |||
2198 | 38 | def setUp(self): | ||
2199 | 39 | return StorePoolTest.setUp(self) | ||
2200 | 40 | |||
2201 | 41 | def tearDown(self): | ||
2202 | 42 | return StorePoolTest.tearDown(self) | ||
2203 | 43 | |||
2204 | 44 | def is_supported(self): | ||
2205 | 45 | return bool(os.environ.get("STORM_POSTGRES_URI")) | ||
2206 | 46 | |||
2207 | 47 | def create_database(self): | ||
2208 | 48 | self.database = create_database(os.environ["STORM_POSTGRES_URI"]) | ||
2209 | 49 | |||
2210 | 50 | def create_tables(self): | ||
2211 | 51 | connection = self.connection | ||
2212 | 52 | connection.execute("CREATE TABLE foo " | ||
2213 | 53 | "(id SERIAL PRIMARY KEY," | ||
2214 | 54 | " title VARCHAR DEFAULT 'Default Title')") | ||
2215 | 55 | connection.execute("CREATE TABLE bar " | ||
2216 | 56 | "(id SERIAL PRIMARY KEY," | ||
2217 | 57 | " foo_id INTEGER, title VARCHAR)") | ||
2218 | 0 | 58 | ||
2219 | === added file 'tests/twisted/sqlite.py' | |||
2220 | --- tests/twisted/sqlite.py 1970-01-01 00:00:00 +0000 | |||
2221 | +++ tests/twisted/sqlite.py 2009-12-27 11:31:15 +0000 | |||
2222 | @@ -0,0 +1,65 @@ | |||
2223 | 1 | from storm.databases.sqlite import SQLite | ||
2224 | 2 | from storm.uri import URI | ||
2225 | 3 | |||
2226 | 4 | from tests.twisted.base import DeferredStoreTest, StorePoolTest | ||
2227 | 5 | from tests.helper import TestHelper, MakePath | ||
2228 | 6 | |||
2229 | 7 | from twisted.trial.unittest import TestCase | ||
2230 | 8 | |||
2231 | 9 | |||
2232 | 10 | class SQLiteDeferredStoreTest(TestCase, TestHelper, DeferredStoreTest): | ||
2233 | 11 | |||
2234 | 12 | helpers = [MakePath] | ||
2235 | 13 | |||
2236 | 14 | def setUp(self): | ||
2237 | 15 | TestHelper.setUp(self) | ||
2238 | 16 | return DeferredStoreTest.setUp(self) | ||
2239 | 17 | |||
2240 | 18 | def tearDown(self): | ||
2241 | 19 | def cb(passthrough): | ||
2242 | 20 | TestHelper.tearDown(self) | ||
2243 | 21 | return passthrough | ||
2244 | 22 | return DeferredStoreTest.tearDown(self).addBoth(cb) | ||
2245 | 23 | |||
2246 | 24 | def create_database(self): | ||
2247 | 25 | self.database = SQLite(URI("sqlite:%s?synchronous=OFF" % | ||
2248 | 26 | self.make_path())) | ||
2249 | 27 | |||
2250 | 28 | def create_tables(self): | ||
2251 | 29 | connection = self.connection | ||
2252 | 30 | connection.execute("CREATE TABLE foo " | ||
2253 | 31 | "(id INTEGER PRIMARY KEY," | ||
2254 | 32 | " title VARCHAR DEFAULT 'Default Title')") | ||
2255 | 33 | connection.execute("CREATE TABLE bar " | ||
2256 | 34 | "(id INTEGER PRIMARY KEY," | ||
2257 | 35 | " foo_id INTEGER, title VARCHAR)") | ||
2258 | 36 | connection.execute("CREATE TABLE egg " | ||
2259 | 37 | "(id INTEGER PRIMARY KEY, value INTEGER)") | ||
2260 | 38 | |||
2261 | 39 | |||
2262 | 40 | class SQLiteStorePoolTest(TestCase, TestHelper, StorePoolTest): | ||
2263 | 41 | |||
2264 | 42 | helpers = [MakePath] | ||
2265 | 43 | |||
2266 | 44 | def setUp(self): | ||
2267 | 45 | TestHelper.setUp(self) | ||
2268 | 46 | return StorePoolTest.setUp(self) | ||
2269 | 47 | |||
2270 | 48 | def tearDown(self): | ||
2271 | 49 | def cb(passthrough): | ||
2272 | 50 | TestHelper.tearDown(self) | ||
2273 | 51 | return passthrough | ||
2274 | 52 | return StorePoolTest.tearDown(self).addBoth(cb) | ||
2275 | 53 | |||
2276 | 54 | def create_database(self): | ||
2277 | 55 | self.database = SQLite(URI("sqlite:%s?synchronous=OFF" % | ||
2278 | 56 | self.make_path())) | ||
2279 | 57 | |||
2280 | 58 | def create_tables(self): | ||
2281 | 59 | connection = self.connection | ||
2282 | 60 | connection.execute("CREATE TABLE foo " | ||
2283 | 61 | "(id INTEGER PRIMARY KEY," | ||
2284 | 62 | " title VARCHAR DEFAULT 'Default Title')") | ||
2285 | 63 | connection.execute("CREATE TABLE bar " | ||
2286 | 64 | "(id INTEGER PRIMARY KEY," | ||
2287 | 65 | " foo_id INTEGER, title VARCHAR)") |
[1] Why setDaemon(True)
[2] can you use underscores for DeferredStore
[3] doc for get explaining that it loads all attributes
[4] It may be useful to extend the core Store API to allow us to avoid accessing internal attributes in the wrapper. Like, for example, Store.disable_ lazy_loading( ). I dunno.
[5] global mapping of storm store to deferred store? maybe?
[6] Make it so that objects that came from stores which were then put back into a pool are "broken" so that any attribute access breaks in the future.