Skip to content
21 changes: 12 additions & 9 deletions bittensor/axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def __init__(
self.priority_fns = {}
self.forward_fns = {}
self.verify_fns = {}
self.required_hash_fields = {}

# Instantiate FastAPI
self.app = FastAPI()
Expand Down Expand Up @@ -397,6 +398,12 @@ def attach(
) # Use 'default_verify' if 'verify_fn' is None
self.forward_fns[request_name] = forward_fn

# Parse required hash fields from the forward function protocol defaults
required_hash_fields = request_class.__dict__["__fields__"][
"required_hash_fields"
].default
self.required_hash_fields[request_name] = required_hash_fields

return self

@classmethod
Expand Down Expand Up @@ -481,8 +488,7 @@ def add_args(cls, parser: argparse.ArgumentParser, prefix: str = None):
# Exception handling for re-parsing arguments
pass

@staticmethod
async def verify_body_integrity(request: Request):
async def verify_body_integrity(self, request: Request):
"""
Asynchronously verifies the integrity of the body of a request by comparing the hash of required fields
with the corresponding hashes provided in the request headers. This method is critical for ensuring
Expand Down Expand Up @@ -514,12 +520,9 @@ async def some_endpoint(body_dict: dict = Depends(verify_body_integrity)):
body = await request.body()
request_body = body.decode() if isinstance(body, bytes) else body

# Gather the required field names from the headers of the request
required_hash_fields = json.loads(
base64.b64decode(request.headers.get("hash_fields", "").encode()).decode(
"utf-8"
)
)
# Gather the required field names from the axon's required_hash_fields dict
request_name = request.url.path.split("/")[1]
required_hash_fields = self.required_hash_fields[request_name]

# Load the body dict and check if all required field hashes match
body_dict = json.loads(request_body)
Expand Down Expand Up @@ -628,7 +631,7 @@ def serve(
subtensor.serve_axon(netuid=netuid, axon=self)
return self

def default_verify(self, synapse: bittensor.Synapse):
async def default_verify(self, synapse: bittensor.Synapse):
"""
This method is used to verify the authenticity of a received message using a digital signature.
It ensures that the message was not tampered with and was sent by the expected sender.
Expand Down
73 changes: 8 additions & 65 deletions bittensor/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def set_name_type(cls, values) -> dict:
repr=False,
)

hash_fields: Optional[List[str]] = pydantic.Field(
title="hash_fields",
required_hash_fields: Optional[List[str]] = pydantic.Field(
title="required_hash_fields",
description="The list of required fields to compute the body hash.",
examples=["roles", "messages"],
default=[],
Expand All @@ -319,61 +319,12 @@ def __setattr__(self, name: str, value: Any):
"""
Override the __setattr__ method to make the `required_hash_fields` property read-only.
"""
if name == "required_hash_fields":
raise AttributeError(
"required_hash_fields property is read-only and cannot be overridden."
)
if name == "body_hash":
raise AttributeError(
"body_hash property is read-only and cannot be overridden."
)
super().__setattr__(name, value)

@property
def required_hash_fields(self) -> List[str]:
"""
Retrieve the list of non-optional fields of the Synapse instance.

This default method identifies and returns the names of non-optional attributes of the Synapse
instance that have non-null values, excluding specific attributes such as `name`, `timeout`,
`total_size`, `header_size`, `dendrite`, and `axon`. The determination of whether a field is
optional or not is based on the schema definition for the Synapse class.

Subclasses are encouraged to override this method to provide their own implementation for
determining required fields. If not overridden, the default implementation provided by the
Synapse superclass will be used, which returns the fields based on the schema definition.

Returns:
List[str]: A list of names of the non-optional fields of the Synapse instance.
"""
fields = []
# Getting the fields of the instance
instance_fields = self.__dict__

# Iterating over the fields of the instance
for field, value in instance_fields.items():
# If the object is not optional and non-null, add to the list of returned body fields
required = schema([self.__class__])["definitions"][self.name].get(
"required"
)
if (
required
and field in required
and value != None
and field
not in [
"name",
"timeout",
"total_size",
"header_size",
"dendrite",
"axon",
]
and "_hash" not in field
):
fields.append(field)
return fields

def get_total_size(self) -> int:
"""
Get the total size of the current object.
Expand Down Expand Up @@ -542,33 +493,30 @@ def to_headers(self) -> dict:
headers["header_size"] = str(sys.getsizeof(headers))
headers["total_size"] = str(self.get_total_size())
headers["computed_body_hash"] = self.body_hash
headers["hash_fields"] = base64.b64encode(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the axon need to know the other parsed hash_fields other then the required_hash_field?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, that is a hold over, nice catch! Will remove.

json.dumps(self.required_hash_fields).encode()
).decode("utf-8")

return headers

@property
def body_hash(self) -> str:
"""
Compute a SHA-256 hash of the serialized body of the Synapse instance.
Compute a SHA3-256 hash of the serialized body of the Synapse instance.

The body of the Synapse instance comprises its serialized and encoded
non-optional fields. This property retrieves these fields using the
`get_body` method, then concatenates their string representations, and
finally computes a SHA-256 hash of the resulting string.
`required_fields_hash` field, then concatenates their string representations,
and finally computes a SHA3-256 hash of the resulting string.

Returns:
str: The hexadecimal representation of the SHA-256 hash of the instance's body.
str: The hexadecimal representation of the SHA3-256 hash of the instance's body.
"""
# Hash the body for verification
hashes = []

# Getting the fields of the instance
instance_fields = self.__dict__

# Iterating over the fields of the instance
for field, value in instance_fields.items():
# If the field is required in the subclass schema, add it.
# If the field is required in the subclass schema, hash and add it.
if field in self.required_hash_fields:
hashes.append(bittensor.utils.hash(str(value)))

Expand Down Expand Up @@ -689,11 +637,6 @@ def parse_headers_to_inputs(cls, headers: dict) -> dict:
inputs_dict["header_size"] = headers.get("header_size", None)
inputs_dict["total_size"] = headers.get("total_size", None)
inputs_dict["computed_body_hash"] = headers.get("computed_body_hash", None)
inputs_dict["hash_fields"] = json.loads(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same question as the previous one) does the axon need to know the other parsed hash_fields other then the required_hash_field?

base64.b64decode(headers.get("hash_fields", "W10=").encode()).decode(
"utf-8"
)
)

return inputs_dict

Expand Down
8 changes: 4 additions & 4 deletions bittensor/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,11 @@ def u8_key_to_ss58(u8_key: List[int]) -> str:
return scalecodec.ss58_encode(bytes(u8_key).hex(), bittensor.__ss58_format__)


def hash(content, hash_type="md5", encoding="utf-8"):
algo = hashlib.md5() if hash_type == "md5" else hashlib.sha256()
def hash(content, encoding="utf-8"):
sha3 = hashlib.sha3_256()

# Update the hash object with the concatenated string
algo.update(content.encode(encoding))
sha3.update(content.encode(encoding))

# Produce the hash
return algo.hexdigest()
return sha3.hexdigest()
13 changes: 3 additions & 10 deletions tests/unit_tests/test_synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ class Test(bittensor.Synapse):
"header_size": "111",
"total_size": "111",
"computed_body_hash": "0xabcdef",
"hash_fields": base64.b64encode(
json.dumps(["key1", "key2"]).encode("utf-8")
).decode("utf-8"),
}
print(headers)

Expand All @@ -60,7 +57,6 @@ class Test(bittensor.Synapse):
"header_size": "111",
"total_size": "111",
"computed_body_hash": "0xabcdef",
"hash_fields": ["key1", "key2"],
}


Expand All @@ -82,9 +78,6 @@ class Test(bittensor.Synapse):
"header_size": "111",
"total_size": "111",
"computed_body_hash": "0xabcdef",
"hash_fields": base64.b64encode(
json.dumps(["key1", "key2"]).encode("utf-8")
).decode("utf-8"),
}

# Run the function to test
Expand Down Expand Up @@ -264,9 +257,9 @@ def test_required_fields_override():
# Create a Synapse instance
synapse_instance = bittensor.Synapse()

# Try to set the body_hash property and expect an AttributeError
# Try to set the required_hash_fields property and expect a TypeError
with pytest.raises(
AttributeError,
match="required_hash_fields property is read-only and cannot be overridden.",
TypeError,
match='"required_hash_fields" has allow_mutation set to False and cannot be assigned',
):
synapse_instance.required_hash_fields = []