Skip to content

Commit

Permalink
Feature: VmClient
Browse files Browse the repository at this point in the history
  • Loading branch information
1yam committed Jun 19, 2024
1 parent 1d3d5e5 commit f771d21
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 15 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"coincurve>=19.0.0; python_version>=\"3.11\"",
"eth_abi>=4.0.0; python_version>=\"3.11\"",
"eth_account>=0.4.0,<0.11.0",
"jwcrypto==1.5.6",
"python-magic",
"typer",
"typing_extensions",
Expand Down
7 changes: 0 additions & 7 deletions src/aleph/sdk/chains/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,3 @@ def get_fallback_private_key(path: Optional[Path] = None) -> bytes:
if not default_key_path.exists():
default_key_path.symlink_to(path)
return private_key


def bytes_from_hex(hex_string: str) -> bytes:
if hex_string.startswith("0x"):
hex_string = hex_string[2:]
hex_string = bytes.fromhex(hex_string)
return hex_string
8 changes: 2 additions & 6 deletions src/aleph/sdk/chains/ethereum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@
from eth_keys.exceptions import BadSignature as EthBadSignatureError

from ..exceptions import BadSignatureError
from .common import (
BaseAccount,
bytes_from_hex,
get_fallback_private_key,
get_public_key,
)
from ..utils import bytes_from_hex
from .common import BaseAccount, get_fallback_private_key, get_public_key


class ETHAccount(BaseAccount):
Expand Down
3 changes: 2 additions & 1 deletion src/aleph/sdk/chains/substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from ..conf import settings
from ..exceptions import BadSignatureError
from .common import BaseAccount, bytes_from_hex, get_verification_buffer
from ..utils import bytes_from_hex
from .common import BaseAccount, get_verification_buffer

logger = logging.getLogger(__name__)

Expand Down
161 changes: 161 additions & 0 deletions src/aleph/sdk/client/vmclient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import datetime
import json
import logging
from typing import Any, Dict, Tuple

import aiohttp
from eth_account.messages import encode_defunct
from jwcrypto import jwk
from jwcrypto.jwa import JWA

from aleph.sdk.types import Account
from aleph.sdk.utils import to_0x_hex

logger = logging.getLogger(__name__)


class VmClient:
def __init__(self, account: Account, domain: str = ""):
self.account: Account = account
self.ephemeral_key: jwk.JWK = jwk.JWK.generate(kty="EC", crv="P-256")
self.domain: str = domain
self.pubkey_payload = self._generate_pubkey_payload()
self.pubkey_signature_header: str = ""
self.session = aiohttp.ClientSession()

def _generate_pubkey_payload(self) -> Dict[str, Any]:
return {
"pubkey": json.loads(self.ephemeral_key.export_public()),
"alg": "ECDSA",
"domain": self.domain,
"address": self.account.get_address(),
"expires": (
datetime.datetime.utcnow() + datetime.timedelta(days=1)
).isoformat()
+ "Z",
}

async def _generate_pubkey_signature_header(self) -> str:
pubkey_payload = json.dumps(self.pubkey_payload).encode("utf-8").hex()
signable_message = encode_defunct(hexstr=pubkey_payload)
buffer_to_sign = signable_message.body

signed_message = await self.account.sign_raw(buffer_to_sign)
pubkey_signature = to_0x_hex(signed_message)

return json.dumps(
{
"sender": self.account.get_address(),
"payload": pubkey_payload,
"signature": pubkey_signature,
"content": {"domain": self.domain},
}
)

async def _generate_header(
self, vm_id: str, operation: str
) -> Tuple[str, Dict[str, str]]:
base_url = f"http://{self.domain}"
path = (
f"/logs/{vm_id}"
if operation == "logs"
else f"/control/machine/{vm_id}/{operation}"
)

payload = {
"time": datetime.datetime.utcnow().isoformat() + "Z",
"method": "POST",
"path": path,
}
payload_as_bytes = json.dumps(payload).encode("utf-8")
headers = {"X-SignedPubKey": self.pubkey_signature_header}
payload_signature = JWA.signing_alg("ES256").sign(
self.ephemeral_key, payload_as_bytes
)
headers["X-SignedOperation"] = json.dumps(
{
"payload": payload_as_bytes.hex(),
"signature": payload_signature.hex(),
}
)

return f"{base_url}{path}", headers

async def perform_operation(self, vm_id, operation):
if not self.pubkey_signature_header:
self.pubkey_signature_header = (
await self._generate_pubkey_signature_header()
)

url, header = await self._generate_header(vm_id=vm_id, operation=operation)

try:
async with self.session.post(url, headers=header) as response:
response_text = await response.text()
return response.status, response_text
except aiohttp.ClientError as e:
logger.error(f"HTTP error during operation {operation}: {str(e)}")
return None, str(e)

async def get_logs(self, vm_id):
if not self.pubkey_signature_header:
self.pubkey_signature_header = (
await self._generate_pubkey_signature_header()
)

ws_url, header = await self._generate_header(vm_id=vm_id, operation="logs")

async with aiohttp.ClientSession() as session:
async with session.ws_connect(ws_url) as ws:
auth_message = {
"auth": {
"X-SignedPubKey": header["X-SignedPubKey"],
"X-SignedOperation": header["X-SignedOperation"],
}
}
await ws.send_json(auth_message)
async for msg in ws: # msg is of type aiohttp.WSMessage
if msg.type == aiohttp.WSMsgType.TEXT:
yield msg.data
elif msg.type == aiohttp.WSMsgType.ERROR:
break

async def start_instance(self, vm_id):
return await self.notify_allocation(vm_id)

async def stop_instance(self, vm_id):
return await self.perform_operation(vm_id, "stop")

async def reboot_instance(self, vm_id):

return await self.perform_operation(vm_id, "reboot")

async def erase_instance(self, vm_id):
return await self.perform_operation(vm_id, "erase")

async def expire_instance(self, vm_id):
return await self.perform_operation(vm_id, "expire")

async def notify_allocation(self, vm_id) -> Tuple[Any, str]:
json_data = {"instance": vm_id}
async with self.session.post(
f"https://{self.domain}/control/allocation/notify", json=json_data
) as s:
form_response_text = await s.text()
return s.status, form_response_text

async def manage_instance(self, vm_id, operations):
for operation in operations:
status, response = await self.perform_operation(vm_id, operation)
if status != 200:
return status, response
return

async def close(self):
await self.session.close()

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
2 changes: 2 additions & 0 deletions src/aleph/sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class Account(Protocol):
@abstractmethod
async def sign_message(self, message: Dict) -> Dict: ...

@abstractmethod
async def sign_raw(self, buffer: bytes) -> bytes: ...
@abstractmethod
def get_address(self) -> str: ...

Expand Down
11 changes: 11 additions & 0 deletions src/aleph/sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,14 @@ def parse_volume(volume_dict: Union[Mapping, MachineVolume]) -> MachineVolume:
def compute_sha256(s: str) -> str:
"""Compute the SHA256 hash of a string."""
return hashlib.sha256(s.encode()).hexdigest()


def to_0x_hex(b: bytes) -> str:
return "0x" + bytes.hex(b)


def bytes_from_hex(hex_string: str) -> bytes:
if hex_string.startswith("0x"):
hex_string = hex_string[2:]
hex_string = bytes.fromhex(hex_string)
return hex_string
3 changes: 2 additions & 1 deletion src/aleph/sdk/wallets/ledger/ethereum.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from ledgereth.messages import sign_message
from ledgereth.objects import LedgerAccount, SignedMessage

from ...chains.common import BaseAccount, bytes_from_hex, get_verification_buffer
from ...chains.common import BaseAccount, get_verification_buffer
from ...utils import bytes_from_hex


class LedgerETHAccount(BaseAccount):
Expand Down

0 comments on commit f771d21

Please sign in to comment.