Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bittensor/axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def default_verify(self, synapse: bittensor.Synapse) -> Request:
keypair = Keypair(ss58_address=synapse.dendrite.hotkey)

# Build the signature messages.
message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{self.wallet.hotkey.ss58_address}.{synapse.dendrite.uuid}"
message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{self.wallet.hotkey.ss58_address}.{synapse.dendrite.uuid}.{synapse.body_hash}"

# Build the unique endpoint key.
endpoint_key = f"{synapse.dendrite.hotkey}:{synapse.dendrite.uuid}"
Expand Down
2 changes: 1 addition & 1 deletion bittensor/dendrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def preprocess_synapse_for_request(
)

# Sign the request using the dendrite and axon information
message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{synapse.axon.hotkey}.{synapse.dendrite.uuid}"
message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{synapse.axon.hotkey}.{synapse.dendrite.uuid}.{synapse.body_hash}"
synapse.dendrite.signature = f"0x{self.keypair.sign(message).hex()}"

return synapse
Expand Down
126 changes: 120 additions & 6 deletions bittensor/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@

import ast
import sys
import torch
import pickle
import base64
import typing
import hashlib
import pydantic
from pydantic.schema import schema
import bittensor
from abc import abstractmethod
from fastapi.responses import Response
from fastapi import Request
from typing import Dict, Optional, Tuple, Union, List, Callable
from typing import Optional, List, Any


def get_size(obj, seen=None):
Expand Down Expand Up @@ -204,7 +201,37 @@ class Config:
)


class Synapse(pydantic.BaseModel):
class ProtectOverride(type):
"""
Metaclass to prevent subclasses from overriding specified methods or attributes.

When a subclass attempts to override a protected attribute or method, a `TypeError` is raised.
The current implementation specifically checks for overriding the 'body_hash' attribute.

Overriding `protected_method` in a subclass of `MyClass` will raise a TypeError.
"""

def __new__(cls, name, bases, class_dict):
# Check if the derived class tries to override the 'body_hash' method or attribute.
if (
any(base for base in bases if hasattr(base, "body_hash"))
and "body_hash" in class_dict
):
raise TypeError("You can't override the body_hash attribute!")
return super(ProtectOverride, cls).__new__(cls, name, bases, class_dict)


class CombinedMeta(ProtectOverride, type(pydantic.BaseModel)):
"""
Metaclass combining functionality of ProtectOverride and BaseModel's metaclass.

Inherits the attributes and methods from both parent metaclasses to provide combined behavior.
"""

pass


class Synapse(pydantic.BaseModel, metaclass=CombinedMeta):
class Config:
validate_assignment = True

Expand Down Expand Up @@ -285,6 +312,16 @@ def set_name_type(cls, values):
repr=False,
)

def __setattr__(self, name, value):
"""
Override the __setattr__ method to make the body_hash property read-only.
"""
if name == "body_hash":
raise AttributeError(
"body_hash property is read-only and cannot be overridden."
)
super().__setattr__(name, value)

def get_total_size(self) -> int:
"""
Get the total size of the current object.
Expand All @@ -298,6 +335,83 @@ def get_total_size(self) -> int:
self.total_size = get_size(self)
return self.total_size

def get_body(self) -> List[Any]:
"""
Retrieve the serialized and encoded non-optional fields of the Synapse instance.

This method filters through the fields of the Synapse instance and identifies
non-optional attributes that have non-null values, excluding specific attributes
such as `name`, `timeout`, `total_size`, `header_size`, `dendrite`, and `axon`.
It returns a list containing these selected field values.

Returns:
List[Any]: A list of values from the non-optional fields of the Synapse instance.

Note:
The determination of whether a field is optional or not is based on the
schema definition for the Synapse class.
"""
fields = []

# Getting the fields of the instance
instance_fields = self.__dict__

# Iterating over the fields of the instance
for field, value in instance_fields.items():
# If the object is not optional and non-null, add to the list of returned body fields
required = schema([self.__class__])["definitions"][self.name].get(
"required"
)
if (
required
and value != None
and field
not in [
"name",
"timeout",
"total_size",
"header_size",
"dendrite",
"axon",
]
):
fields.append(value)

return fields

@property
def body_hash(self) -> str:
"""
Compute a SHA-256 hash of the serialized body of the Synapse instance.

The body of the Synapse instance comprises its serialized and encoded
non-optional fields. This property retrieves these fields using the
`get_body` method, then concatenates their string representations, and
finally computes a SHA-256 hash of the resulting string.

Note:
This property is intended to be read-only. Any attempts to override
or set its value will raise an AttributeError due to the protections
set in the __setattr__ method.

Returns:
str: The hexadecimal representation of the SHA-256 hash of the instance's body.
"""
# Hash the body for verification
body = self.get_body()

# Convert elements to string and concatenate
concat = "".join(map(str, body))

# Create a SHA-256 hash object
sha256 = hashlib.sha256()

# Update the hash object with the concatenated string
sha256.update(concat.encode("utf-8"))

# Produce the hash
return sha256.hexdigest()

@property
def is_success(self) -> bool:
"""
Expand Down
24 changes: 21 additions & 3 deletions tests/unit_tests/test_synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
import typing
import pytest
import bittensor
import unittest
import unittest.mock as mock
from unittest.mock import MagicMock


def test_parse_headers_to_inputs():
Expand Down Expand Up @@ -238,3 +235,24 @@ class Test(bittensor.Synapse):
assert next_synapse.a["cat"].shape == [10]
assert next_synapse.a["dog"].dtype == "torch.float32"
assert next_synapse.a["dog"].shape == [11]


def test_override_protection():
with pytest.raises(TypeError, match="You can't override the body_hash attribute!"):

class DerivedModel(bittensor.Synapse):
@property
def body_hash(self):
return "new_value"


def test_body_hash_override():
# Create a Synapse instance
synapse_instance = bittensor.Synapse()

# Try to set the body_hash property and expect an AttributeError
with pytest.raises(
AttributeError,
match="body_hash property is read-only and cannot be overridden.",
):
synapse_instance.body_hash = "some_value"