Skip to content

Commit

Permalink
Refactor: Private keys could be deleted
Browse files Browse the repository at this point in the history
Never delete private keys, as this could put the user at risk.
Instead, use temporary files in tests.
  • Loading branch information
hoh committed Jan 12, 2023
1 parent 258cd58 commit dbab017
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 140 deletions.
31 changes: 16 additions & 15 deletions src/aleph_client/chains/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from abc import abstractmethod, ABC
from typing import Dict
from pathlib import Path
from typing import Dict, Optional

from coincurve.keys import PrivateKey
from ecies import decrypt, encrypt
Expand Down Expand Up @@ -70,24 +71,24 @@ def generate_key() -> bytes:
return privkey.secret


def get_fallback_private_key() -> bytes:
def get_fallback_private_key(path: Optional[Path] = None) -> bytes:
path = path or settings.PRIVATE_KEY_FILE
private_key: bytes
try:
with open(settings.PRIVATE_KEY_FILE, "rb") as prvfile:
if path.exists() and path.stat().st_size > 0:
with open(path, "rb") as prvfile:
private_key = prvfile.read()
except OSError:
else:
private_key = generate_key()
os.makedirs(os.path.dirname(settings.PRIVATE_KEY_FILE), exist_ok=True)
with open(settings.PRIVATE_KEY_FILE, "wb") as prvfile:
os.makedirs(path.parent, exist_ok=True)
with open(path, "wb") as prvfile:
prvfile.write(private_key)
os.symlink(settings.PRIVATE_KEY_FILE, os.path.join(os.path.dirname(settings.PRIVATE_KEY_FILE), "default.key"))

return private_key
with open(path, "rb") as prvfile:
print(prvfile.read())


def delete_private_key_file():
try:
os.remove(settings.PRIVATE_KEY_FILE)
os.unlink(os.path.join(os.path.dirname(settings.PRIVATE_KEY_FILE), "default.key"))
except FileNotFoundError:
pass
default_key_path = path.parent / "default.key"
if not default_key_path.is_symlink():
# Create a symlink to use this key by default
os.symlink(path, default_key_path)
return private_key
7 changes: 4 additions & 3 deletions src/aleph_client/chains/ethereum.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
from pathlib import Path
from typing import Dict, Optional

from eth_account import Account
from eth_account.signers.local import LocalAccount
Expand Down Expand Up @@ -38,5 +39,5 @@ def get_public_key(self) -> str:
return "0x" + get_public_key(private_key=self._account.key).hex()


def get_fallback_account() -> ETHAccount:
return ETHAccount(private_key=get_fallback_private_key())
def get_fallback_account(path: Optional[Path] = None) -> ETHAccount:
return ETHAccount(private_key=get_fallback_private_key(path=path))
43 changes: 31 additions & 12 deletions src/aleph_client/chains/sol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
from typing import Dict
import os
from pathlib import Path
from typing import Dict, Optional

import base58
from nacl.public import PrivateKey, SealedBox
Expand Down Expand Up @@ -53,17 +55,34 @@ async def decrypt(self, content) -> bytes:
return value


def get_fallback_account() -> SOLAccount:
return SOLAccount(private_key=get_fallback_private_key())
def get_fallback_account(path: Optional[Path] = None) -> SOLAccount:
return SOLAccount(private_key=get_fallback_private_key(path=path))


def get_fallback_private_key():
try:
with open(settings.PRIVATE_KEY_FILE, "rb") as prvfile:
pkey = prvfile.read()
except OSError:
pkey = bytes(SigningKey.generate())
with open(settings.PRIVATE_KEY_FILE, "wb") as prvfile:
prvfile.write(pkey)
def generate_key() -> bytes:
privkey = bytes(SigningKey.generate())
return privkey


def get_fallback_private_key(path: Optional[Path] = None) -> bytes:
path = path or settings.PRIVATE_KEY_FILE
private_key: bytes
if path.exists() and path.stat().st_size > 0:
with open(path, "rb") as prvfile:
private_key = prvfile.read()
else:
private_key = generate_key()
os.makedirs(path.parent, exist_ok=True)
with open(path, "wb") as prvfile:
prvfile.write(private_key)

with open(path, "rb") as prvfile:
print(prvfile.read())


default_key_path = path.parent / "default.key"
if not default_key_path.is_symlink():
# Create a symlink to use this key by default
os.symlink(path, default_key_path)
return private_key

return pkey
7 changes: 4 additions & 3 deletions src/aleph_client/chains/tezos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import Dict
from pathlib import Path
from typing import Dict, Optional

from aleph_pytezos.crypto.key import Key
from nacl.public import SealedBox
Expand Down Expand Up @@ -49,5 +50,5 @@ async def decrypt(self, content) -> bytes:
return SealedBox(self._private_key).decrypt(content)


def get_fallback_account() -> TezosAccount:
return TezosAccount(private_key=get_fallback_private_key())
def get_fallback_account(path: Optional[Path] = None) -> TezosAccount:
return TezosAccount(private_key=get_fallback_private_key(path=path))
35 changes: 34 additions & 1 deletion tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,38 @@
Read more about conftest.py under:
https://pytest.org/latest/plugins.html
"""
from pathlib import Path
from tempfile import NamedTemporaryFile

# import pytest
import pytest

from aleph_client.chains.common import get_fallback_private_key
import aleph_client.chains.ethereum as ethereum
import aleph_client.chains.sol as solana
import aleph_client.chains.tezos as tezos

@pytest.fixture
def fallback_private_key() -> bytes:
with NamedTemporaryFile() as private_key_file:
yield get_fallback_private_key(path=Path(private_key_file.name))


@pytest.fixture
def ethereum_account() -> ethereum.ETHAccount:
with NamedTemporaryFile(delete=False) as private_key_file:
private_key_file.close()
yield ethereum.get_fallback_account(path=Path(private_key_file.name))


@pytest.fixture
def solana_account() -> solana.SOLAccount:
with NamedTemporaryFile(delete=False) as private_key_file:
private_key_file.close()
yield solana.get_fallback_account(path=Path(private_key_file.name))


@pytest.fixture
def tezos_account() -> tezos.TezosAccount:
with NamedTemporaryFile(delete=False) as private_key_file:
private_key_file.close()
yield tezos.get_fallback_account(path=Path(private_key_file.name))
64 changes: 16 additions & 48 deletions tests/unit/test_asynchronous.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from unittest.mock import MagicMock, patch, AsyncMock

import pytest as pytest
Expand All @@ -10,6 +9,8 @@
ForgetMessage,
)

from aleph_client.types import StorageEnum, MessageStatus

from aleph_client.asynchronous import (
create_post,
_get_fallback_session,
Expand All @@ -18,10 +19,6 @@
create_program,
forget,
)
from aleph_client.chains.common import get_fallback_private_key, delete_private_key_file
from aleph_client.chains.ethereum import ETHAccount
from aleph_client.conf import settings
from aleph_client.types import StorageEnum, MessageStatus


def new_mock_session_with_post_success():
Expand All @@ -41,21 +38,15 @@ def new_mock_session_with_post_success():


@pytest.mark.asyncio
async def test_create_post():
async def test_create_post(ethereum_account):
_get_fallback_session.cache_clear()

if os.path.exists(settings.PRIVATE_KEY_FILE):
delete_private_key_file()

private_key = get_fallback_private_key()
account: ETHAccount = ETHAccount(private_key=private_key)

content = {"Hello": "World"}

mock_session = new_mock_session_with_post_success()

post_message, message_status = await create_post(
account=account,
account=ethereum_account,
post_content=content,
post_type="TEST",
channel="TEST",
Expand All @@ -70,29 +61,23 @@ async def test_create_post():


@pytest.mark.asyncio
async def test_create_aggregate():
async def test_create_aggregate(ethereum_account):
_get_fallback_session.cache_clear()

if os.path.exists(settings.PRIVATE_KEY_FILE):
delete_private_key_file()

private_key = get_fallback_private_key()
account: ETHAccount = ETHAccount(private_key=private_key)

content = {"Hello": "World"}

mock_session = new_mock_session_with_post_success()

_ = await create_aggregate(
account=account,
account=ethereum_account,
key="hello",
content=content,
channel="TEST",
session=mock_session,
)

aggregate_message, message_status = await create_aggregate(
account=account,
account=ethereum_account,
key="hello",
content="world",
channel="TEST",
Expand All @@ -105,23 +90,17 @@ async def test_create_aggregate():


@pytest.mark.asyncio
async def test_create_store():
async def test_create_store(ethereum_account):
_get_fallback_session.cache_clear()

if os.path.exists(settings.PRIVATE_KEY_FILE):
delete_private_key_file()

private_key = get_fallback_private_key()
account: ETHAccount = ETHAccount(private_key=private_key)

mock_session = new_mock_session_with_post_success()

mock_ipfs_push_file = AsyncMock()
mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"

with patch("aleph_client.asynchronous.ipfs_push_file", mock_ipfs_push_file):
_ = await create_store(
account=account,
account=ethereum_account,
file_content=b"HELLO",
channel="TEST",
storage_engine=StorageEnum.ipfs,
Expand All @@ -130,7 +109,7 @@ async def test_create_store():
)

_ = await create_store(
account=account,
account=ethereum_account,
file_hash="QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy",
channel="TEST",
storage_engine=StorageEnum.ipfs,
Expand All @@ -144,8 +123,9 @@ async def test_create_store():
)

with patch("aleph_client.asynchronous.storage_push_file", mock_storage_push_file):

store_message, message_status = await create_store(
account=account,
account=ethereum_account,
file_content=b"HELLO",
channel="TEST",
storage_engine=StorageEnum.storage,
Expand All @@ -158,19 +138,13 @@ async def test_create_store():


@pytest.mark.asyncio
async def test_create_program():
async def test_create_program(ethereum_account):
_get_fallback_session.cache_clear()

if os.path.exists(settings.PRIVATE_KEY_FILE):
delete_private_key_file()

private_key = get_fallback_private_key()
account: ETHAccount = ETHAccount(private_key=private_key)

mock_session = new_mock_session_with_post_success()

program_message, message_status = await create_program(
account=account,
account=ethereum_account,
program_ref="FAKE-HASH",
entrypoint="main:app",
runtime="FAKE-HASH",
Expand All @@ -184,19 +158,13 @@ async def test_create_program():


@pytest.mark.asyncio
async def test_forget():
async def test_forget(ethereum_account):
_get_fallback_session.cache_clear()

if os.path.exists(settings.PRIVATE_KEY_FILE):
delete_private_key_file()

private_key = get_fallback_private_key()
account: ETHAccount = ETHAccount(private_key=private_key)

mock_session = new_mock_session_with_post_success()

forget_message, message_status = await forget(
account=account,
account=ethereum_account,
hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"],
reason="GDPR",
channel="TEST",
Expand Down
23 changes: 12 additions & 11 deletions tests/unit/test_chain_ethereum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from pathlib import Path
from tempfile import NamedTemporaryFile

import pytest
from dataclasses import dataclass, asdict

from aleph_client.chains.common import delete_private_key_file
from aleph_client.chains.ethereum import ETHAccount, get_fallback_account


Expand All @@ -14,17 +16,16 @@ class Message:


def test_get_fallback_account():
delete_private_key_file()
account: ETHAccount = get_fallback_account()

assert account.CHAIN == "ETH"
assert account.CURVE == "secp256k1"
assert account._account.address
with NamedTemporaryFile() as private_key_file:
account = get_fallback_account(path=Path(private_key_file.name))
assert account.CHAIN == "ETH"
assert account.CURVE == "secp256k1"
assert account._account.address


@pytest.mark.asyncio
async def test_ETHAccount():
account: ETHAccount = get_fallback_account()
async def test_ETHAccount(ethereum_account):
account = ethereum_account

message = Message("ETH", account.get_address(), "SomeType", "ItemHash")
signed = await account.sign_message(asdict(message))
Expand All @@ -42,8 +43,8 @@ async def test_ETHAccount():


@pytest.mark.asyncio
async def test_decrypt_secp256k1():
account: ETHAccount = get_fallback_account()
async def test_decrypt_secp256k1(ethereum_account):
account = ethereum_account

assert account.CURVE == "secp256k1"
content = b"SomeContent"
Expand Down
Loading

0 comments on commit dbab017

Please sign in to comment.