Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 39 additions & 29 deletions bittensor/utils/async_substrate_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from websockets.asyncio.client import connect
from websockets.exceptions import ConnectionClosed

from bittensor.utils import hex_to_bytes

if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection

Expand Down Expand Up @@ -464,6 +466,9 @@ def __init__(self, chain, runtime_config, metadata, type_registry):
self.runtime_config = runtime_config
self.metadata = metadata

def __str__(self):
return f"Runtime: {self.chain} | {self.config}"

@property
def implements_scaleinfo(self) -> bool:
"""
Expand Down Expand Up @@ -647,15 +652,12 @@ async def __aenter__(self):
self._exit_task.cancel()
if not self._initialized:
self._initialized = True
await self._connect()
self.ws = await asyncio.wait_for(
connect(self.ws_url, **self._options), timeout=10
)
self._receiving_task = asyncio.create_task(self._start_receiving())
return self

async def _connect(self):
self.ws = await asyncio.wait_for(
connect(self.ws_url, **self._options), timeout=10
)

async def __aexit__(self, exc_type, exc_val, exc_tb):
async with self._lock:
self._in_use -= 1
Expand Down Expand Up @@ -696,7 +698,7 @@ async def shutdown(self):

async def _recv(self) -> None:
try:
response = json.loads(await cast(ClientConnection, self.ws).recv())
response = json.loads(await self.ws.recv())
async with self._lock:
self._open_subscriptions -= 1
if "id" in response:
Expand Down Expand Up @@ -770,11 +772,13 @@ def __init__(
"""
self.chain_endpoint = chain_endpoint
self.__chain = chain_name
options = {
"max_size": 2**32,
"write_limit": 2**16,
}
self.ws = Websocket(chain_endpoint, options=options)
self.ws = Websocket(
chain_endpoint,
options={
"max_size": 2**32,
"write_limit": 2**16,
},
)
Comment on lines +775 to +781
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe consider using try... expect... wrapper around Websocket(...) to be able to catch the error to be good

self._lock = asyncio.Lock()
self.last_block_hash: Optional[str] = None
self.config = {
Expand Down Expand Up @@ -896,9 +900,10 @@ async def init_runtime(

async def get_runtime(block_hash, block_id) -> Runtime:
# Check if runtime state already set to current block
if (block_hash and block_hash == self.last_block_hash) or (
block_id and block_id == self.block_id
):
if (
(block_hash and block_hash == self.last_block_hash)
or (block_id and block_id == self.block_id)
) and self.metadata is not None:
return Runtime(
self.chain,
self.runtime_config,
Expand Down Expand Up @@ -944,9 +949,11 @@ async def get_runtime(block_hash, block_id) -> Runtime:
raise SubstrateRequestException(
f"No runtime information for block '{block_hash}'"
)

# Check if runtime state already set to current block
if runtime_info.get("specVersion") == self.runtime_version:
if (
runtime_info.get("specVersion") == self.runtime_version
and self.metadata is not None
):
return Runtime(
self.chain,
self.runtime_config,
Expand All @@ -961,16 +968,19 @@ async def get_runtime(block_hash, block_id) -> Runtime:
if self.runtime_version in self.__metadata_cache:
# Get metadata from cache
# self.debug_message('Retrieved metadata for {} from memory'.format(self.runtime_version))
self.metadata = self.__metadata_cache[self.runtime_version]
metadata = self.metadata = self.__metadata_cache[
self.runtime_version
]
else:
self.metadata = await self.get_block_metadata(
metadata = self.metadata = await self.get_block_metadata(
block_hash=runtime_block_hash, decode=True
)
# self.debug_message('Retrieved metadata for {} from Substrate node'.format(self.runtime_version))

# Update metadata cache
self.__metadata_cache[self.runtime_version] = self.metadata

else:
metadata = self.metadata
# Update type registry
self.reload_type_registry(use_remote_preset=False, auto_discover=True)

Expand Down Expand Up @@ -1011,7 +1021,10 @@ async def get_runtime(block_hash, block_id) -> Runtime:
if block_id and block_hash:
raise ValueError("Cannot provide block_hash and block_id at the same time")

if not (runtime := self.runtime_cache.retrieve(block_id, block_hash)):
if (
not (runtime := self.runtime_cache.retrieve(block_id, block_hash))
or runtime.metadata is None
):
runtime = await get_runtime(block_hash, block_id)
self.runtime_cache.add_item(block_id, block_hash, runtime)
return runtime
Expand Down Expand Up @@ -2271,7 +2284,7 @@ async def get_metadata_constant(self, module_name, constant_name, block_hash=Non
MetadataModuleConstants
"""

# await self.init_runtime(block_hash=block_hash)
await self.init_runtime(block_hash=block_hash)

for module in self.metadata.pallets:
if module_name == module.name and module.constants:
Expand All @@ -2285,7 +2298,7 @@ async def get_constant(
constant_name: str,
block_hash: Optional[str] = None,
reuse_block_hash: bool = False,
) -> "ScaleType":
) -> Optional["ScaleType"]:
"""
Returns the decoded `ScaleType` object of the constant for given module name, call function name and block_hash
(or chaintip if block_hash is omitted)
Expand Down Expand Up @@ -2364,7 +2377,7 @@ async def query(
raw_storage_key: Optional[bytes] = None,
subscription_handler=None,
reuse_block_hash: bool = False,
) -> Union["ScaleType"]:
) -> "ScaleType":
"""
Queries subtensor. This should only be used when making a single request. For multiple requests,
you should use ``self.query_multiple``
Expand Down Expand Up @@ -2551,10 +2564,7 @@ def concat_hash_len(key_hasher: str) -> int:
item_key = None

try:
try:
item_bytes = bytes.fromhex(item[1][2:])
except ValueError:
item_bytes = bytes.fromhex(item[1])
item_bytes = hex_to_bytes(item[1])

item_value = await self.decode_scale(
type_string=value_type,
Expand Down Expand Up @@ -2720,7 +2730,7 @@ async def get_metadata_call_function(
return call
return None

async def get_block_number(self, block_hash: Optional[str] = None) -> int:
async def get_block_number(self, block_hash: Optional[str]) -> int:
"""Async version of `substrateinterface.base.get_block_number` method."""
response = await self.rpc_request("chain_getHeader", [block_hash])

Expand Down