Skip to content

Updates and tests #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion pycardano/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from fractions import Fraction
from typing import Dict, List, Union
from typing import Dict, List, Union, Optional

from pycardano.address import Address
from pycardano.exception import InvalidArgumentException
Expand Down Expand Up @@ -110,6 +110,19 @@ class ProtocolParameters:
The value will be a dict of cost model parameters."""


@dataclass(frozen=True)
class StakeAddressInfo:
"""The current delegation and reward account for a stake address"""

address: str

delegation: str

reward_account_balance: int

delegation_deposit: Optional[int] = None


@typechecked
class ChainContext:
"""Interfaces through which the library interacts with Cardano blockchain."""
Expand Down Expand Up @@ -139,6 +152,30 @@ def last_block_slot(self) -> int:
"""Slot number of last block"""
raise NotImplementedError()

def stake_address_info(
self, address: Union[str, Address]
) -> List[StakeAddressInfo]:
"""Get the current delegation and reward account for a stake address.

Args:
address (Union[str, Address]): An address, potentially bech32 encoded

Returns:
List[StakeAddressInfo]: A list of StakeAddressInfo objects
"""
return self._stake_address_info(str(address))

def _stake_address_info(self, address: str) -> List[StakeAddressInfo]:
"""Get the current delegation and reward account for a stake address.

Args:
address (str): An address encoded with bech32.

Returns:
List[StakeAddressInfo]: A list of StakeAddressInfo objects
"""
raise NotImplementedError()

def utxos(self, address: Union[str, Address]) -> List[UTxO]:
"""Get all UTxOs associated with an address.

Expand Down
20 changes: 20 additions & 0 deletions pycardano/backend/blockfrost.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ChainContext,
GenesisParameters,
ProtocolParameters,
StakeAddressInfo,
)
from pycardano.exception import TransactionFailedException
from pycardano.hash import SCRIPT_HASH_SIZE, DatumHash, ScriptHash
Expand Down Expand Up @@ -305,3 +306,22 @@ def evaluate_tx_cbor(self, cbor: Union[bytes, str]) -> Dict[str, ExecutionUnits]
getattr(result.EvaluationResult, k).steps,
)
return return_val

def _stake_address_info(self, address: str) -> List[StakeAddressInfo]:
"""Get the current delegation and reward account for a stake address.

Args:
address (str): An address encoded with bech32.

Returns:
List[StakeAddressInfo]: A list of StakeAddressInfo objects
"""
results = self.api.accounts(address).to_dict()

return [
StakeAddressInfo(
address=results.get("stake_address"),
delegation=results.get("pool_id"),
reward_account_balance=results.get("withdrawable_amount"),
)
]
34 changes: 34 additions & 0 deletions pycardano/backend/cardano_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ChainContext,
GenesisParameters,
ProtocolParameters,
StakeAddressInfo,
)
from pycardano.exception import (
CardanoCliError,
Expand Down Expand Up @@ -530,3 +531,36 @@ def submit_tx_cbor(self, cbor: Union[bytes, str]) -> str:
) from err

return txid

def _stake_address_info(self, address: str) -> List[StakeAddressInfo]:
"""Get the current delegation and reward account for a stake address.

Args:
address (str): An address encoded with bech32.

Returns:
List[StakeAddressInfo]: A list of StakeAddressInfo objects
"""
results = self._run_command(
[
"query",
"stake-address-info",
"--address",
address,
"--out-file",
"/dev/stdout",
]
+ self._network_args
)

result_json = json.loads(results)

return [
StakeAddressInfo(
address=stake_info["address"],
delegation=stake_info["delegation"],
delegation_deposit=stake_info["delegationDeposit"],
reward_account_balance=stake_info["rewardAccountBalance"],
)
for stake_info in result_json
]
73 changes: 34 additions & 39 deletions pycardano/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
from typing_extensions import Literal

from pycardano.address import Address, PointerAddress
from pycardano.backend.base import ChainContext
from pycardano.backend.base import ChainContext, StakeAddressInfo
from pycardano.backend.blockfrost import BlockFrostChainContext
from pycardano.certificate import (
StakeCredential,
StakeDelegation,
StakeDeregistration,
StakeRegistration,
Certificate,
)
from pycardano.cip.cip8 import sign
from pycardano.exception import PyCardanoException
from pycardano.exception import PyCardanoException, CardanoCliError
from pycardano.hash import PoolKeyHash, ScriptHash, TransactionId
from pycardano.key import (
ExtendedSigningKey,
Expand Down Expand Up @@ -668,10 +668,11 @@ class Wallet:
ada: Optional[Ada] = field(repr=True, default=Ada(0))
signing_key: Optional[SigningKey] = field(repr=False, default=None)
verification_key: Optional[VerificationKey] = field(repr=False, default=None)
uxtos: Optional[list] = field(repr=False, default_factory=list)
context: Optional[BlockFrostChainContext] = field(repr=False, default=None)
uxtos: Optional[List[UTxO]] = None
context: Optional[ChainContext] = field(repr=False, default=None)

def __post_init__(self):
self.utxos = []
# convert address into pycardano format
if isinstance(self.address, str):
self.address = Address.from_primitive(self.address)
Expand Down Expand Up @@ -763,20 +764,18 @@ def stake_info(self):
@property
def pool_id(self):
account_info = get_stake_info(self.stake_address, self.context)
if account_info.get("pool_id"):
return account_info.get("pool_id")
else:
logger.warn("Stake address is not registered yet.")
return None
if len(account_info):
return account_info[0].delegation
logger.warn("Stake address is not registered yet.")
return None

@property
def withdrawable_amount(self):
account_info = get_stake_info(self.stake_address, self.context)
if account_info.get("withdrawable_amount"):
return Lovelace(int(account_info.get("withdrawable_amount")))
else:
logger.warn("Stake address is not registered yet.")
return Lovelace(0)
if len(account_info):
return Lovelace(int(account_info[0].reward_account_balance))
logger.warn("Stake address is not registered yet.")
return Lovelace(0)

def _load_or_create_key_pair(self, stake=True):
"""Look for a key pair in the keys directory. If not found, create a new key pair."""
Expand Down Expand Up @@ -918,12 +917,11 @@ def sync(self, context: Optional[ChainContext] = None):
logger.warning(
f"Error getting UTxOs. Address has likely not transacted yet. Details: {e}"
)
self.utxos = []

# calculate total ada
if self.utxos:
self.lovelace = Lovelace(
sum([utxo.output.amount.coin for utxo in self.utxos])
sum(utxo.output.amount.coin for utxo in self.utxos)
)
self.ada = self.lovelace.as_ada()

Expand All @@ -937,8 +935,8 @@ def sync(self, context: Optional[ChainContext] = None):
else:
logger.info(f"Wallet {self.name} has no UTxOs.")

self.lovelace = Lovelace(0)
self.ada = Ada(0)
self.lovelace = Lovelace()
self.ada = Ada()

def to_address(self):
return Address(
Expand Down Expand Up @@ -1145,7 +1143,7 @@ def delegate(
inputs = [self]

# check registration, do not register if already registered
active = self.stake_info.get("active")
active = len(self.stake_info) > 0
if register:
register = not active
elif not active:
Expand Down Expand Up @@ -1195,7 +1193,7 @@ def withdraw_rewards(
return self.transact(
inputs=[self],
outputs=Output(self, output_amount),
withdrawals={self: withdrawal_amount},
withdrawals={str(self.stake_address): withdrawal_amount},
**kwargs,
)

Expand Down Expand Up @@ -1241,11 +1239,9 @@ def mint_tokens(
inputs = utxos
else:
inputs = [self]

# if token amounts are negative, remove them from the outputs
if not isinstance(mints, list):
mints = [mints]

tokens = []
for mint in mints:
if mint.amount > 0:
Expand Down Expand Up @@ -1521,9 +1517,7 @@ def transact(
auxiliary_data = AuxiliaryData(Metadata())

# create stake_registrations, delegations
certificates: List[
Union[StakeDelegation, StakeRegistration, StakeDeregistration]
] = []
certificates: List[Certificate] = []
if stake_registration_info: # add registrations
if isinstance(stake_registration_info, bool):
# register current wallet
Expand Down Expand Up @@ -1589,9 +1583,9 @@ def transact(
# withdrawals
withdraw = {}
if withdrawals and isinstance(withdrawals, bool): # withdraw current wallet
withdraw[
self.stake_address.to_primitive()
] = self.withdrawable_amount.lovelace
withdraw[self.stake_address.to_primitive()] = (
self.withdrawable_amount.lovelace
)
if self.stake_signing_key not in signers_list:
signers_list.append(self.stake_signing_key)
elif isinstance(withdrawals, dict):
Expand All @@ -1611,7 +1605,7 @@ def transact(
"Withdraw all is only supported with BlockFrostChainContext at the moment."
)
account_info = get_stake_info(stake_address, tx_context)
withdrawable_amount = account_info.get("withdrawable_amount")
withdrawable_amount = account_info[0].reward_account_balance

if withdrawable_amount:
if isinstance(withdrawable_amount, (int, float)):
Expand All @@ -1638,7 +1632,7 @@ def transact(
if not stake_address.staking_part:
raise ValueError(f"Stake Address {stake_address} is invalid.")

withdraw[stake_address.staking_part.to_primitive()] = (
withdraw[str(stake_address)] = (
withdrawal_amount.as_lovelace().amount
if isinstance(withdrawal_amount, Lovelace)
else withdrawal_amount.lovelace
Expand Down Expand Up @@ -1824,8 +1818,8 @@ def get_utxo_block_time(utxo: UTxO, context: BlockFrostChainContext) -> int:


def get_stake_info(
stake_address: Union[str, Address], context: BlockFrostChainContext
) -> dict:
stake_address: Union[str, Address], context: ChainContext
) -> List[StakeAddressInfo]:
"""Get the stake info of a stake address from Blockfrost.
For more info see: https://docs.blockfrost.io/#tag/Cardano-Accounts/paths/~1accounts~1{stake_address}/get

Expand All @@ -1834,22 +1828,23 @@ def get_stake_info(
context (ChainContext): The context to use for the query. For now must be BlockFrost.

Returns:
dict: Info regarding the given stake address.
List[StakeAddressInfo]: A list of StakeAddressInfo objects
"""

if isinstance(stake_address, str):
stake_address = Address.from_primitive(stake_address)

if not type(stake_address) == Address:
if not isinstance(stake_address, Address):
raise TypeError(f"Address {stake_address} is not a valid stake address.")

if not stake_address.staking_part:
raise TypeError(f"Address {stake_address} has no staking part.")

try:
return context.api.accounts(str(stake_address)).to_dict()
except ApiError:
return {}
# return context.api.accounts(str(stake_address)).to_dict()
return context.stake_address_info(str(stake_address))
except (ApiError, CardanoCliError):
return []


def get_stake_address(address: Union[str, Address]) -> Address:
Expand All @@ -1864,7 +1859,7 @@ def get_stake_address(address: Union[str, Address]) -> Address:
if isinstance(address, str):
address = Address.from_primitive(address)

if not type(address) == Address:
if not isinstance(address, Address):
raise TypeError(f"Address {address} is not a valid address.")

return Address.from_primitive(
Expand Down
Loading