Skip to content
This repository was archived by the owner on Nov 15, 2021. It is now read-only.

[refactor-prompt] Migrate command: wallet import contract_addr #777

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 27 additions & 33 deletions neo/Prompt/Commands/LoadSmartContract.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,41 @@
from neocore.BigInteger import BigInteger


def ImportContractAddr(wallet, args):
if wallet is None:
print("please open a wallet")
def ImportContractAddr(wallet, contract_hash, pubkey_script_hash):
"""
Args:
wallet (Wallet): a UserWallet instance
contract_hash (UInt160): hash of the contract to import
pubkey_script_hash (UInt160):

Returns:
neo.SmartContract.Contract.Contract
"""

contract = Blockchain.Default().GetContract(contract_hash)
if not contract or not pubkey_script_hash:
print("Could not find contract")
return

contract_hash = get_arg(args, 0)
pubkey = get_arg(args, 1)
reedeem_script = contract.Code.Script.hex()

if contract_hash and pubkey:
# there has to be at least 1 param, and the first one needs to be a signature param
param_list = bytearray(b'\x00')

if len(pubkey) != 66:
print("invalid public key format")
# if there's more than one param
# we set the first parameter to be the signature param
if len(contract.Code.ParameterList) > 1:
param_list = bytearray(contract.Code.ParameterList)
param_list[0] = 0

pubkey_script_hash = Crypto.ToScriptHash(pubkey, unhex=True)
verification_contract = Contract.Create(reedeem_script, param_list, pubkey_script_hash)

contract = Blockchain.Default().GetContract(contract_hash)
address = verification_contract.Address

if contract is not None:
wallet.AddContract(verification_contract)

reedeem_script = contract.Code.Script.hex()

# there has to be at least 1 param, and the first
# one needs to be a signature param
param_list = bytearray(b'\x00')

# if there's more than one param
# we set the first parameter to be the signature param
if len(contract.Code.ParameterList) > 1:
param_list = bytearray(contract.Code.ParameterList)
param_list[0] = 0

verification_contract = Contract.Create(reedeem_script, param_list, pubkey_script_hash)

address = verification_contract.Address

wallet.AddContract(verification_contract)

print("Added contract addres %s to wallet" % address)
return

print("Could not add contract. Invalid public key or contract address")
print(f"Added contract address {address} to wallet")
return verification_contract


def LoadContract(args):
Expand Down
61 changes: 48 additions & 13 deletions neo/Prompt/Commands/Wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,25 @@
from neo.Prompt.PromptData import PromptData
from neo.Prompt.Commands.Send import CommandWalletSend, CommandWalletSendMany, CommandWalletSign
from neo.Prompt.Commands.Tokens import CommandWalletToken
from neo.Prompt.Commands.LoadSmartContract import ImportContractAddr
from neo.logging import log_manager
from neocore.Utils import isValidPublicAddress

logger = log_manager.getLogger()


def _is_valid_public_key(key):
if len(key) != 66:
return False
try:
Crypto.ToScriptHash(key, unhex=True)
except Exception:
# the UINT160 inside ToScriptHash can throw Exception
return False
else:
return True


class CommandWallet(CommandBase):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -425,6 +438,7 @@ def __init__(self):
self.register_sub_command(CommandWalletImportWatchAddr())
self.register_sub_command(CommandWalletImportMultisigAddr())
self.register_sub_command(CommandWalletImportToken())
self.register_sub_command(CommandWalletImportContractAddr())

def command_desc(self):
return CommandDesc('import', 'import wallet items')
Expand Down Expand Up @@ -617,17 +631,6 @@ class CommandWalletImportMultisigAddr(CommandBase):
def __init__(self):
super().__init__()

def _is_valid_public_key(self, key):
if len(key) != 66:
return False
try:
Crypto.ToScriptHash(key, unhex=True)
except Exception:
# the UINT160 inside ToScriptHash can throw Exception
return False
else:
return True

def execute(self, arguments):
wallet = PromptData.Wallet

Expand All @@ -636,7 +639,7 @@ def execute(self, arguments):
return False

pubkey_in_wallet = arguments[0]
if not self._is_valid_public_key(pubkey_in_wallet):
if not _is_valid_public_key(pubkey_in_wallet):
print("Invalid public key format")
return False

Expand Down Expand Up @@ -666,7 +669,7 @@ def execute(self, arguments):

# validate remaining pub keys
for key in signing_keys:
if not self._is_valid_public_key(key):
if not _is_valid_public_key(key):
print(f"Invalid signing key {key}")
return False

Expand Down Expand Up @@ -712,6 +715,38 @@ def command_desc(self):
return CommandDesc('token', 'import a token', [p1])


class CommandWalletImportContractAddr(CommandBase):
def __init__(self):
super().__init__()

def execute(self, arguments):
wallet = PromptData.Wallet

if len(arguments) != 2:
print("Please specify the required parameters")
return

try:
contract_hash = UInt160.ParseString(arguments[0]).ToBytes()
except Exception:
print(f"Invalid contract hash: {arguments[0]}")
return

pubkey = arguments[1]
if not _is_valid_public_key(pubkey):
print(f"Invalid pubkey: {arguments[1]}")
return

pubkey_script_hash = Crypto.ToScriptHash(pubkey, unhex=True)

return ImportContractAddr(wallet, contract_hash, pubkey_script_hash)

def command_desc(self):
p1 = ParameterDesc('contract_hash', 'hash of the contract')
p2 = ParameterDesc('pubkey', 'pubkey of the contract')
return CommandDesc('contract_addr', 'import a contract address', [p1, p2])


#########################################################################
#########################################################################

Expand Down
59 changes: 59 additions & 0 deletions neo/Prompt/Commands/tests/test_wallet_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from neocore.UInt160 import UInt160
from neocore.Fixed8 import Fixed8
from neo.Core.TX.ClaimTransaction import ClaimTransaction
from neo.SmartContract.Contract import Contract
from neo.Core.TX.Transaction import ContractTransaction
from neo.Prompt.Commands.BuildNRun import BuildAndRun
from neo.Prompt.Commands.Wallet import CommandWallet
from neo.Prompt.Commands.Wallet import CreateAddress, DeleteAddress, ImportToken, ShowUnspentCoins, SplitUnspentCoin
from neo.Prompt.PromptData import PromptData
Expand Down Expand Up @@ -828,6 +830,63 @@ def test_wallet_import_token(self):
self.assertEqual(token.decimals, 8)
self.assertEqual(token.Address, 'Ab61S1rk2VtCVd3NtGNphmBckWk4cfBdmB')

def test_wallet_import_contract_addr(self):
# test with no wallet open
with patch('sys.stdout', new=StringIO()) as mock_print:
args = ['import', 'contract_addr', 'contract_hash', 'pubkey']
res = CommandWallet().execute(args)
self.assertIsNone(res)
self.assertIn("open a wallet", mock_print.getvalue())

self.OpenWallet1()

# test with not enough arguments (must have 2 arguments)
with patch('sys.stdout', new=StringIO()) as mock_print:
args = ['import', 'contract_addr']
res = CommandWallet().execute(args)
self.assertIsNone(res)
self.assertIn("specify the required parameters", mock_print.getvalue())

# test with too many arguments (must have 2 arguments)
with patch('sys.stdout', new=StringIO()) as mock_print:
args = ['import', 'contract_addr', 'arg1', 'arg2', 'arg3']
res = CommandWallet().execute(args)
self.assertIsNone(res)
self.assertIn("specify the required parameters", mock_print.getvalue())

# test with invalid contract hash
with patch('sys.stdout', new=StringIO()) as mock_print:
args = ['import', 'contract_addr', 'invalid_contract_hash', '03cbb45da6072c14761c9da545749d9cfd863f860c351066d16df480602a2024c6']
res = CommandWallet().execute(args)
self.assertIsNone(res)
self.assertIn("Invalid contract hash", mock_print.getvalue())

# test with valid contract hash but that doesn't exist
with patch('sys.stdout', new=StringIO()) as mock_print:
args = ['import', 'contract_addr', '31730cc9a1844891a3bafd1aa929000000000000', '03cbb45da6072c14761c9da545749d9cfd863f860c351066d16df480602a2024c6']
res = CommandWallet().execute(args)
self.assertIsNone(res)
self.assertIn("Could not find contract", mock_print.getvalue())

# test with invalid pubkey
with patch('sys.stdout', new=StringIO()) as mock_print:
args = ['import', 'contract_addr', '31730cc9a1844891a3bafd1aa929a4142860d8d3', 'invalid_pubkey']
res = CommandWallet().execute(args)
self.assertIsNone(res)
self.assertIn("Invalid pubkey", mock_print.getvalue())

# test with valid arguments
contract_hash = UInt160.ParseString('31730cc9a1844891a3bafd1aa929a4142860d8d3')

with patch('sys.stdout', new=StringIO()) as mock_print:
self.assertIsNone(PromptData.Wallet.GetContract(contract_hash))

args = ['import', 'contract_addr', '31730cc9a1844891a3bafd1aa929a4142860d8d3', '03cbb45da6072c14761c9da545749d9cfd863f860c351066d16df480602a2024c6']
res = CommandWallet().execute(args)
self.assertIsInstance(res, Contract)
self.assertTrue(PromptData.Wallet.GetContract(contract_hash))
self.assertIn("Added contract address", mock_print.getvalue())

##########################################################
##########################################################

Expand Down