diff --git a/src/aleph/sdk/account.py b/src/aleph/sdk/account.py index 8c067283..9bfafcd3 100644 --- a/src/aleph/sdk/account.py +++ b/src/aleph/sdk/account.py @@ -10,67 +10,95 @@ from aleph.sdk.chains.remote import RemoteAccount from aleph.sdk.chains.solana import SOLAccount from aleph.sdk.conf import load_main_configuration, settings +from aleph.sdk.evm_utils import get_chains_with_super_token from aleph.sdk.types import AccountFromPrivateKey logger = logging.getLogger(__name__) T = TypeVar("T", bound=AccountFromPrivateKey) +chain_account_map: Dict[Chain, Type[T]] = { # type: ignore + Chain.ETH: ETHAccount, + Chain.AVAX: ETHAccount, + Chain.BASE: ETHAccount, + Chain.SOL: SOLAccount, +} + def load_chain_account_type(chain: Chain) -> Type[AccountFromPrivateKey]: - chain_account_map: Dict[Chain, Type[AccountFromPrivateKey]] = { - Chain.ETH: ETHAccount, - Chain.AVAX: ETHAccount, - Chain.SOL: SOLAccount, - Chain.BASE: ETHAccount, - } - return chain_account_map.get(chain) or ETHAccount + return chain_account_map.get(chain) or ETHAccount # type: ignore -def account_from_hex_string(private_key_str: str, account_type: Type[T]) -> T: +def account_from_hex_string( + private_key_str: str, + account_type: Optional[Type[T]], + chain: Optional[Chain] = None, +) -> AccountFromPrivateKey: if private_key_str.startswith("0x"): private_key_str = private_key_str[2:] - return account_type(bytes.fromhex(private_key_str)) + if not chain: + if not account_type: + account_type = load_chain_account_type(Chain.ETH) # type: ignore + return account_type(bytes.fromhex(private_key_str)) # type: ignore + + account_type = load_chain_account_type(chain) + account = account_type(bytes.fromhex(private_key_str)) + if chain in get_chains_with_super_token(): + account.switch_chain(chain) + return account # type: ignore -def account_from_file(private_key_path: Path, account_type: Type[T]) -> T: + +def account_from_file( + private_key_path: Path, + account_type: Optional[Type[T]], + chain: Optional[Chain] = None, +) -> AccountFromPrivateKey: private_key = private_key_path.read_bytes() - return account_type(private_key) + + if not chain: + if not account_type: + account_type = load_chain_account_type(Chain.ETH) # type: ignore + return account_type(private_key) # type: ignore + + account_type = load_chain_account_type(chain) + account = account_type(private_key) + if chain in get_chains_with_super_token(): + account.switch_chain(chain) + return account def _load_account( private_key_str: Optional[str] = None, private_key_path: Optional[Path] = None, account_type: Optional[Type[AccountFromPrivateKey]] = None, + chain: Optional[Chain] = None, ) -> AccountFromPrivateKey: - """Load private key from a string or a file. takes the string argument in priority""" - if private_key_str or (private_key_path and private_key_path.is_file()): - if account_type: - if private_key_path and private_key_path.is_file(): - return account_from_file(private_key_path, account_type) - elif private_key_str: - return account_from_hex_string(private_key_str, account_type) - else: - raise ValueError("Any private key specified") + """Load an account from a private key string or file, or from the configuration file.""" + + # Loads configuration if no account_type is specified + if not account_type: + config = load_main_configuration(settings.CONFIG_FILE) + if config and hasattr(config, "chain"): + account_type = load_chain_account_type(config.chain) + logger.debug( + f"Detected {config.chain} account for path {settings.CONFIG_FILE}" + ) else: - main_configuration = load_main_configuration(settings.CONFIG_FILE) - if main_configuration: - account_type = load_chain_account_type(main_configuration.chain) - logger.debug( - f"Detected {main_configuration.chain} account for path {settings.CONFIG_FILE}" - ) - else: - account_type = ETHAccount # Defaults to ETHAccount - logger.warning( - f"No main configuration data found in {settings.CONFIG_FILE}, defaulting to {account_type.__name__}" - ) - if private_key_path and private_key_path.is_file(): - return account_from_file(private_key_path, account_type) - elif private_key_str: - return account_from_hex_string(private_key_str, account_type) - else: - raise ValueError("Any private key specified") + account_type = account_type = load_chain_account_type( + Chain.ETH + ) # Defaults to ETHAccount + logger.warning( + f"No main configuration data found in {settings.CONFIG_FILE}, defaulting to {account_type and account_type.__name__}" + ) + # Loads private key from a string + if private_key_str: + return account_from_hex_string(private_key_str, account_type, chain) + # Loads private key from a file + elif private_key_path and private_key_path.is_file(): + return account_from_file(private_key_path, account_type, chain) + # For ledger keys elif settings.REMOTE_CRYPTO_HOST: logger.debug("Using remote account") loop = asyncio.get_event_loop() @@ -80,10 +108,12 @@ def _load_account( unix_socket=settings.REMOTE_CRYPTO_UNIX_SOCKET, ) ) + # Fallback: config.path if set, else generate a new private key else: - account_type = ETHAccount # Defaults to ETHAccount new_private_key = get_fallback_private_key() - account = account_type(private_key=new_private_key) + account = account_from_hex_string( + bytes.hex(new_private_key), account_type, chain + ) logger.info( f"Generated fallback private key with address {account.get_address()}" ) diff --git a/src/aleph/sdk/chains/ethereum.py b/src/aleph/sdk/chains/ethereum.py index 32f459b7..ab93df56 100644 --- a/src/aleph/sdk/chains/ethereum.py +++ b/src/aleph/sdk/chains/ethereum.py @@ -1,4 +1,5 @@ import asyncio +import base64 from decimal import Decimal from pathlib import Path from typing import Awaitable, Optional, Union @@ -61,6 +62,10 @@ def from_mnemonic(mnemonic: str, chain: Optional[Chain] = None) -> "ETHAccount": private_key=Account.from_mnemonic(mnemonic=mnemonic).key, chain=chain ) + def export_private_key(self) -> str: + """Export the private key using standard format.""" + return f"0x{base64.b16encode(self.private_key).decode().lower()}" + def get_address(self) -> str: return self._account.address diff --git a/src/aleph/sdk/chains/solana.py b/src/aleph/sdk/chains/solana.py index a9352489..920ca8a0 100644 --- a/src/aleph/sdk/chains/solana.py +++ b/src/aleph/sdk/chains/solana.py @@ -43,6 +43,12 @@ async def sign_raw(self, buffer: bytes) -> bytes: sig = self._signing_key.sign(buffer) return sig.signature + def export_private_key(self) -> str: + """Export the private key using Phantom format.""" + return base58.b58encode( + self.private_key + self._signing_key.verify_key.encode() + ).decode() + def get_address(self) -> str: return encode(self._signing_key.verify_key) diff --git a/src/aleph/sdk/client/vm_client.py b/src/aleph/sdk/client/vm_client.py index 18d280cc..83b00dc9 100644 --- a/src/aleph/sdk/client/vm_client.py +++ b/src/aleph/sdk/client/vm_client.py @@ -5,10 +5,11 @@ from urllib.parse import urlparse import aiohttp -from aleph_message.models import ItemHash +from aleph_message.models import Chain, ItemHash from eth_account.messages import encode_defunct from jwcrypto import jwk +from aleph.sdk.chains.solana import SOLAccount from aleph.sdk.types import Account from aleph.sdk.utils import ( create_vm_control_payload, @@ -36,11 +37,13 @@ def __init__( self.account = account self.ephemeral_key = jwk.JWK.generate(kty="EC", crv="P-256") self.node_url = node_url.rstrip("/") - self.pubkey_payload = self._generate_pubkey_payload() + self.pubkey_payload = self._generate_pubkey_payload( + Chain.SOL if isinstance(account, SOLAccount) else Chain.ETH + ) self.pubkey_signature_header = "" self.session = session or aiohttp.ClientSession() - def _generate_pubkey_payload(self) -> Dict[str, Any]: + def _generate_pubkey_payload(self, chain: Chain = Chain.ETH) -> Dict[str, Any]: return { "pubkey": json.loads(self.ephemeral_key.export_public()), "alg": "ECDSA", @@ -50,12 +53,16 @@ def _generate_pubkey_payload(self) -> Dict[str, Any]: datetime.datetime.utcnow() + datetime.timedelta(days=1) ).isoformat() + "Z", + "chain": chain.value, } async def _generate_pubkey_signature_header(self) -> str: pubkey_payload = json.dumps(self.pubkey_payload).encode("utf-8").hex() - signable_message = encode_defunct(hexstr=pubkey_payload) - buffer_to_sign = signable_message.body + if isinstance(self.account, SOLAccount): + buffer_to_sign = bytes(pubkey_payload, encoding="utf-8") + else: + signable_message = encode_defunct(hexstr=pubkey_payload) + buffer_to_sign = signable_message.body signed_message = await self.account.sign_raw(buffer_to_sign) pubkey_signature = to_0x_hex(signed_message) diff --git a/src/aleph/sdk/conf.py b/src/aleph/sdk/conf.py index 114652b7..5fe4cd4b 100644 --- a/src/aleph/sdk/conf.py +++ b/src/aleph/sdk/conf.py @@ -41,7 +41,7 @@ class Settings(BaseSettings): REMOTE_CRYPTO_HOST: Optional[str] = None REMOTE_CRYPTO_UNIX_SOCKET: Optional[str] = None ADDRESS_TO_USE: Optional[str] = None - HTTP_REQUEST_TIMEOUT = 10.0 + HTTP_REQUEST_TIMEOUT = 15.0 DEFAULT_CHANNEL: str = "ALEPH-CLOUDSOLUTIONS" DEFAULT_RUNTIME_ID: str = ( diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index 081a3465..dab90379 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -39,6 +39,10 @@ def __init__(self, private_key: bytes): ... async def sign_raw(self, buffer: bytes) -> bytes: ... + def export_private_key(self) -> str: ... + + def switch_chain(self, chain: Optional[str] = None) -> None: ... + GenericMessage = TypeVar("GenericMessage", bound=AlephMessage)