Skip to content

Commit 21ee7a8

Browse files
committed
Add TokenPolicy class
1 parent ae84ee0 commit 21ee7a8

File tree

1 file changed

+135
-1
lines changed

1 file changed

+135
-1
lines changed

pycardano/easy.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,143 @@ def ad_ada(self):
240240
return Ada(self.ada)
241241

242242

243+
@dataclass(unsafe_hash=True)
244+
class TokenPolicy:
245+
name: str
246+
policy: Optional[Union[NativeScript, dict]] = field(repr=False, default=None)
247+
policy_dir: Optional[Union[str, Path]] = field(
248+
repr=False, default=Path("./priv/policies")
249+
)
250+
251+
def __post_init__(self):
252+
253+
# streamline inputs
254+
if isinstance(self.policy_dir, str):
255+
self.policy_dir = Path(self.policy_dir)
256+
257+
if not self.policy_dir.exists():
258+
self.policy_dir.mkdir(parents=True, exist_ok=True)
259+
260+
# look for the policy
261+
if Path(self.policy_dir / f"{self.name}.script").exists():
262+
with open(
263+
Path(self.policy_dir / f"{self.name}.script"), "r"
264+
) as policy_file:
265+
self.policy = NativeScript.from_dict(json.load(policy_file))
266+
267+
elif isinstance(self.policy, dict):
268+
self.policy = NativeScript.from_dict(self.policy)
269+
270+
@property
271+
def policy_id(self):
272+
273+
if self.policy:
274+
return str(self.policy.hash())
275+
276+
@property
277+
def expiration_slot(self):
278+
"""Get the expiration slot for a simple minting policy,
279+
like one generated by generate_minting_policy
280+
"""
281+
282+
if self.policy:
283+
scripts = getattr(self.policy, "native_scripts", None)
284+
285+
if scripts:
286+
for script in scripts:
287+
if script._TYPE == 5:
288+
return script.after
289+
290+
def get_expiration_timestamp(self, context: ChainContext):
291+
"""Get the expiration timestamp for a simple minting policy,
292+
like one generated by generate_minting_policy
293+
"""
294+
295+
if self.expiration_slot:
296+
297+
seconds_diff = self.expiration_slot - context.last_block_slot
298+
299+
return datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
300+
seconds=seconds_diff
301+
)
302+
303+
def is_expired(self, context: ChainContext):
304+
"""Get the expiration timestamp for a simple minting policy,
305+
like one generated by generate_minting_policy
306+
"""
307+
308+
if self.expiration_slot:
309+
310+
seconds_diff = self.expiration_slot - context.last_block_slot
311+
312+
return seconds_diff < 0
313+
314+
def generate_minting_policy(
315+
self,
316+
signers: Union["Wallet", Address, List["Wallet"], List[Address]],
317+
expiration: Optional[Union[datetime.datetime, int]] = None,
318+
context: Optional[ChainContext] = None,
319+
):
320+
321+
script_filepath = Path(self.policy_dir / f"{self.name}.script")
322+
323+
if script_filepath.exists() or self.policy:
324+
raise FileExistsError(f"Policy named {self.name} already exists")
325+
326+
if isinstance(expiration, datetime.datetime) and not context:
327+
raise AttributeError(
328+
"If input expiration is provided as a datetime, please also provide a context."
329+
)
330+
331+
# get pub key hashes
332+
if not isinstance(signers, list):
333+
signers = [signers]
334+
335+
pub_keys = [ScriptPubkey(self._get_pub_key_hash(signer)) for signer in signers]
336+
337+
# calculate when to lock
338+
if expiration:
339+
if isinstance(expiration, int): # assume this is directly the block no.
340+
must_before_slot = InvalidHereAfter(expiration)
341+
elif isinstance(expiration, datetime.datetime):
342+
if expiration.tzinfo:
343+
time_until_expiration = expiration - datetime.datetime.now(
344+
datetime.datetime.utc
345+
)
346+
else:
347+
time_until_expiration = expiration - datetime.datetime.now()
348+
349+
last_block_slot = context.last_block_slot
350+
351+
must_before_slot = InvalidHereAfter(
352+
last_block_slot + int(time_until_expiration.total_seconds())
353+
)
354+
355+
policy = ScriptAll(pub_keys + [must_before_slot])
356+
357+
else:
358+
policy = ScriptAll(pub_keys)
359+
360+
# save policy to file
361+
with open(script_filepath, "w") as script_file:
362+
json.dump(policy.to_dict(), script_file, indent=4)
363+
364+
self.policy = policy
365+
366+
@staticmethod
367+
def _get_pub_key_hash(signer=Union["Wallet", Address]):
368+
369+
if hasattr(signer, "verification_key"):
370+
return signer.verification_key.hash()
371+
elif isinstance(signer, Address):
372+
return str(signer.payment_part)
373+
else:
374+
raise TypeError("Input signer must be of type Wallet or Address.")
375+
376+
243377
@dataclass(unsafe_hash=True)
244378
class Token:
245-
policy: Union[NativeScript, str]
379+
policy: Union[NativeScript, TokenPolicy]
246380
amount: int
247381
name: Optional[str] = field(default="")
248382
hex_name: Optional[str] = field(default="")

0 commit comments

Comments
 (0)