@@ -475,24 +475,32 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
475475 ):
476476 raise error
477477
478- # COMMAND EXECUTION AND PROTOCOL PARSING
479- async def execute_command (self , * args , ** options ):
480- """Execute a command and return a parsed response"""
481- await self .initialize ()
482- pool = self .connection_pool
483- command_name = args [0 ]
484- conn = self .connection or await pool .get_connection (command_name , ** options )
485-
478+ async def _try_send_command_parse_response (self , conn , * args , ** options ):
486479 try :
487480 return await conn .retry .call_with_retry (
488481 lambda : self ._send_command_parse_response (
489- conn , command_name , * args , ** options
482+ conn , args [ 0 ] , * args , ** options
490483 ),
491484 lambda error : self ._disconnect_raise (conn , error ),
492485 )
486+ except asyncio .CancelledError :
487+ await conn .disconnect (nowait = True )
488+ raise
493489 finally :
494490 if not self .connection :
495- await pool .release (conn )
491+ await self .connection_pool .release (conn )
492+
493+ # COMMAND EXECUTION AND PROTOCOL PARSING
494+ async def execute_command (self , * args , ** options ):
495+ """Execute a command and return a parsed response"""
496+ await self .initialize ()
497+ pool = self .connection_pool
498+ command_name = args [0 ]
499+ conn = self .connection or await pool .get_connection (command_name , ** options )
500+
501+ return await asyncio .shield (
502+ self ._try_send_command_parse_response (conn , * args , ** options )
503+ )
496504
497505 async def parse_response (
498506 self , connection : Connection , command_name : Union [str , bytes ], ** options
@@ -726,10 +734,18 @@ async def _disconnect_raise_connect(self, conn, error):
726734 is not a TimeoutError. Otherwise, try to reconnect
727735 """
728736 await conn .disconnect ()
737+
729738 if not (conn .retry_on_timeout and isinstance (error , TimeoutError )):
730739 raise error
731740 await conn .connect ()
732741
742+ async def _try_execute (self , conn , command , * arg , ** kwargs ):
743+ try :
744+ return await command (* arg , ** kwargs )
745+ except asyncio .CancelledError :
746+ await conn .disconnect ()
747+ raise
748+
733749 async def _execute (self , conn , command , * args , ** kwargs ):
734750 """
735751 Connect manually upon disconnection. If the Redis server is down,
@@ -738,9 +754,11 @@ async def _execute(self, conn, command, *args, **kwargs):
738754 called by the # connection to resubscribe us to any channels and
739755 patterns we were previously listening to
740756 """
741- return await conn .retry .call_with_retry (
742- lambda : command (* args , ** kwargs ),
743- lambda error : self ._disconnect_raise_connect (conn , error ),
757+ return await asyncio .shield (
758+ conn .retry .call_with_retry (
759+ lambda : self ._try_execute (conn , command , * args , ** kwargs ),
760+ lambda error : self ._disconnect_raise_connect (conn , error ),
761+ )
744762 )
745763
746764 async def parse_response (self , block : bool = True , timeout : float = 0 ):
@@ -1140,6 +1158,18 @@ async def _disconnect_reset_raise(self, conn, error):
11401158 await self .reset ()
11411159 raise
11421160
1161+ async def _try_send_command_parse_response (self , conn , * args , ** options ):
1162+ try :
1163+ return await conn .retry .call_with_retry (
1164+ lambda : self ._send_command_parse_response (
1165+ conn , args [0 ], * args , ** options
1166+ ),
1167+ lambda error : self ._disconnect_reset_raise (conn , error ),
1168+ )
1169+ except asyncio .CancelledError :
1170+ await conn .disconnect ()
1171+ raise
1172+
11431173 async def immediate_execute_command (self , * args , ** options ):
11441174 """
11451175 Execute a command immediately, but don't auto-retry on a
@@ -1155,13 +1185,13 @@ async def immediate_execute_command(self, *args, **options):
11551185 command_name , self .shard_hint
11561186 )
11571187 self .connection = conn
1158-
1159- return await conn . retry . call_with_retry (
1160- lambda : self ._send_command_parse_response (
1161- conn , command_name , * args , ** options
1162- ),
1163- lambda error : self . _disconnect_reset_raise ( conn , error ),
1164- )
1188+ try :
1189+ return await asyncio . shield (
1190+ self ._try_send_command_parse_response ( conn , * args , ** options )
1191+ )
1192+ except asyncio . CancelledError :
1193+ await conn . disconnect ()
1194+ raise
11651195
11661196 def pipeline_execute_command (self , * args , ** options ):
11671197 """
@@ -1328,6 +1358,19 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
13281358 await self .reset ()
13291359 raise
13301360
1361+ async def _try_execute (self , conn , execute , stack , raise_on_error ):
1362+ try :
1363+ return await conn .retry .call_with_retry (
1364+ lambda : execute (conn , stack , raise_on_error ),
1365+ lambda error : self ._disconnect_raise_reset (conn , error ),
1366+ )
1367+ except asyncio .CancelledError :
1368+ # not supposed to be possible, yet here we are
1369+ await conn .disconnect (nowait = True )
1370+ raise
1371+ finally :
1372+ await self .reset ()
1373+
13311374 async def execute (self , raise_on_error : bool = True ):
13321375 """Execute all the commands in the current pipeline"""
13331376 stack = self .command_stack
@@ -1350,15 +1393,10 @@ async def execute(self, raise_on_error: bool = True):
13501393
13511394 try :
13521395 return await asyncio .shield (
1353- conn .retry .call_with_retry (
1354- lambda : execute (conn , stack , raise_on_error ),
1355- lambda error : self ._disconnect_raise_reset (conn , error ),
1356- )
1396+ self ._try_execute (conn , execute , stack , raise_on_error )
13571397 )
1358- except asyncio .CancelledError :
1359- # not supposed to be possible, yet here we are
1360- await conn .disconnect (nowait = True )
1361- raise
1398+ except RuntimeError :
1399+ await self .reset ()
13621400 finally :
13631401 await self .reset ()
13641402
0 commit comments