Skip to content

Improved Query Map Decodes #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 18 additions & 62 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@
)
from async_substrate_interface.utils.storage import StorageKey
from async_substrate_interface.type_registry import _TYPE_REGISTRY
from async_substrate_interface.utils.decoding import (
decode_query_map,
)

if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection
Expand Down Expand Up @@ -898,7 +901,7 @@ async def decode_scale(
else:
return obj

async def load_runtime(self, runtime):
def load_runtime(self, runtime):
self.runtime = runtime

# Update type registry
Expand Down Expand Up @@ -954,7 +957,7 @@ async def init_runtime(
)

if self.runtime and runtime_version == self.runtime.runtime_version:
return
return self.runtime

runtime = self.runtime_cache.retrieve(runtime_version=runtime_version)
if not runtime:
Expand Down Expand Up @@ -990,7 +993,7 @@ async def init_runtime(
runtime_version=runtime_version, runtime=runtime
)

await self.load_runtime(runtime)
self.load_runtime(runtime)

if self.ss58_format is None:
# Check and apply runtime constants
Expand All @@ -1000,6 +1003,7 @@ async def init_runtime(

if ss58_prefix_constant:
self.ss58_format = ss58_prefix_constant
return runtime

async def create_storage_key(
self,
Expand Down Expand Up @@ -2892,12 +2896,11 @@ async def query_map(
Returns:
AsyncQueryMapResult object
"""
hex_to_bytes_ = hex_to_bytes
params = params or []
block_hash = await self._get_current_block_hash(block_hash, reuse_block_hash)
if block_hash:
self.last_block_hash = block_hash
await self.init_runtime(block_hash=block_hash)
runtime = await self.init_runtime(block_hash=block_hash)

metadata_pallet = self.runtime.metadata.get_metadata_pallet(module)
if not metadata_pallet:
Expand Down Expand Up @@ -2952,19 +2955,6 @@ async def query_map(
result = []
last_key = None

def concat_hash_len(key_hasher: str) -> int:
"""
Helper function to avoid if statements
"""
if key_hasher == "Blake2_128Concat":
return 16
elif key_hasher == "Twox64Concat":
return 8
elif key_hasher == "Identity":
return 0
else:
raise ValueError("Unsupported hash type")

if len(result_keys) > 0:
last_key = result_keys[-1]

Expand All @@ -2975,51 +2965,17 @@ def concat_hash_len(key_hasher: str) -> int:

if "error" in response:
raise SubstrateRequestException(response["error"]["message"])

for result_group in response["result"]:
for item in result_group["changes"]:
try:
# Determine type string
key_type_string = []
for n in range(len(params), len(param_types)):
key_type_string.append(
f"[u8; {concat_hash_len(key_hashers[n])}]"
)
key_type_string.append(param_types[n])

item_key_obj = await self.decode_scale(
type_string=f"({', '.join(key_type_string)})",
scale_bytes=bytes.fromhex(item[0][len(prefix) :]),
return_scale_obj=True,
)

# strip key_hashers to use as item key
if len(param_types) - len(params) == 1:
item_key = item_key_obj[1]
else:
item_key = tuple(
item_key_obj[key + 1]
for key in range(len(params), len(param_types) + 1, 2)
)

except Exception as _:
if not ignore_decoding_errors:
raise
item_key = None

try:
item_bytes = hex_to_bytes_(item[1])

item_value = await self.decode_scale(
type_string=value_type,
scale_bytes=item_bytes,
return_scale_obj=True,
)
except Exception as _:
if not ignore_decoding_errors:
raise
item_value = None
result.append([item_key, item_value])
result = decode_query_map(
result_group["changes"],
prefix,
runtime,
param_types,
params,
value_type,
key_hashers,
ignore_decoding_errors,
)
return AsyncQueryMapResult(
records=result,
page_size=page_size,
Expand Down
76 changes: 17 additions & 59 deletions async_substrate_interface/sync_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from async_substrate_interface.utils.decoding import (
_determine_if_old_runtime_call,
_bt_decode_to_dict_or_list,
decode_query_map,
)
from async_substrate_interface.utils.storage import StorageKey
from async_substrate_interface.type_registry import _TYPE_REGISTRY
Expand Down Expand Up @@ -525,7 +526,9 @@ def __enter__(self):
return self

def __del__(self):
self.close()
self.ws.close()
print("DELETING SUBSTATE")
# self.ws.protocol.fail(code=1006) # ABNORMAL_CLOSURE

def initialize(self):
"""
Expand Down Expand Up @@ -703,7 +706,7 @@ def init_runtime(
)

if self.runtime and runtime_version == self.runtime.runtime_version:
return
return self.runtime

runtime = self.runtime_cache.retrieve(runtime_version=runtime_version)
if not runtime:
Expand Down Expand Up @@ -757,6 +760,7 @@ def init_runtime(

if ss58_prefix_constant:
self.ss58_format = ss58_prefix_constant
return runtime

def create_storage_key(
self,
Expand Down Expand Up @@ -2598,7 +2602,7 @@ def query_map(
block_hash = self._get_current_block_hash(block_hash, reuse_block_hash)
if block_hash:
self.last_block_hash = block_hash
self.init_runtime(block_hash=block_hash)
runtime = self.init_runtime(block_hash=block_hash)

metadata_pallet = self.runtime.metadata.get_metadata_pallet(module)
if not metadata_pallet:
Expand Down Expand Up @@ -2654,19 +2658,6 @@ def query_map(
result = []
last_key = None

def concat_hash_len(key_hasher: str) -> int:
"""
Helper function to avoid if statements
"""
if key_hasher == "Blake2_128Concat":
return 16
elif key_hasher == "Twox64Concat":
return 8
elif key_hasher == "Identity":
return 0
else:
raise ValueError("Unsupported hash type")

if len(result_keys) > 0:
last_key = result_keys[-1]

Expand All @@ -2679,49 +2670,16 @@ def concat_hash_len(key_hasher: str) -> int:
raise SubstrateRequestException(response["error"]["message"])

for result_group in response["result"]:
for item in result_group["changes"]:
try:
# Determine type string
key_type_string = []
for n in range(len(params), len(param_types)):
key_type_string.append(
f"[u8; {concat_hash_len(key_hashers[n])}]"
)
key_type_string.append(param_types[n])

item_key_obj = self.decode_scale(
type_string=f"({', '.join(key_type_string)})",
scale_bytes=bytes.fromhex(item[0][len(prefix) :]),
return_scale_obj=True,
)

# strip key_hashers to use as item key
if len(param_types) - len(params) == 1:
item_key = item_key_obj[1]
else:
item_key = tuple(
item_key_obj[key + 1]
for key in range(len(params), len(param_types) + 1, 2)
)

except Exception as _:
if not ignore_decoding_errors:
raise
item_key = None

try:
item_bytes = hex_to_bytes_(item[1])

item_value = self.decode_scale(
type_string=value_type,
scale_bytes=item_bytes,
return_scale_obj=True,
)
except Exception as _:
if not ignore_decoding_errors:
raise
item_value = None
result.append([item_key, item_value])
result = decode_query_map(
result_group["changes"],
prefix,
runtime,
param_types,
params,
value_type,
key_hashers,
ignore_decoding_errors,
)
return QueryMapResult(
records=result,
page_size=page_size,
Expand Down
94 changes: 92 additions & 2 deletions async_substrate_interface/utils/decoding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from typing import Union
from typing import Union, TYPE_CHECKING

from bt_decode import AxonInfo, PrometheusInfo
from bt_decode import AxonInfo, PrometheusInfo, decode_list
from scalecodec import ss58_encode
from bittensor_wallet.utils import SS58_FORMAT

from async_substrate_interface.utils import hex_to_bytes
from async_substrate_interface.types import ScaleObj

if TYPE_CHECKING:
from async_substrate_interface.types import Runtime


def _determine_if_old_runtime_call(runtime_call_def, metadata_v15_value) -> bool:
Expand Down Expand Up @@ -44,3 +52,85 @@ def _bt_decode_to_dict_or_list(obj) -> Union[dict, list[dict]]:
else:
as_dict[key] = val
return as_dict


def _decode_scale_list_with_runtime(
type_strings: list[str],
scale_bytes_list: list[bytes],
runtime_registry,
return_scale_obj: bool = False,
):
obj = decode_list(type_strings, runtime_registry, scale_bytes_list)
if return_scale_obj:
return [ScaleObj(x) for x in obj]
else:
return obj


def decode_query_map(
result_group_changes,
prefix,
runtime: "Runtime",
param_types,
params,
value_type,
key_hashers,
ignore_decoding_errors,
):
def concat_hash_len(key_hasher: str) -> int:
"""
Helper function to avoid if statements
"""
if key_hasher == "Blake2_128Concat":
return 16
elif key_hasher == "Twox64Concat":
return 8
elif key_hasher == "Identity":
return 0
else:
raise ValueError("Unsupported hash type")

hex_to_bytes_ = hex_to_bytes

result = []
# Determine type string
key_type_string_ = []
for n in range(len(params), len(param_types)):
key_type_string_.append(f"[u8; {concat_hash_len(key_hashers[n])}]")
key_type_string_.append(param_types[n])
key_type_string = f"({', '.join(key_type_string_)})"

pre_decoded_keys = []
pre_decoded_key_types = [key_type_string] * len(result_group_changes)
pre_decoded_values = []
pre_decoded_value_types = [value_type] * len(result_group_changes)

for item in result_group_changes:
pre_decoded_keys.append(bytes.fromhex(item[0][len(prefix) :]))
pre_decoded_values.append(hex_to_bytes_(item[1]))
all_decoded = _decode_scale_list_with_runtime(
pre_decoded_key_types + pre_decoded_value_types,
pre_decoded_keys + pre_decoded_values,
runtime.registry,
)
middl_index = len(all_decoded) // 2
decoded_keys = all_decoded[:middl_index]
decoded_values = [ScaleObj(x) for x in all_decoded[middl_index:]]
for dk, dv in zip(decoded_keys, decoded_values):
try:
# strip key_hashers to use as item key
if len(param_types) - len(params) == 1:
item_key = dk[1]
else:
item_key = tuple(
dk[key + 1] for key in range(len(params), len(param_types) + 1, 2)
)

except Exception as _:
if not ignore_decoding_errors:
raise
item_key = None

item_value = dv
result.append([item_key, item_value])
return result
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies = [
"wheel",
"asyncstdlib~=3.13.0",
"bittensor-wallet>=2.1.3",
"bt-decode==v0.5.0",
"bt-decode==v0.6.0",
"scalecodec~=1.2.11",
"websockets>=14.1",
"xxhash"
Expand Down
Loading
Loading