32
32
MultiAccountId ,
33
33
)
34
34
from websockets .asyncio .client import connect
35
- from websockets .exceptions import ConnectionClosed
35
+ from websockets .exceptions import ConnectionClosed , WebSocketException
36
36
37
37
from async_substrate_interface .const import SS58_FORMAT
38
38
from async_substrate_interface .errors import (
75
75
ResultHandler = Callable [[dict , Any ], Awaitable [tuple [dict , bool ]]]
76
76
77
77
logger = logging .getLogger ("async_substrate_interface" )
78
+ raw_websocket_logger = logging .getLogger ("raw_websocket" )
78
79
79
80
80
81
class AsyncExtrinsicReceipt :
@@ -505,6 +506,7 @@ def __init__(
505
506
max_connections = 100 ,
506
507
shutdown_timer = 5 ,
507
508
options : Optional [dict ] = None ,
509
+ _log_raw_websockets : bool = False ,
508
510
):
509
511
"""
510
512
Websocket manager object. Allows for the use of a single websocket connection by multiple
@@ -532,6 +534,10 @@ def __init__(
532
534
self ._exit_task = None
533
535
self ._open_subscriptions = 0
534
536
self ._options = options if options else {}
537
+ self ._log_raw_websockets = _log_raw_websockets
538
+ self ._is_connecting = False
539
+ self ._is_closing = False
540
+
535
541
try :
536
542
now = asyncio .get_running_loop ().time ()
537
543
except RuntimeError :
@@ -556,38 +562,63 @@ async def __aenter__(self):
556
562
async def loop_time () -> float :
557
563
return asyncio .get_running_loop ().time ()
558
564
565
+ async def _cancel (self ):
566
+ try :
567
+ self ._receiving_task .cancel ()
568
+ await self ._receiving_task
569
+ await self .ws .close ()
570
+ except (
571
+ AttributeError ,
572
+ asyncio .CancelledError ,
573
+ WebSocketException ,
574
+ ):
575
+ pass
576
+ except Exception as e :
577
+ logger .warning (
578
+ f"{ e } encountered while trying to close websocket connection."
579
+ )
580
+
559
581
async def connect (self , force = False ):
560
- now = await self .loop_time ()
561
- self .last_received = now
562
- self .last_sent = now
563
- if self ._exit_task :
564
- self ._exit_task .cancel ()
565
- async with self ._lock :
566
- if not self ._initialized or force :
567
- try :
568
- self ._receiving_task .cancel ()
569
- await self ._receiving_task
570
- await self .ws .close ()
571
- except (AttributeError , asyncio .CancelledError ):
572
- pass
573
- self .ws = await asyncio .wait_for (
574
- connect (self .ws_url , ** self ._options ), timeout = 10
575
- )
576
- self ._receiving_task = asyncio .create_task (self ._start_receiving ())
577
- self ._initialized = True
582
+ self ._is_connecting = True
583
+ try :
584
+ now = await self .loop_time ()
585
+ self .last_received = now
586
+ self .last_sent = now
587
+ if self ._exit_task :
588
+ self ._exit_task .cancel ()
589
+ if not self ._is_closing :
590
+ if not self ._initialized or force :
591
+ try :
592
+ await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
593
+ except asyncio .TimeoutError :
594
+ pass
595
+
596
+ self .ws = await asyncio .wait_for (
597
+ connect (self .ws_url , ** self ._options ), timeout = 10.0
598
+ )
599
+ self ._receiving_task = asyncio .get_running_loop ().create_task (
600
+ self ._start_receiving ()
601
+ )
602
+ self ._initialized = True
603
+ finally :
604
+ self ._is_connecting = False
578
605
579
606
async def __aexit__ (self , exc_type , exc_val , exc_tb ):
580
- async with self ._lock : # TODO is this actually what I want to happen?
581
- self ._in_use -= 1
582
- if self ._exit_task is not None :
583
- self ._exit_task .cancel ()
584
- try :
585
- await self ._exit_task
586
- except asyncio .CancelledError :
587
- pass
588
- if self ._in_use == 0 and self .ws is not None :
589
- self ._open_subscriptions = 0
590
- self ._exit_task = asyncio .create_task (self ._exit_with_timer ())
607
+ self ._is_closing = True
608
+ try :
609
+ if not self ._is_connecting :
610
+ self ._in_use -= 1
611
+ if self ._exit_task is not None :
612
+ self ._exit_task .cancel ()
613
+ try :
614
+ await self ._exit_task
615
+ except asyncio .CancelledError :
616
+ pass
617
+ if self ._in_use == 0 and self .ws is not None :
618
+ self ._open_subscriptions = 0
619
+ self ._exit_task = asyncio .create_task (self ._exit_with_timer ())
620
+ finally :
621
+ self ._is_closing = False
591
622
592
623
async def _exit_with_timer (self ):
593
624
"""
@@ -601,26 +632,24 @@ async def _exit_with_timer(self):
601
632
pass
602
633
603
634
async def shutdown (self ):
604
- async with self ._lock :
605
- try :
606
- self ._receiving_task .cancel ()
607
- await self ._receiving_task
608
- await self .ws .close ()
609
- except (AttributeError , asyncio .CancelledError ):
610
- pass
611
- self .ws = None
612
- self ._initialized = False
613
- self ._receiving_task = None
635
+ self ._is_closing = True
636
+ try :
637
+ await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
638
+ except asyncio .TimeoutError :
639
+ pass
640
+ self .ws = None
641
+ self ._initialized = False
642
+ self ._receiving_task = None
643
+ self ._is_closing = False
614
644
615
645
async def _recv (self ) -> None :
616
646
try :
617
647
# TODO consider wrapping this in asyncio.wait_for and use that for the timeout logic
618
- response = json .loads (await self .ws .recv (decode = False ))
648
+ recd = await self .ws .recv (decode = False )
649
+ if self ._log_raw_websockets :
650
+ raw_websocket_logger .debug (f"WEBSOCKET_RECEIVE> { recd .decode ()} " )
651
+ response = json .loads (recd )
619
652
self .last_received = await self .loop_time ()
620
- async with self ._lock :
621
- # note that these 'subscriptions' are all waiting sent messages which have not received
622
- # responses, and thus are not the same as RPC 'subscriptions', which are unique
623
- self ._open_subscriptions -= 1
624
653
if "id" in response :
625
654
self ._received [response ["id" ]] = response
626
655
self ._in_use_ids .remove (response ["id" ])
@@ -640,8 +669,7 @@ async def _start_receiving(self):
640
669
except asyncio .CancelledError :
641
670
pass
642
671
except ConnectionClosed :
643
- async with self ._lock :
644
- await self .connect (force = True )
672
+ await self .connect (force = True )
645
673
646
674
async def send (self , payload : dict ) -> int :
647
675
"""
@@ -660,12 +688,14 @@ async def send(self, payload: dict) -> int:
660
688
# self._open_subscriptions += 1
661
689
await self .max_subscriptions .acquire ()
662
690
try :
663
- await self .ws .send (json .dumps ({** payload , ** {"id" : original_id }}))
691
+ to_send = {** payload , ** {"id" : original_id }}
692
+ if self ._log_raw_websockets :
693
+ raw_websocket_logger .debug (f"WEBSOCKET_SEND> { to_send } " )
694
+ await self .ws .send (json .dumps (to_send ))
664
695
self .last_sent = await self .loop_time ()
665
696
return original_id
666
697
except (ConnectionClosed , ssl .SSLError , EOFError ):
667
- async with self ._lock :
668
- await self .connect (force = True )
698
+ await self .connect (force = True )
669
699
670
700
async def retrieve (self , item_id : int ) -> Optional [dict ]:
671
701
"""
@@ -699,6 +729,8 @@ def __init__(
699
729
max_retries : int = 5 ,
700
730
retry_timeout : float = 60.0 ,
701
731
_mock : bool = False ,
732
+ _log_raw_websockets : bool = False ,
733
+ ws_shutdown_timer : float = 5.0 ,
702
734
):
703
735
"""
704
736
The asyncio-compatible version of the subtensor interface commands we use in bittensor. It is important to
@@ -716,20 +748,25 @@ def __init__(
716
748
max_retries: number of times to retry RPC requests before giving up
717
749
retry_timeout: how to long wait since the last ping to retry the RPC request
718
750
_mock: whether to use mock version of the subtensor interface
751
+ _log_raw_websockets: whether to log raw websocket requests during RPC requests
752
+ ws_shutdown_timer: how long after the last connection your websocket should close
719
753
720
754
"""
721
755
self .max_retries = max_retries
722
756
self .retry_timeout = retry_timeout
723
757
self .chain_endpoint = url
724
758
self .url = url
725
759
self ._chain = chain_name
760
+ self ._log_raw_websockets = _log_raw_websockets
726
761
if not _mock :
727
762
self .ws = Websocket (
728
763
url ,
764
+ _log_raw_websockets = _log_raw_websockets ,
729
765
options = {
730
766
"max_size" : self .ws_max_size ,
731
767
"write_limit" : 2 ** 16 ,
732
768
},
769
+ shutdown_timer = ws_shutdown_timer ,
733
770
)
734
771
else :
735
772
self .ws = AsyncMock (spec = Websocket )
0 commit comments