1
1
from __future__ import unicode_literals
2
2
3
+ import contextlib
3
4
import msgpack
4
5
import os
5
6
import random
@@ -92,18 +93,17 @@ def receive(self, channels, block=False):
92
93
# Keep looping on the channel until
93
94
# we hit no messages or an unexpired one
94
95
while True :
95
- try :
96
- message , expires = self .message_store .pop_message (channel )
97
- message = msgpack .unpackb (message , encoding = "utf8" )
98
- if expires <= time .time ():
99
- continue
100
- # If there is a full channel name stored in the message, unpack it.
101
- if "__asgi_channel__" in message :
102
- channel = message ['__asgi_channel__' ]
103
- del message ['__asgi_channel__' ]
104
- return channel , message
105
- except ValueError :
96
+ message , expires = self .message_store .pop_message (channel )
97
+ if message is None :
106
98
break
99
+ message = msgpack .unpackb (message , encoding = "utf8" )
100
+ if expires <= time .time ():
101
+ continue
102
+ # If there is a full channel name stored in the message, unpack it.
103
+ if "__asgi_channel__" in message :
104
+ channel = message ['__asgi_channel__' ]
105
+ del message ['__asgi_channel__' ]
106
+ return channel , message
107
107
return None , None
108
108
109
109
def new_channel (self , pattern ):
@@ -191,7 +191,10 @@ def connection(self):
191
191
Caching, threadlocal connection accessor.
192
192
"""
193
193
if not hasattr (self ._locals , "connection" ):
194
- self ._locals .connection = sqlite3 .connect (self .db_path )
194
+ self ._locals .connection = sqlite3 .connect (
195
+ self .db_path ,
196
+ isolation_level = "IMMEDIATE" ,
197
+ )
195
198
self ._locals .connection .text_factory = str
196
199
return self ._locals .connection
197
200
@@ -221,23 +224,26 @@ def _reset(self):
221
224
delattr (self ._locals , "semaphore" )
222
225
223
226
def flush (self ):
224
- self .connection . cursor (). execute ( 'DELETE FROM {table_name}' . format ( table_name = self . table_name ))
225
- self .connection . commit ( )
227
+ with self .semaphore_manager ():
228
+ self ._execute ( 'DELETE FROM {table_name}' . format ( table_name = self . table_name ) )
226
229
227
230
def _execute (self , query , * args ):
231
+ cursor = self .connection .cursor ()
232
+ cursor .execute (query .format (table = self .table_name ), args )
233
+ self .connection .commit ()
234
+ return cursor .fetchall ()
235
+
236
+ @contextlib .contextmanager
237
+ def semaphore_manager (self ):
228
238
try :
229
239
self .semaphore .acquire (self .death_timeout )
230
240
except posix_ipc .BusyError :
231
241
self ._reset ()
232
242
self .semaphore .acquire (0 )
233
243
try :
234
- cursor = self .connection .cursor ()
235
- cursor .execute (query .format (table = self .table_name ), args )
236
- result = cursor .fetchall ()
237
- self .connection .commit ()
244
+ yield
238
245
finally :
239
246
self .semaphore .release ()
240
- return result
241
247
242
248
243
249
class MessageTable (SqliteTable ):
@@ -248,25 +254,43 @@ class MessageTable(SqliteTable):
248
254
'''
249
255
250
256
def get_messages (self , channel ):
251
- return self ._execute ('SELECT message, expiry FROM {table} WHERE channel=?' , channel ) or (None , None )
257
+ with self .semaphore_manager ():
258
+ return (
259
+ self ._execute ('SELECT message, expiry FROM {table} WHERE channel=?' , channel ) or
260
+ (None , None )
261
+ )
252
262
253
263
def get_message_count (self , channel ):
254
- return self ._execute ('SELECT COUNT(*) FROM {table} WHERE channel=?' , channel )[0 ][0 ]
264
+ with self .semaphore_manager ():
265
+ return self ._execute ('SELECT COUNT(*) FROM {table} WHERE channel=?' , channel )[0 ][0 ]
255
266
256
267
def add_message (self , message , expiry , channel ):
257
- self ._execute ('INSERT INTO {table} (channel, message, expiry) VALUES (?,?,?)' , channel , message , expiry )
268
+ with self .semaphore_manager ():
269
+ self ._execute ('INSERT INTO {table} (channel, message, expiry) VALUES (?,?,?)' , channel , message , expiry )
258
270
259
271
def pop_message (self , channel ):
260
- result = self ._execute ('SELECT id, message, expiry FROM {table} WHERE channel=? LIMIT 1' , channel )
261
- if not result :
262
- raise ValueError ('No message in channel' )
263
- result = result [0 ]
264
- self ._execute ('DELETE FROM {table} WHERE id=?' , result [0 ])
265
- return result [1 ], result [2 ]
272
+ """
273
+ Atomically reads and removes a message from the messages table.
274
+ """
275
+ with self .semaphore_manager ():
276
+ cursor = self .connection .cursor ()
277
+ cursor .execute ("BEGIN" )
278
+ cursor .execute (
279
+ 'SELECT id, message, expiry FROM {table} WHERE channel=? LIMIT 1' .format (table = self .table_name ),
280
+ (channel , )
281
+ )
282
+ result = cursor .fetchall ()
283
+ self .connection .commit ()
284
+ if not result :
285
+ return None , None
286
+ row = result [0 ]
287
+ cursor .execute ('DELETE FROM {table} WHERE id=?' .format (table = self .table_name ), (row [0 ], ))
288
+ self .connection .commit ()
289
+ return row [1 ], row [2 ]
266
290
267
291
def __contains__ (self , value ):
268
- result = self ._execute ( 'SELECT COUNT(*) FROM {table} WHERE channel=?' , value )[ 0 ][ 0 ]
269
- return bool (result )
292
+ with self .semaphore_manager ():
293
+ return bool (self . _execute ( 'SELECT COUNT(*) FROM {table} WHERE channel=?' , value )[ 0 ][ 0 ] )
270
294
271
295
272
296
class GroupTable (SqliteTable ):
@@ -277,14 +301,18 @@ class GroupTable(SqliteTable):
277
301
'''
278
302
279
303
def add_channel (self , group , channel , expiry ):
280
- self ._execute ('INSERT INTO {table} (channel, group_name, expiry) VALUES (?,?,?)' , channel , group , expiry )
304
+ with self .semaphore_manager ():
305
+ self ._execute ('INSERT INTO {table} (channel, group_name, expiry) VALUES (?,?,?)' , channel , group , expiry )
281
306
282
307
def discard_channel (self , group , channel ):
283
- self ._execute ('DELETE FROM {table} WHERE group_name=? AND channel=?' , group , channel )
308
+ with self .semaphore_manager ():
309
+ self ._execute ('DELETE FROM {table} WHERE group_name=? AND channel=?' , group , channel )
284
310
285
311
def _cleanup (self , group ):
286
- self ._execute ('DELETE FROM {table} WHERE group_name=? AND expiry<=?' , group , time .time ())
312
+ with self .semaphore_manager ():
313
+ self ._execute ('DELETE FROM {table} WHERE group_name=? AND expiry<=?' , group , time .time ())
287
314
288
315
def get_current_channels (self , group ):
289
- self ._cleanup (group )
290
- return self ._execute ('SELECT DISTINCT channel FROM {table} WHERE group_name=?' , group )
316
+ with self .semaphore_manager ():
317
+ self ._cleanup (group )
318
+ return self ._execute ('SELECT DISTINCT channel FROM {table} WHERE group_name=?' , group )
0 commit comments