From dd0176e68d1c72567847402c7e2c54b99e8c46ef Mon Sep 17 00:00:00 2001 From: roshii Date: Thu, 3 Aug 2023 10:39:46 +0200 Subject: [PATCH] Check wallet lock file before asking for password Co-authored-by: Wukong --- jmclient/jmclient/storage.py | 70 ++++++++++++++++++++----------- jmclient/jmclient/wallet_utils.py | 6 ++- jmclient/test/test_storage.py | 5 +++ jmclient/test/test_wallet_rpc.py | 53 +++++++++++++++++++++-- scripts/joinmarket-qt.py | 46 +++++++++++++++----- 5 files changed, 141 insertions(+), 39 deletions(-) diff --git a/jmclient/jmclient/storage.py b/jmclient/jmclient/storage.py index 21f5aff66..ea4fc895c 100644 --- a/jmclient/jmclient/storage.py +++ b/jmclient/jmclient/storage.py @@ -107,7 +107,7 @@ def is_encrypted(self): return self._hash is not None def is_locked(self): - return self._lock_file and os.path.exists(self._lock_file) + return self._lock_file is not None and os.path.exists(self._lock_file) def was_changed(self): """ @@ -279,34 +279,52 @@ def _hash_password(cls, password, salt=None): return Argon2Hash(password, salt, hash_len=cls.ENC_KEY_BYTES, salt_len=cls.SALT_LENGTH) - def _create_lock(self): - if self.read_only: - return - (path_head, path_tail) = os.path.split(self.path) - lock_filename = os.path.join(path_head, '.' + path_tail + '.lock') - self._lock_file = lock_filename - if os.path.exists(self._lock_file): - with open(self._lock_file, 'r') as f: - try: - locked_by_pid = int(f.read()) - except ValueError: - locked_by_pid = None - self._lock_file = None + @staticmethod + def _get_lock_filename(path: str) -> str: + """Return lock filename""" + (path_head, path_tail) = os.path.split(path) + return os.path.join(path_head, '.' + path_tail + '.lock') + + @classmethod + def _get_locking_pid(cls, path: str) -> int: + """Return locking PID, -1 if no lockfile if found, 0 if PID cannot be read.""" + try: + with open(cls._get_lock_filename(path), 'r') as f: + return int(f.read()) + except FileNotFoundError: + return -1 + except ValueError: + return 0 + + @classmethod + def verify_lock(cls, path: str): + locked_by_pid = cls._get_locking_pid(path) + if locked_by_pid >= 0: raise RetryableStorageError( - "File is currently in use (locked by pid {}). " - "If this is a leftover from a crashed instance " - "you need to remove the lock file `{}` manually." . - format(locked_by_pid, lock_filename)) - #FIXME: in python >=3.3 use mode x - with open(self._lock_file, 'w') as f: - f.write(str(os.getpid())) + "File is currently in use (locked by pid {}). " + "If this is a leftover from a crashed instance " + "you need to remove the lock file `{}` manually.". + format(locked_by_pid, cls._get_lock_filename(path)) + ) + + def _create_lock(self): + if not self.read_only: + self._lock_file = self._get_lock_filename(self.path) + try: + with open(self._lock_file, 'x') as f: + f.write(str(os.getpid())) + except FileExistsError: + self._lock_file = None + self.verify_lock(self.path) atexit.register(self.close) def _remove_lock(self): - if self._lock_file: - os.remove(self._lock_file) - self._lock_file = None + if self._lock_file is not None: + try: + os.remove(self._lock_file) + except FileNotFoundError: + pass def close(self): if not self.read_only and self.was_changed(): @@ -338,6 +356,10 @@ def _create_lock(self): def _remove_lock(self): pass + @classmethod + def verify_lock(cls): + pass + def _write_file(self, data): self.file_data = data diff --git a/jmclient/jmclient/wallet_utils.py b/jmclient/jmclient/wallet_utils.py index 1955860cc..50d4db8dc 100644 --- a/jmclient/jmclient/wallet_utils.py +++ b/jmclient/jmclient/wallet_utils.py @@ -593,13 +593,13 @@ def get_addr_status(addr_path, utxos, utxos_enabled, is_new, is_internal): label = wallet_service.get_address_label(addr) timelock = datetime.utcfromtimestamp(0) + timedelta(seconds=path[-1]) - balance = sum([utxodata["value"] for _, utxodata in + balance = sum([utxodata["value"] for _, utxodata in utxos[m].items() if path == utxodata["path"]]) status = timelock.strftime("%Y-%m-%d") + " [" + ( "LOCKED" if datetime.now() < timelock else "UNLOCKED") + "]" status += get_utxo_status_string(utxos[m], utxos_enabled[m], path) - + privkey = "" if showprivkey: privkey = wallet_service.get_wif_path(path) @@ -1532,6 +1532,8 @@ def open_wallet(path, ask_for_password=True, password=None, read_only=False, if ask_for_password and Storage.is_encrypted_storage_file(path): while True: try: + # Verify lock status before trying to open wallet. + Storage.verify_lock(path) # do not try empty password, assume unencrypted on empty password pwd = get_password("Enter passphrase to decrypt wallet: ") or None storage = Storage(path, password=pwd, read_only=read_only) diff --git a/jmclient/test/test_storage.py b/jmclient/test/test_storage.py index 98f626f3c..a8a2925ae 100644 --- a/jmclient/test/test_storage.py +++ b/jmclient/test/test_storage.py @@ -131,3 +131,8 @@ def test_storage_lock(tmpdir): assert s.is_locked() assert s.data == {b'test': b'value'} + # Assert a new lock cannot be created + with pytest.raises(storage.StorageError): + s._create_lock() + pytest.fail("It should not be possible to re-create a lock") + diff --git a/jmclient/test/test_wallet_rpc.py b/jmclient/test/test_wallet_rpc.py index 44e05a898..b14f746ee 100644 --- a/jmclient/test/test_wallet_rpc.py +++ b/jmclient/test/test_wallet_rpc.py @@ -12,9 +12,16 @@ from jmbase import get_nontor_agent, hextobin, BytesProducer, get_log from jmbase.support import get_free_tcp_ports from jmbitcoin import CTransaction -from jmclient import (load_test_config, jm_single, SegwitWalletFidelityBonds, - JMWalletDaemon, validate_address, start_reactor, - SegwitWallet) +from jmclient import ( + load_test_config, + jm_single, + SegwitWalletFidelityBonds, + JMWalletDaemon, + validate_address, + start_reactor, + SegwitWallet, + storage, +) from jmclient.wallet_rpc import api_version_string, CJ_MAKER_RUNNING, CJ_NOT_RUNNING from commontest import make_wallets from test_coinjoin import make_wallets_to_list, sync_wallets @@ -105,6 +112,11 @@ def clean_out_wallet_files(self): if os.path.exists(wfn): os.remove(wfn) + parent, name = os.path.split(wfn) + lockfile = os.path.join(parent, f".{name}.lock") + if os.path.exists(lockfile): + os.remove(lockfile) + def get_wallet_file_name(self, i, fullpath=False): tfn = testfilename + str(i) + ".jmdat" if fullpath: @@ -413,6 +425,38 @@ def test_create_list_lock_unlock(self): yield self.do_request(agent, b"POST", addr, body, self.process_unlock_response) + @defer.inlineCallbacks + def test_unlock_locked(self): + """Assert if unlocking a wallet locked by another process fails.""" + self.clean_out_wallet_files() + self.daemon.services["wallet"] = None + self.daemon.stopService() + self.daemon.auth_disabled = False + + wfn = self.get_wallet_file_name(1) + self.wfnames = [wfn] + + agent = get_nontor_agent() + root = self.get_route_root() + + # Create first + p = self.get_wallet_file_name(1, True) + pw = "None" + + s = storage.Storage(p, bytes(pw, "utf-8"), create=True) + assert s.is_locked() + + # Unlocking a locked wallet should fail + + addr = root + "/wallet/" + wfn + "/unlock" + addr = addr.encode() + body = BytesProducer(json.dumps({"password": pw}).encode()) + yield self.do_request( + agent, b"POST", addr, body, self.process_failed_unlock_response + ) + + s.close() + def process_create_wallet_response(self, response, code): assert code == 201 json_body = json.loads(response.decode("utf-8")) @@ -610,6 +654,9 @@ def process_unlock_response(self, response, code): assert json_body["walletname"] in self.wfnames self.jwt_token = json_body["token"] + def process_failed_unlock_response(self, response, code): + assert code == 409 + def process_lock_response(self, response, code): assert code == 200 json_body = json.loads(response.decode("utf-8")) diff --git a/scripts/joinmarket-qt.py b/scripts/joinmarket-qt.py index 53f069aac..f9151b7ad 100755 --- a/scripts/joinmarket-qt.py +++ b/scripts/joinmarket-qt.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 from future.utils import iteritems +from typing import Optional ''' Joinmarket GUI using PyQt for doing coinjoins. @@ -74,7 +75,7 @@ detect_script_type, general_custom_change_warning, \ nonwallet_custom_change_warning, sweep_custom_change_warning, EngineError,\ TYPE_P2WPKH, check_and_start_tor, is_extended_public_key, \ - ScheduleGenerationErrorNoFunds + ScheduleGenerationErrorNoFunds, Storage from jmclient.wallet import BaseWallet from qtsupport import ScheduleWizard, TumbleRestartWizard, config_tips,\ @@ -111,7 +112,11 @@ def update_config_for_gui(): from jmqtui import Ui_OpenWalletDialog + + class JMOpenWalletDialog(QDialog, Ui_OpenWalletDialog): + DEFAULT_WALLET_FILE_TEXT = "wallet.jmdat" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.setupUi(self) @@ -119,15 +124,35 @@ def __init__(self, *args, **kwargs): self.chooseWalletButton.clicked.connect(self.chooseWalletFile) - def chooseWalletFile(self): - wallets_path = os.path.join(jm_single().datadir, 'wallets') - (filename, _) = QFileDialog.getOpenFileName(self, - 'Choose Wallet File', - wallets_path, - options=QFileDialog.DontUseNativeDialog) + def chooseWalletFile(self, error_text: str = ""): + (filename, _) = QFileDialog.getOpenFileName( + self, + "Choose Wallet File", + self._get_wallets_path(), + options=QFileDialog.DontUseNativeDialog, + ) if filename: self.walletFileEdit.setText(filename) self.passphraseEdit.setFocus() + self.errorMessageLabel.setText(self.verify_lock(filename)) + + @staticmethod + def _get_wallets_path() -> str: + """Return wallets path""" + return os.path.join(jm_single().datadir, "wallets") + + @classmethod + def verify_lock(cls, filename: Optional[str] = None) -> str: + """Return an error text if wallet is locked, empty string otherwise""" + if filename is None: + filename = os.path.join( + cls._get_wallets_path(), cls.DEFAULT_WALLET_FILE_TEXT + ) + try: + Storage.verify_lock(filename) + return "" + except Exception as e: + return str(e) class HelpLabel(QLabel): @@ -1991,13 +2016,14 @@ def recoverWallet(self): def openWallet(self): wallet_loaded = False - wallet_file_text = "wallet.jmdat" error_text = "" while not wallet_loaded: openWalletDialog = JMOpenWalletDialog() - openWalletDialog.walletFileEdit.setText(wallet_file_text) - openWalletDialog.errorMessageLabel.setText(error_text) + # Set default wallet file name and verify its lock status + openWalletDialog.walletFileEdit.setText(openWalletDialog.DEFAULT_WALLET_FILE_TEXT) + openWalletDialog.errorMessageLabel.setText(openWalletDialog.verify_lock()) + if openWalletDialog.exec_() == QDialog.Accepted: wallet_file_text = openWalletDialog.walletFileEdit.text() wallet_path = wallet_file_text