-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
35dab56
commit 80ceb8c
Showing
7 changed files
with
147 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from pyteal import Expr, For, Int, Seq, Subroutine, TealType | ||
|
||
from pytealext import AutoLoadScratchVar | ||
|
||
|
||
@Subroutine(TealType.uint64) | ||
def sum_of_integers_in_range(start: Expr, end: Expr) -> Expr: | ||
"""Calculate the sum of integers in the range [start, end)""" | ||
i = AutoLoadScratchVar(TealType.uint64) | ||
s = AutoLoadScratchVar(TealType.uint64) | ||
|
||
return Seq( | ||
s.store(Int(0)), | ||
For(i.store(start), i < end, i.increment()).Do( | ||
# with regular scratch vars, this would be: | ||
# s.store(s.load() + i.load()) | ||
s.increment(i) | ||
), | ||
s.load(), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from pyteal import Bytes, CompileOptions, Expr, Int, ScratchVar, TealBlock, TealSimpleBlock, TealType | ||
|
||
|
||
class AutoLoadScratchVar(ScratchVar, Expr): | ||
"""Makes ScratchVars more convenient to use. | ||
When used inline with another expression, a load is performed automatically | ||
without the need for explicit load() calls. | ||
Store still has to be called manually. | ||
Example: | ||
``` | ||
s = AutoLoadScratchVar(TealType.uint64) | ||
sum = AutoLoadScratchVar(TealType.uint64) | ||
program = For(s.store(Int(0)), s < Int(10), s.increment()).Do( | ||
sum.increment(s) | ||
) | ||
``` | ||
""" | ||
|
||
def __init__(self, type: TealType = TealType.anytype, slotId: int = None): # pylint: disable=redefined-builtin | ||
ScratchVar.__init__(self, type, slotId) | ||
Expr.__init__(self) | ||
|
||
def store(self, value: Expr | int | str | bytes) -> Expr: | ||
match value: | ||
case int(v): | ||
value = Int(v) | ||
case str(v) | bytes(v): | ||
value = Bytes(v) # type: ignore | ||
case Expr(): | ||
pass | ||
case _: | ||
raise TypeError(f"Invalid type for ScratchVarPro.store: {type(value)}") | ||
# superclass's store will check for correct stack type | ||
return ScratchVar.store(self, value) # type: ignore | ||
|
||
def increment(self, value: Expr | int = 1) -> Expr: | ||
"""Increase the value in the scratch space by the given value (1 by default)""" | ||
if isinstance(value, int): | ||
value = Int(value) | ||
return self.store(self.load() + value) | ||
|
||
def decrement(self, value: Expr | int = 1) -> Expr: | ||
"""Decrease the value in the scratch space by the given value (1 by default)""" | ||
if isinstance(value, int): | ||
value = Int(value) | ||
return self.store(self.load() - value) | ||
|
||
def type_of(self) -> TealType: | ||
return self.type | ||
|
||
def has_return(self) -> bool: | ||
return False | ||
|
||
def __str__(self) -> str: | ||
return self.load().__str__() | ||
|
||
def __teal__(self, options: CompileOptions) -> tuple[TealBlock, TealSimpleBlock]: | ||
return self.load().__teal__(options) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
pyteal>=0.15.0 | ||
pyteal>=0.20.0, <1.0.0 | ||
py-algorand-sdk >= 1.15.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import pytest | ||
from pyteal import Bytes, For, Int, Seq, TealType | ||
|
||
from examples.auto_load_scratch_var import sum_of_integers_in_range | ||
from pytealext import AutoLoadScratchVar | ||
from pytealext.evaluator import compile_and_run | ||
|
||
|
||
def test_AutoLoadScratchVar(): | ||
i = AutoLoadScratchVar(TealType.uint64) | ||
s = AutoLoadScratchVar(TealType.uint64) | ||
|
||
program = sum_of_integers_in_range(Int(0), Int(10)) | ||
|
||
stack, _ = compile_and_run(program) | ||
assert stack[0] == 45 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"source,expected", | ||
[(Int(1123), 1123), (Bytes("hello"), b"hello"), (123123, 123123), (b"hello\x00", b"hello\x00"), ("'sup", b"'sup")], | ||
) | ||
def test_store_different_types(source, expected): | ||
s = AutoLoadScratchVar(slotId=10) | ||
program = Seq(s.store(source), Int(1)) | ||
|
||
_, slots = compile_and_run(program) | ||
assert slots[10] == expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"amt", | ||
[ | ||
0, | ||
1, | ||
500, | ||
1000, | ||
], | ||
) | ||
def test_increment(amt: int): | ||
s = AutoLoadScratchVar() | ||
program = Seq(s.store(1000), s.increment(amt), s.load()) | ||
|
||
stack, _ = compile_and_run(program) | ||
assert stack[0] == 1000 + amt | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"amt", | ||
[ | ||
0, | ||
1, | ||
500, | ||
1000, | ||
], | ||
) | ||
def test_decrement(amt: int): | ||
s = AutoLoadScratchVar() | ||
program = Seq(s.store(1000), s.decrement(amt), s.load()) | ||
|
||
stack, _ = compile_and_run(program) | ||
assert stack[0] == 1000 - amt |