From dbab0177537fa75a1e058be00016a59c36b38892 Mon Sep 17 00:00:00 2001 From: Hugo Herter Date: Tue, 10 Jan 2023 19:10:37 +0100 Subject: [PATCH] Refactor: Private keys could be deleted Never delete private keys, as this could put the user at risk. Instead, use temporary files in tests. --- src/aleph_client/chains/common.py | 31 ++++++------- src/aleph_client/chains/ethereum.py | 7 +-- src/aleph_client/chains/sol.py | 43 +++++++++++++----- src/aleph_client/chains/tezos.py | 7 +-- tests/unit/conftest.py | 35 ++++++++++++++- tests/unit/test_asynchronous.py | 64 +++++++-------------------- tests/unit/test_chain_ethereum.py | 23 +++++----- tests/unit/test_chain_nuls1_compat.py | 20 +++++---- tests/unit/test_chain_solana.py | 36 +++++++-------- tests/unit/test_chain_tezos.py | 37 ++++++++-------- 10 files changed, 163 insertions(+), 140 deletions(-) diff --git a/src/aleph_client/chains/common.py b/src/aleph_client/chains/common.py index 86008264..6462d7f7 100644 --- a/src/aleph_client/chains/common.py +++ b/src/aleph_client/chains/common.py @@ -1,6 +1,7 @@ import os from abc import abstractmethod, ABC -from typing import Dict +from pathlib import Path +from typing import Dict, Optional from coincurve.keys import PrivateKey from ecies import decrypt, encrypt @@ -70,24 +71,24 @@ def generate_key() -> bytes: return privkey.secret -def get_fallback_private_key() -> bytes: +def get_fallback_private_key(path: Optional[Path] = None) -> bytes: + path = path or settings.PRIVATE_KEY_FILE private_key: bytes - try: - with open(settings.PRIVATE_KEY_FILE, "rb") as prvfile: + if path.exists() and path.stat().st_size > 0: + with open(path, "rb") as prvfile: private_key = prvfile.read() - except OSError: + else: private_key = generate_key() - os.makedirs(os.path.dirname(settings.PRIVATE_KEY_FILE), exist_ok=True) - with open(settings.PRIVATE_KEY_FILE, "wb") as prvfile: + os.makedirs(path.parent, exist_ok=True) + with open(path, "wb") as prvfile: prvfile.write(private_key) - os.symlink(settings.PRIVATE_KEY_FILE, os.path.join(os.path.dirname(settings.PRIVATE_KEY_FILE), "default.key")) - return private_key + with open(path, "rb") as prvfile: + print(prvfile.read()) -def delete_private_key_file(): - try: - os.remove(settings.PRIVATE_KEY_FILE) - os.unlink(os.path.join(os.path.dirname(settings.PRIVATE_KEY_FILE), "default.key")) - except FileNotFoundError: - pass + default_key_path = path.parent / "default.key" + if not default_key_path.is_symlink(): + # Create a symlink to use this key by default + os.symlink(path, default_key_path) + return private_key diff --git a/src/aleph_client/chains/ethereum.py b/src/aleph_client/chains/ethereum.py index e1d7306f..2866788b 100644 --- a/src/aleph_client/chains/ethereum.py +++ b/src/aleph_client/chains/ethereum.py @@ -1,4 +1,5 @@ -from typing import Dict +from pathlib import Path +from typing import Dict, Optional from eth_account import Account from eth_account.signers.local import LocalAccount @@ -38,5 +39,5 @@ def get_public_key(self) -> str: return "0x" + get_public_key(private_key=self._account.key).hex() -def get_fallback_account() -> ETHAccount: - return ETHAccount(private_key=get_fallback_private_key()) +def get_fallback_account(path: Optional[Path] = None) -> ETHAccount: + return ETHAccount(private_key=get_fallback_private_key(path=path)) diff --git a/src/aleph_client/chains/sol.py b/src/aleph_client/chains/sol.py index 724c596e..c7988cc0 100644 --- a/src/aleph_client/chains/sol.py +++ b/src/aleph_client/chains/sol.py @@ -1,5 +1,7 @@ import json -from typing import Dict +import os +from pathlib import Path +from typing import Dict, Optional import base58 from nacl.public import PrivateKey, SealedBox @@ -53,17 +55,34 @@ async def decrypt(self, content) -> bytes: return value -def get_fallback_account() -> SOLAccount: - return SOLAccount(private_key=get_fallback_private_key()) +def get_fallback_account(path: Optional[Path] = None) -> SOLAccount: + return SOLAccount(private_key=get_fallback_private_key(path=path)) -def get_fallback_private_key(): - try: - with open(settings.PRIVATE_KEY_FILE, "rb") as prvfile: - pkey = prvfile.read() - except OSError: - pkey = bytes(SigningKey.generate()) - with open(settings.PRIVATE_KEY_FILE, "wb") as prvfile: - prvfile.write(pkey) +def generate_key() -> bytes: + privkey = bytes(SigningKey.generate()) + return privkey + + +def get_fallback_private_key(path: Optional[Path] = None) -> bytes: + path = path or settings.PRIVATE_KEY_FILE + private_key: bytes + if path.exists() and path.stat().st_size > 0: + with open(path, "rb") as prvfile: + private_key = prvfile.read() + else: + private_key = generate_key() + os.makedirs(path.parent, exist_ok=True) + with open(path, "wb") as prvfile: + prvfile.write(private_key) + + with open(path, "rb") as prvfile: + print(prvfile.read()) + + + default_key_path = path.parent / "default.key" + if not default_key_path.is_symlink(): + # Create a symlink to use this key by default + os.symlink(path, default_key_path) + return private_key - return pkey diff --git a/src/aleph_client/chains/tezos.py b/src/aleph_client/chains/tezos.py index 03cf0009..417b7847 100644 --- a/src/aleph_client/chains/tezos.py +++ b/src/aleph_client/chains/tezos.py @@ -1,5 +1,6 @@ import json -from typing import Dict +from pathlib import Path +from typing import Dict, Optional from aleph_pytezos.crypto.key import Key from nacl.public import SealedBox @@ -49,5 +50,5 @@ async def decrypt(self, content) -> bytes: return SealedBox(self._private_key).decrypt(content) -def get_fallback_account() -> TezosAccount: - return TezosAccount(private_key=get_fallback_private_key()) +def get_fallback_account(path: Optional[Path] = None) -> TezosAccount: + return TezosAccount(private_key=get_fallback_private_key(path=path)) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 1a4c4cb8..30c3c49c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -6,5 +6,38 @@ Read more about conftest.py under: https://pytest.org/latest/plugins.html """ +from pathlib import Path +from tempfile import NamedTemporaryFile -# import pytest +import pytest + +from aleph_client.chains.common import get_fallback_private_key +import aleph_client.chains.ethereum as ethereum +import aleph_client.chains.sol as solana +import aleph_client.chains.tezos as tezos + +@pytest.fixture +def fallback_private_key() -> bytes: + with NamedTemporaryFile() as private_key_file: + yield get_fallback_private_key(path=Path(private_key_file.name)) + + +@pytest.fixture +def ethereum_account() -> ethereum.ETHAccount: + with NamedTemporaryFile(delete=False) as private_key_file: + private_key_file.close() + yield ethereum.get_fallback_account(path=Path(private_key_file.name)) + + +@pytest.fixture +def solana_account() -> solana.SOLAccount: + with NamedTemporaryFile(delete=False) as private_key_file: + private_key_file.close() + yield solana.get_fallback_account(path=Path(private_key_file.name)) + + +@pytest.fixture +def tezos_account() -> tezos.TezosAccount: + with NamedTemporaryFile(delete=False) as private_key_file: + private_key_file.close() + yield tezos.get_fallback_account(path=Path(private_key_file.name)) diff --git a/tests/unit/test_asynchronous.py b/tests/unit/test_asynchronous.py index 5caa81da..bb689f55 100644 --- a/tests/unit/test_asynchronous.py +++ b/tests/unit/test_asynchronous.py @@ -1,4 +1,3 @@ -import os from unittest.mock import MagicMock, patch, AsyncMock import pytest as pytest @@ -10,6 +9,8 @@ ForgetMessage, ) +from aleph_client.types import StorageEnum, MessageStatus + from aleph_client.asynchronous import ( create_post, _get_fallback_session, @@ -18,10 +19,6 @@ create_program, forget, ) -from aleph_client.chains.common import get_fallback_private_key, delete_private_key_file -from aleph_client.chains.ethereum import ETHAccount -from aleph_client.conf import settings -from aleph_client.types import StorageEnum, MessageStatus def new_mock_session_with_post_success(): @@ -41,21 +38,15 @@ def new_mock_session_with_post_success(): @pytest.mark.asyncio -async def test_create_post(): +async def test_create_post(ethereum_account): _get_fallback_session.cache_clear() - if os.path.exists(settings.PRIVATE_KEY_FILE): - delete_private_key_file() - - private_key = get_fallback_private_key() - account: ETHAccount = ETHAccount(private_key=private_key) - content = {"Hello": "World"} mock_session = new_mock_session_with_post_success() post_message, message_status = await create_post( - account=account, + account=ethereum_account, post_content=content, post_type="TEST", channel="TEST", @@ -70,21 +61,15 @@ async def test_create_post(): @pytest.mark.asyncio -async def test_create_aggregate(): +async def test_create_aggregate(ethereum_account): _get_fallback_session.cache_clear() - if os.path.exists(settings.PRIVATE_KEY_FILE): - delete_private_key_file() - - private_key = get_fallback_private_key() - account: ETHAccount = ETHAccount(private_key=private_key) - content = {"Hello": "World"} mock_session = new_mock_session_with_post_success() _ = await create_aggregate( - account=account, + account=ethereum_account, key="hello", content=content, channel="TEST", @@ -92,7 +77,7 @@ async def test_create_aggregate(): ) aggregate_message, message_status = await create_aggregate( - account=account, + account=ethereum_account, key="hello", content="world", channel="TEST", @@ -105,15 +90,9 @@ async def test_create_aggregate(): @pytest.mark.asyncio -async def test_create_store(): +async def test_create_store(ethereum_account): _get_fallback_session.cache_clear() - if os.path.exists(settings.PRIVATE_KEY_FILE): - delete_private_key_file() - - private_key = get_fallback_private_key() - account: ETHAccount = ETHAccount(private_key=private_key) - mock_session = new_mock_session_with_post_success() mock_ipfs_push_file = AsyncMock() @@ -121,7 +100,7 @@ async def test_create_store(): with patch("aleph_client.asynchronous.ipfs_push_file", mock_ipfs_push_file): _ = await create_store( - account=account, + account=ethereum_account, file_content=b"HELLO", channel="TEST", storage_engine=StorageEnum.ipfs, @@ -130,7 +109,7 @@ async def test_create_store(): ) _ = await create_store( - account=account, + account=ethereum_account, file_hash="QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", channel="TEST", storage_engine=StorageEnum.ipfs, @@ -144,8 +123,9 @@ async def test_create_store(): ) with patch("aleph_client.asynchronous.storage_push_file", mock_storage_push_file): + store_message, message_status = await create_store( - account=account, + account=ethereum_account, file_content=b"HELLO", channel="TEST", storage_engine=StorageEnum.storage, @@ -158,19 +138,13 @@ async def test_create_store(): @pytest.mark.asyncio -async def test_create_program(): +async def test_create_program(ethereum_account): _get_fallback_session.cache_clear() - if os.path.exists(settings.PRIVATE_KEY_FILE): - delete_private_key_file() - - private_key = get_fallback_private_key() - account: ETHAccount = ETHAccount(private_key=private_key) - mock_session = new_mock_session_with_post_success() program_message, message_status = await create_program( - account=account, + account=ethereum_account, program_ref="FAKE-HASH", entrypoint="main:app", runtime="FAKE-HASH", @@ -184,19 +158,13 @@ async def test_create_program(): @pytest.mark.asyncio -async def test_forget(): +async def test_forget(ethereum_account): _get_fallback_session.cache_clear() - if os.path.exists(settings.PRIVATE_KEY_FILE): - delete_private_key_file() - - private_key = get_fallback_private_key() - account: ETHAccount = ETHAccount(private_key=private_key) - mock_session = new_mock_session_with_post_success() forget_message, message_status = await forget( - account=account, + account=ethereum_account, hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"], reason="GDPR", channel="TEST", diff --git a/tests/unit/test_chain_ethereum.py b/tests/unit/test_chain_ethereum.py index 545dd1e0..3057c81a 100644 --- a/tests/unit/test_chain_ethereum.py +++ b/tests/unit/test_chain_ethereum.py @@ -1,7 +1,9 @@ +from pathlib import Path +from tempfile import NamedTemporaryFile + import pytest from dataclasses import dataclass, asdict -from aleph_client.chains.common import delete_private_key_file from aleph_client.chains.ethereum import ETHAccount, get_fallback_account @@ -14,17 +16,16 @@ class Message: def test_get_fallback_account(): - delete_private_key_file() - account: ETHAccount = get_fallback_account() - - assert account.CHAIN == "ETH" - assert account.CURVE == "secp256k1" - assert account._account.address + with NamedTemporaryFile() as private_key_file: + account = get_fallback_account(path=Path(private_key_file.name)) + assert account.CHAIN == "ETH" + assert account.CURVE == "secp256k1" + assert account._account.address @pytest.mark.asyncio -async def test_ETHAccount(): - account: ETHAccount = get_fallback_account() +async def test_ETHAccount(ethereum_account): + account = ethereum_account message = Message("ETH", account.get_address(), "SomeType", "ItemHash") signed = await account.sign_message(asdict(message)) @@ -42,8 +43,8 @@ async def test_ETHAccount(): @pytest.mark.asyncio -async def test_decrypt_secp256k1(): - account: ETHAccount = get_fallback_account() +async def test_decrypt_secp256k1(ethereum_account): + account = ethereum_account assert account.CURVE == "secp256k1" content = b"SomeContent" diff --git a/tests/unit/test_chain_nuls1_compat.py b/tests/unit/test_chain_nuls1_compat.py index 868a8374..357b7afc 100644 --- a/tests/unit/test_chain_nuls1_compat.py +++ b/tests/unit/test_chain_nuls1_compat.py @@ -2,12 +2,14 @@ This file tests that both implementations returns identical results. """ +from pathlib import Path +from tempfile import NamedTemporaryFile import pytest import secp256k1 from coincurve.keys import PrivateKey -from aleph_client.chains.common import get_fallback_private_key, delete_private_key_file +from aleph_client.chains.common import get_fallback_private_key from aleph_client.chains.nuls1 import NulsSignature, VarInt, MESSAGE_TEMPLATE, LOGGER SECRET = ( @@ -63,16 +65,16 @@ def test_sign_data_deprecated(): data = None signature = NulsSignature(data=data) - delete_private_key_file() - private_key = get_fallback_private_key() + with NamedTemporaryFile() as private_key_file: + private_key = get_fallback_private_key(path=Path(private_key_file.name)) - assert signature - sign_deprecated: NulsSignatureSecp256k1 = ( - NulsSignatureSecp256k1.sign_data_deprecated( - pri_key=private_key, digest_bytes=b"x" * (256 // 8) + assert signature + sign_deprecated: NulsSignatureSecp256k1 = ( + NulsSignatureSecp256k1.sign_data_deprecated( + pri_key=private_key, digest_bytes=b"x" * (256 // 8) + ) ) - ) - assert sign_deprecated + assert sign_deprecated @pytest.mark.asyncio diff --git a/tests/unit/test_chain_solana.py b/tests/unit/test_chain_solana.py index 4815667e..e62cc7ef 100644 --- a/tests/unit/test_chain_solana.py +++ b/tests/unit/test_chain_solana.py @@ -1,4 +1,6 @@ import json +from pathlib import Path +from tempfile import NamedTemporaryFile import base58 import pytest @@ -6,7 +8,7 @@ from nacl.signing import VerifyKey -from aleph_client.chains.common import delete_private_key_file, get_verification_buffer +from aleph_client.chains.common import get_verification_buffer from aleph_client.chains.sol import SOLAccount, get_fallback_account @@ -19,23 +21,21 @@ class Message: def test_get_fallback_account(): - delete_private_key_file() - account: SOLAccount = get_fallback_account() + with NamedTemporaryFile() as private_key_file: + account: SOLAccount = get_fallback_account(path=Path(private_key_file.name)) - assert account.CHAIN == "SOL" - assert account.CURVE == "curve25519" - assert account._signing_key.verify_key - assert type(account.private_key) == bytes - assert len(account.private_key) == 32 + assert account.CHAIN == "SOL" + assert account.CURVE == "curve25519" + assert account._signing_key.verify_key + assert type(account.private_key) == bytes + assert len(account.private_key) == 32 @pytest.mark.asyncio -async def test_SOLAccount(): - account: SOLAccount = get_fallback_account() - - message = asdict(Message("SOL", account.get_address(), "SomeType", "ItemHash")) +async def test_SOLAccount(solana_account): + message = asdict(Message("SOL", solana_account.get_address(), "SomeType", "ItemHash")) initial_message = message.copy() - await account.sign_message(message) + await solana_account.sign_message(message) assert message["signature"] address = message["sender"] @@ -61,14 +61,12 @@ async def test_SOLAccount(): @pytest.mark.asyncio -async def test_decrypt_curve25516(): - account: SOLAccount = get_fallback_account() - - assert account.CURVE == "curve25519" +async def test_decrypt_curve25516(solana_account): + assert solana_account.CURVE == "curve25519" content = b"SomeContent" - encrypted = await account.encrypt(content) + encrypted = await solana_account.encrypt(content) assert type(encrypted) == bytes - decrypted = await account.decrypt(encrypted) + decrypted = await solana_account.decrypt(encrypted) assert type(decrypted) == bytes assert content == decrypted diff --git a/tests/unit/test_chain_tezos.py b/tests/unit/test_chain_tezos.py index 6e49c071..026cbc0a 100644 --- a/tests/unit/test_chain_tezos.py +++ b/tests/unit/test_chain_tezos.py @@ -1,6 +1,8 @@ +from pathlib import Path +from tempfile import NamedTemporaryFile + import pytest -from aleph_client.chains.common import delete_private_key_file from aleph_client.chains.tezos import TezosAccount, get_fallback_account from dataclasses import dataclass, asdict @@ -13,43 +15,40 @@ class Message: item_hash: str -def test_get_fallback_account(): - delete_private_key_file() - account: TezosAccount = get_fallback_account() +def test_get_fallback_account(tezos_account: TezosAccount): + with NamedTemporaryFile() as private_key_file: + account: TezosAccount = get_fallback_account(path=Path(private_key_file.name)) - assert account.CHAIN == "TEZOS" - assert account.CURVE == "secp256k1" - assert account._account.public_key() + assert account.CHAIN == "TEZOS" + assert account.CURVE == "secp256k1" + assert account._account.public_key() @pytest.mark.asyncio -async def test_tezos_account(): - account: TezosAccount = get_fallback_account() +async def test_tezos_account(tezos_account: TezosAccount): - message = Message("TEZOS", account.get_address(), "SomeType", "ItemHash") - signed = await account.sign_message(asdict(message)) + message = Message("TEZOS", tezos_account.get_address(), "SomeType", "ItemHash") + signed = await tezos_account.sign_message(asdict(message)) assert signed["signature"] assert len(signed["signature"]) == 188 - address = account.get_address() + address = tezos_account.get_address() assert address is not None assert isinstance(address, str) assert len(address) == 36 - pubkey = account.get_public_key() + pubkey = tezos_account.get_public_key() assert isinstance(pubkey, str) assert len(pubkey) == 55 @pytest.mark.asyncio -async def test_decrypt_secp256k1(): - account: TezosAccount = get_fallback_account() - - assert account.CURVE == "secp256k1" +async def test_decrypt_secp256k1(tezos_account: TezosAccount): + assert tezos_account.CURVE == "secp256k1" content = b"SomeContent" - encrypted = await account.encrypt(content) + encrypted = await tezos_account.encrypt(content) assert isinstance(encrypted, bytes) - decrypted = await account.decrypt(encrypted) + decrypted = await tezos_account.decrypt(encrypted) assert isinstance(decrypted, bytes) assert content == decrypted