Skip to content

Commit

Permalink
Post-SOL fixes (#178)
Browse files Browse the repository at this point in the history
* Missing chain field on auth

* Fix Signature of Solana operation for CRN

* Add export_private_key func for accounts

* Improve _load_account

* Add chain arg to _load_account

* Increase default HTTP_REQUEST_TIMEOUT

* Typing

---------

Co-authored-by: Olivier Le Thanh Duong <olivier@lethanh.be>
  • Loading branch information
philogicae and olethanh authored Oct 11, 2024
1 parent a636106 commit d54e9ac
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 45 deletions.
108 changes: 69 additions & 39 deletions src/aleph/sdk/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,67 +10,95 @@
from aleph.sdk.chains.remote import RemoteAccount
from aleph.sdk.chains.solana import SOLAccount
from aleph.sdk.conf import load_main_configuration, settings
from aleph.sdk.evm_utils import get_chains_with_super_token
from aleph.sdk.types import AccountFromPrivateKey

logger = logging.getLogger(__name__)

T = TypeVar("T", bound=AccountFromPrivateKey)

chain_account_map: Dict[Chain, Type[T]] = { # type: ignore
Chain.ETH: ETHAccount,
Chain.AVAX: ETHAccount,
Chain.BASE: ETHAccount,
Chain.SOL: SOLAccount,
}


def load_chain_account_type(chain: Chain) -> Type[AccountFromPrivateKey]:
chain_account_map: Dict[Chain, Type[AccountFromPrivateKey]] = {
Chain.ETH: ETHAccount,
Chain.AVAX: ETHAccount,
Chain.SOL: SOLAccount,
Chain.BASE: ETHAccount,
}
return chain_account_map.get(chain) or ETHAccount
return chain_account_map.get(chain) or ETHAccount # type: ignore


def account_from_hex_string(private_key_str: str, account_type: Type[T]) -> T:
def account_from_hex_string(
private_key_str: str,
account_type: Optional[Type[T]],
chain: Optional[Chain] = None,
) -> AccountFromPrivateKey:
if private_key_str.startswith("0x"):
private_key_str = private_key_str[2:]
return account_type(bytes.fromhex(private_key_str))

if not chain:
if not account_type:
account_type = load_chain_account_type(Chain.ETH) # type: ignore
return account_type(bytes.fromhex(private_key_str)) # type: ignore

account_type = load_chain_account_type(chain)
account = account_type(bytes.fromhex(private_key_str))
if chain in get_chains_with_super_token():
account.switch_chain(chain)
return account # type: ignore

def account_from_file(private_key_path: Path, account_type: Type[T]) -> T:

def account_from_file(
private_key_path: Path,
account_type: Optional[Type[T]],
chain: Optional[Chain] = None,
) -> AccountFromPrivateKey:
private_key = private_key_path.read_bytes()
return account_type(private_key)

if not chain:
if not account_type:
account_type = load_chain_account_type(Chain.ETH) # type: ignore
return account_type(private_key) # type: ignore

account_type = load_chain_account_type(chain)
account = account_type(private_key)
if chain in get_chains_with_super_token():
account.switch_chain(chain)
return account


def _load_account(
private_key_str: Optional[str] = None,
private_key_path: Optional[Path] = None,
account_type: Optional[Type[AccountFromPrivateKey]] = None,
chain: Optional[Chain] = None,
) -> AccountFromPrivateKey:
"""Load private key from a string or a file. takes the string argument in priority"""
if private_key_str or (private_key_path and private_key_path.is_file()):
if account_type:
if private_key_path and private_key_path.is_file():
return account_from_file(private_key_path, account_type)
elif private_key_str:
return account_from_hex_string(private_key_str, account_type)
else:
raise ValueError("Any private key specified")
"""Load an account from a private key string or file, or from the configuration file."""

# Loads configuration if no account_type is specified
if not account_type:
config = load_main_configuration(settings.CONFIG_FILE)
if config and hasattr(config, "chain"):
account_type = load_chain_account_type(config.chain)
logger.debug(
f"Detected {config.chain} account for path {settings.CONFIG_FILE}"
)
else:
main_configuration = load_main_configuration(settings.CONFIG_FILE)
if main_configuration:
account_type = load_chain_account_type(main_configuration.chain)
logger.debug(
f"Detected {main_configuration.chain} account for path {settings.CONFIG_FILE}"
)
else:
account_type = ETHAccount # Defaults to ETHAccount
logger.warning(
f"No main configuration data found in {settings.CONFIG_FILE}, defaulting to {account_type.__name__}"
)
if private_key_path and private_key_path.is_file():
return account_from_file(private_key_path, account_type)
elif private_key_str:
return account_from_hex_string(private_key_str, account_type)
else:
raise ValueError("Any private key specified")
account_type = account_type = load_chain_account_type(
Chain.ETH
) # Defaults to ETHAccount
logger.warning(
f"No main configuration data found in {settings.CONFIG_FILE}, defaulting to {account_type and account_type.__name__}"
)

# Loads private key from a string
if private_key_str:
return account_from_hex_string(private_key_str, account_type, chain)
# Loads private key from a file
elif private_key_path and private_key_path.is_file():
return account_from_file(private_key_path, account_type, chain)
# For ledger keys
elif settings.REMOTE_CRYPTO_HOST:
logger.debug("Using remote account")
loop = asyncio.get_event_loop()
Expand All @@ -80,10 +108,12 @@ def _load_account(
unix_socket=settings.REMOTE_CRYPTO_UNIX_SOCKET,
)
)
# Fallback: config.path if set, else generate a new private key
else:
account_type = ETHAccount # Defaults to ETHAccount
new_private_key = get_fallback_private_key()
account = account_type(private_key=new_private_key)
account = account_from_hex_string(
bytes.hex(new_private_key), account_type, chain
)
logger.info(
f"Generated fallback private key with address {account.get_address()}"
)
Expand Down
5 changes: 5 additions & 0 deletions src/aleph/sdk/chains/ethereum.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import base64
from decimal import Decimal
from pathlib import Path
from typing import Awaitable, Optional, Union
Expand Down Expand Up @@ -61,6 +62,10 @@ def from_mnemonic(mnemonic: str, chain: Optional[Chain] = None) -> "ETHAccount":
private_key=Account.from_mnemonic(mnemonic=mnemonic).key, chain=chain
)

def export_private_key(self) -> str:
"""Export the private key using standard format."""
return f"0x{base64.b16encode(self.private_key).decode().lower()}"

def get_address(self) -> str:
return self._account.address

Expand Down
6 changes: 6 additions & 0 deletions src/aleph/sdk/chains/solana.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ async def sign_raw(self, buffer: bytes) -> bytes:
sig = self._signing_key.sign(buffer)
return sig.signature

def export_private_key(self) -> str:
"""Export the private key using Phantom format."""
return base58.b58encode(
self.private_key + self._signing_key.verify_key.encode()
).decode()

def get_address(self) -> str:
return encode(self._signing_key.verify_key)

Expand Down
17 changes: 12 additions & 5 deletions src/aleph/sdk/client/vm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from urllib.parse import urlparse

import aiohttp
from aleph_message.models import ItemHash
from aleph_message.models import Chain, ItemHash
from eth_account.messages import encode_defunct
from jwcrypto import jwk

from aleph.sdk.chains.solana import SOLAccount
from aleph.sdk.types import Account
from aleph.sdk.utils import (
create_vm_control_payload,
Expand Down Expand Up @@ -36,11 +37,13 @@ def __init__(
self.account = account
self.ephemeral_key = jwk.JWK.generate(kty="EC", crv="P-256")
self.node_url = node_url.rstrip("/")
self.pubkey_payload = self._generate_pubkey_payload()
self.pubkey_payload = self._generate_pubkey_payload(
Chain.SOL if isinstance(account, SOLAccount) else Chain.ETH
)
self.pubkey_signature_header = ""
self.session = session or aiohttp.ClientSession()

def _generate_pubkey_payload(self) -> Dict[str, Any]:
def _generate_pubkey_payload(self, chain: Chain = Chain.ETH) -> Dict[str, Any]:
return {
"pubkey": json.loads(self.ephemeral_key.export_public()),
"alg": "ECDSA",
Expand All @@ -50,12 +53,16 @@ def _generate_pubkey_payload(self) -> Dict[str, Any]:
datetime.datetime.utcnow() + datetime.timedelta(days=1)
).isoformat()
+ "Z",
"chain": chain.value,
}

async def _generate_pubkey_signature_header(self) -> str:
pubkey_payload = json.dumps(self.pubkey_payload).encode("utf-8").hex()
signable_message = encode_defunct(hexstr=pubkey_payload)
buffer_to_sign = signable_message.body
if isinstance(self.account, SOLAccount):
buffer_to_sign = bytes(pubkey_payload, encoding="utf-8")
else:
signable_message = encode_defunct(hexstr=pubkey_payload)
buffer_to_sign = signable_message.body

signed_message = await self.account.sign_raw(buffer_to_sign)
pubkey_signature = to_0x_hex(signed_message)
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/sdk/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Settings(BaseSettings):
REMOTE_CRYPTO_HOST: Optional[str] = None
REMOTE_CRYPTO_UNIX_SOCKET: Optional[str] = None
ADDRESS_TO_USE: Optional[str] = None
HTTP_REQUEST_TIMEOUT = 10.0
HTTP_REQUEST_TIMEOUT = 15.0

DEFAULT_CHANNEL: str = "ALEPH-CLOUDSOLUTIONS"
DEFAULT_RUNTIME_ID: str = (
Expand Down
4 changes: 4 additions & 0 deletions src/aleph/sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def __init__(self, private_key: bytes): ...

async def sign_raw(self, buffer: bytes) -> bytes: ...

def export_private_key(self) -> str: ...

def switch_chain(self, chain: Optional[str] = None) -> None: ...


GenericMessage = TypeVar("GenericMessage", bound=AlephMessage)

Expand Down

0 comments on commit d54e9ac

Please sign in to comment.