From 2790df9d2a5e25a0a826a538fc214a0f3493b9e2 Mon Sep 17 00:00:00 2001 From: 1yam <40899431+1yam@users.noreply.github.com> Date: Wed, 2 Oct 2024 12:25:33 +0200 Subject: [PATCH] Feature: Account Handler (#175) * Feature: Internal account management + fix on _load_account to handle SolAccount * fixup! Feature: Internal account management + fix on _load_account to handle SolAccount * Fix: chains_config wasn't using settings.CONFIG_HOME for locations * Fix: blakc issue * Fix: rename CHAINS_CONFIG_FILE to CONFIG_FILE to avoid getting issue by conf of chain * Fix: base58 and pynacl is now needed for build * Fix: f string without nay placeholders * Fix: black error * Refactor: we now store single account at the time * Fix: ruff issue * fix: debug stuff remove * Fix: Improve code structure in pair-programming with Lyam --------- Co-authored-by: Andres D. Molins --- .gitignore | 2 +- pyproject.toml | 2 + src/aleph/sdk/account.py | 52 ++++++++++++++---- src/aleph/sdk/chains/solana.py | 94 +++++++++++++++++++++++++++++++-- src/aleph/sdk/conf.py | 68 +++++++++++++++++++++++- tests/unit/test_chain_solana.py | 60 ++++++++++++++++++++- 6 files changed, 263 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index 2896a4e6..f18f4bd6 100644 --- a/.gitignore +++ b/.gitignore @@ -50,7 +50,7 @@ MANIFEST **/device.key # environment variables -.env +.config.json .env.local .gitsigners diff --git a/pyproject.toml b/pyproject.toml index 2cffe116..f533bfe2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,8 @@ dependencies = [ "aleph-superfluid>=0.2.1", "eth_typing==4.3.1", "web3==6.3.0", + "base58==2.1.1", # Needed now as default with _load_account changement + "pynacl==1.5.0" # Needed now as default with _load_account changement ] [project.optional-dependencies] diff --git a/src/aleph/sdk/account.py b/src/aleph/sdk/account.py index 59eef815..8c067283 100644 --- a/src/aleph/sdk/account.py +++ b/src/aleph/sdk/account.py @@ -1,12 +1,15 @@ import asyncio import logging from pathlib import Path -from typing import Optional, Type, TypeVar +from typing import Dict, Optional, Type, TypeVar + +from aleph_message.models import Chain from aleph.sdk.chains.common import get_fallback_private_key from aleph.sdk.chains.ethereum import ETHAccount from aleph.sdk.chains.remote import RemoteAccount -from aleph.sdk.conf import settings +from aleph.sdk.chains.solana import SOLAccount +from aleph.sdk.conf import load_main_configuration, settings from aleph.sdk.types import AccountFromPrivateKey logger = logging.getLogger(__name__) @@ -14,6 +17,16 @@ T = TypeVar("T", bound=AccountFromPrivateKey) +def load_chain_account_type(chain: Chain) -> Type[AccountFromPrivateKey]: + chain_account_map: Dict[Chain, Type[AccountFromPrivateKey]] = { + Chain.ETH: ETHAccount, + Chain.AVAX: ETHAccount, + Chain.SOL: SOLAccount, + Chain.BASE: ETHAccount, + } + return chain_account_map.get(chain) or ETHAccount + + def account_from_hex_string(private_key_str: str, account_type: Type[T]) -> T: if private_key_str.startswith("0x"): private_key_str = private_key_str[2:] @@ -28,16 +41,36 @@ def account_from_file(private_key_path: Path, account_type: Type[T]) -> T: def _load_account( private_key_str: Optional[str] = None, private_key_path: Optional[Path] = None, - account_type: Type[AccountFromPrivateKey] = ETHAccount, + account_type: Optional[Type[AccountFromPrivateKey]] = None, ) -> AccountFromPrivateKey: """Load private key from a string or a file. takes the string argument in priority""" + if private_key_str or (private_key_path and private_key_path.is_file()): + if account_type: + if private_key_path and private_key_path.is_file(): + return account_from_file(private_key_path, account_type) + elif private_key_str: + return account_from_hex_string(private_key_str, account_type) + else: + raise ValueError("Any private key specified") + else: + main_configuration = load_main_configuration(settings.CONFIG_FILE) + if main_configuration: + account_type = load_chain_account_type(main_configuration.chain) + logger.debug( + f"Detected {main_configuration.chain} account for path {settings.CONFIG_FILE}" + ) + else: + account_type = ETHAccount # Defaults to ETHAccount + logger.warning( + f"No main configuration data found in {settings.CONFIG_FILE}, defaulting to {account_type.__name__}" + ) + if private_key_path and private_key_path.is_file(): + return account_from_file(private_key_path, account_type) + elif private_key_str: + return account_from_hex_string(private_key_str, account_type) + else: + raise ValueError("Any private key specified") - if private_key_str: - logger.debug("Using account from string") - return account_from_hex_string(private_key_str, account_type) - elif private_key_path and private_key_path.is_file(): - logger.debug("Using account from file") - return account_from_file(private_key_path, account_type) elif settings.REMOTE_CRYPTO_HOST: logger.debug("Using remote account") loop = asyncio.get_event_loop() @@ -48,6 +81,7 @@ def _load_account( ) ) else: + account_type = ETHAccount # Defaults to ETHAccount new_private_key = get_fallback_private_key() account = account_type(private_key=new_private_key) logger.info( diff --git a/src/aleph/sdk/chains/solana.py b/src/aleph/sdk/chains/solana.py index ff870a4d..a9352489 100644 --- a/src/aleph/sdk/chains/solana.py +++ b/src/aleph/sdk/chains/solana.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union import base58 from nacl.exceptions import BadSignatureError as NaclBadSignatureError @@ -22,7 +22,7 @@ class SOLAccount(BaseAccount): _private_key: PrivateKey def __init__(self, private_key: bytes): - self.private_key = private_key + self.private_key = parse_private_key(private_key_from_bytes(private_key)) self._signing_key = SigningKey(self.private_key) self._private_key = self._signing_key.to_curve25519_private_key() @@ -79,7 +79,7 @@ def verify_signature( public_key: The public key to use for verification. Can be a base58 encoded string or bytes. message: The message to verify. Can be an utf-8 string or bytes. Raises: - BadSignatureError: If the signature is invalid. + BadSignatureError: If the signature is invalid.! """ if isinstance(signature, str): signature = base58.b58decode(signature) @@ -91,3 +91,91 @@ def verify_signature( VerifyKey(public_key).verify(message, signature) except NaclBadSignatureError as e: raise BadSignatureError from e + + +def private_key_from_bytes( + private_key_bytes: bytes, output_format: str = "base58" +) -> Union[str, List[int], bytes]: + """ + Convert a Solana private key in bytes back to different formats (base58 string, uint8 list, or raw bytes). + + - For base58 string: Encode the bytes into a base58 string. + - For uint8 list: Convert the bytes into a list of integers. + - For raw bytes: Return as-is. + + Args: + private_key_bytes (bytes): The private key in byte format. + output_format (str): The format to return ('base58', 'list', 'bytes'). + + Returns: + The private key in the requested format. + + Raises: + ValueError: If the output_format is not recognized or the private key length is invalid. + """ + if not isinstance(private_key_bytes, bytes): + raise ValueError("Expected the private key in bytes.") + + if len(private_key_bytes) != 32: + raise ValueError("Solana private key must be exactly 32 bytes long.") + + if output_format == "base58": + return base58.b58encode(private_key_bytes).decode("utf-8") + + elif output_format == "list": + return list(private_key_bytes) + + elif output_format == "bytes": + return private_key_bytes + + else: + raise ValueError("Invalid output format. Choose 'base58', 'list', or 'bytes'.") + + +def parse_private_key(private_key: Union[str, List[int], bytes]) -> bytes: + """ + Parse the private key which could be either: + - a base58-encoded string (which may contain both private and public key) + - a list of uint8 integers (which may contain both private and public key) + - a byte array (exactly 32 bytes) + + Returns: + bytes: The private key in byte format (32 bytes). + + Raises: + ValueError: If the private key format is invalid or the length is incorrect. + """ + # If the private key is already in byte format + if isinstance(private_key, bytes): + if len(private_key) != 32: + raise ValueError("The private key in bytes must be exactly 32 bytes long.") + return private_key + + # If the private key is a base58-encoded string + elif isinstance(private_key, str): + try: + decoded_key = base58.b58decode(private_key) + if len(decoded_key) not in [32, 64]: + raise ValueError( + "The base58 decoded private key must be either 32 or 64 bytes long." + ) + return decoded_key[:32] + except Exception as e: + raise ValueError(f"Invalid base58 encoded private key: {e}") + + # If the private key is a list of uint8 integers + elif isinstance(private_key, list): + if all(isinstance(i, int) and 0 <= i <= 255 for i in private_key): + byte_key = bytes(private_key) + if len(byte_key) < 32: + raise ValueError("The uint8 array must contain at least 32 elements.") + return byte_key[:32] # Take the first 32 bytes (private key) + else: + raise ValueError( + "Invalid uint8 array, must contain integers between 0 and 255." + ) + + else: + raise ValueError( + "Unsupported private key format. Must be a base58 string, bytes, or a list of uint8 integers." + ) diff --git a/src/aleph/sdk/conf.py b/src/aleph/sdk/conf.py index 4236370a..114652b7 100644 --- a/src/aleph/sdk/conf.py +++ b/src/aleph/sdk/conf.py @@ -1,3 +1,5 @@ +import json +import logging import os from pathlib import Path from shutil import which @@ -5,14 +7,21 @@ from aleph_message.models import Chain from aleph_message.models.execution.environment import HypervisorType -from pydantic import BaseSettings, Field +from pydantic import BaseModel, BaseSettings, Field from aleph.sdk.types import ChainInfo +logger = logging.getLogger(__name__) + class Settings(BaseSettings): CONFIG_HOME: Optional[str] = None + CONFIG_FILE: Path = Field( + default=Path("config.json"), + description="Path to the JSON file containing chain account configurations", + ) + # In case the user does not want to bother with handling private keys himself, # do an ugly and insecure write and read from disk to this file. PRIVATE_KEY_FILE: Path = Field( @@ -139,6 +148,18 @@ class Config: env_file = ".env" +class MainConfiguration(BaseModel): + """ + Intern Chain Management with Account. + """ + + path: Path + chain: Chain + + class Config: + use_enum_values = True + + # Settings singleton settings = Settings() @@ -162,6 +183,19 @@ class Config: settings.PRIVATE_MNEMONIC_FILE = Path( settings.CONFIG_HOME, "private-keys", "substrate.mnemonic" ) +if str(settings.CONFIG_FILE) == "config.json": + settings.CONFIG_FILE = Path(settings.CONFIG_HOME, "config.json") + # If Config file exist and well filled we update the PRIVATE_KEY_FILE default + if settings.CONFIG_FILE.exists(): + try: + with open(settings.CONFIG_FILE, "r", encoding="utf-8") as f: + config_data = json.load(f) + + if "path" in config_data: + settings.PRIVATE_KEY_FILE = Path(config_data["path"]) + except json.JSONDecodeError: + pass + # Update CHAINS settings and remove placeholders CHAINS_ENV = [(key[7:], value) for key, value in settings if key.startswith("CHAINS_")] @@ -172,3 +206,35 @@ class Config: field = field.lower() settings.CHAINS[chain].__dict__[field] = value settings.__delattr__(f"CHAINS_{fields}") + + +def save_main_configuration(file_path: Path, data: MainConfiguration): + """ + Synchronously save a single ChainAccount object as JSON to a file. + """ + with file_path.open("w") as file: + data_serializable = data.dict() + data_serializable["path"] = str(data_serializable["path"]) + json.dump(data_serializable, file, indent=4) + + +def load_main_configuration(file_path: Path) -> Optional[MainConfiguration]: + """ + Synchronously load the private key and chain type from a file. + If the file does not exist or is empty, return None. + """ + if not file_path.exists() or file_path.stat().st_size == 0: + logger.debug(f"File {file_path} does not exist or is empty. Returning None.") + return None + + try: + with file_path.open("rb") as file: + content = file.read() + data = json.loads(content.decode("utf-8")) + return MainConfiguration(**data) + except UnicodeDecodeError as e: + logger.error(f"Unable to decode {file_path} as UTF-8: {e}") + except json.JSONDecodeError: + logger.error(f"Invalid JSON format in {file_path}.") + + return None diff --git a/tests/unit/test_chain_solana.py b/tests/unit/test_chain_solana.py index ed2fff78..0fbd717e 100644 --- a/tests/unit/test_chain_solana.py +++ b/tests/unit/test_chain_solana.py @@ -8,7 +8,12 @@ from nacl.signing import VerifyKey from aleph.sdk.chains.common import get_verification_buffer -from aleph.sdk.chains.solana import SOLAccount, get_fallback_account, verify_signature +from aleph.sdk.chains.solana import ( + SOLAccount, + get_fallback_account, + parse_private_key, + verify_signature, +) from aleph.sdk.exceptions import BadSignatureError @@ -136,3 +141,56 @@ async def test_sign_raw(solana_account): assert isinstance(signature, bytes) verify_signature(signature, solana_account.get_address(), buffer) + + +def test_parse_solana_private_key_bytes(): + # Valid 32-byte private key + private_key_bytes = bytes(range(32)) + parsed_key = parse_private_key(private_key_bytes) + assert isinstance(parsed_key, bytes) + assert len(parsed_key) == 32 + assert parsed_key == private_key_bytes + + # Invalid private key (too short) + with pytest.raises( + ValueError, match="The private key in bytes must be exactly 32 bytes long." + ): + parse_private_key(bytes(range(31))) + + +def test_parse_solana_private_key_base58(): + # Valid base58 private key (32 bytes) + base58_key = base58.b58encode(bytes(range(32))).decode("utf-8") + parsed_key = parse_private_key(base58_key) + assert isinstance(parsed_key, bytes) + assert len(parsed_key) == 32 + + # Invalid base58 key (not decodable) + with pytest.raises(ValueError, match="Invalid base58 encoded private key"): + parse_private_key("invalid_base58_key") + + # Invalid base58 key (wrong length) + with pytest.raises( + ValueError, + match="The base58 decoded private key must be either 32 or 64 bytes long.", + ): + parse_private_key(base58.b58encode(bytes(range(31))).decode("utf-8")) + + +def test_parse_solana_private_key_list(): + # Valid list of uint8 integers (64 elements, but we only take the first 32 for private key) + uint8_list = list(range(64)) + parsed_key = parse_private_key(uint8_list) + assert isinstance(parsed_key, bytes) + assert len(parsed_key) == 32 + assert parsed_key == bytes(range(32)) + + # Invalid list (contains non-integers) + with pytest.raises(ValueError, match="Invalid uint8 array"): + parse_private_key([1, 2, "not an int", 4]) # type: ignore # Ignore type check for string + + # Invalid list (less than 32 elements) + with pytest.raises( + ValueError, match="The uint8 array must contain at least 32 elements." + ): + parse_private_key(list(range(31)))