Skip to content

Commit bc2d04d

Browse files
committed
refactor: review feedback
1 parent 9df6223 commit bc2d04d

File tree

3 files changed

+76
-34
lines changed

3 files changed

+76
-34
lines changed

src/_algopy_testing/primitives/fixed_bytes.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import typing
77

88
if typing.TYPE_CHECKING:
9-
from collections.abc import Iterator
9+
from collections.abc import Callable, Iterator
1010

1111
from itertools import zip_longest
1212

@@ -24,18 +24,19 @@
2424
_TBytesLength_Arg = typing.TypeVar("_TBytesLength_Arg", bound=int)
2525

2626

27-
def _create_class_from_type(cls: type, length_t: type) -> type:
27+
def _get_or_create_class_from_type(cls: type, length_t: type) -> type:
2828
_length = get_int_literal_from_type_generic(length_t)
29-
return _create_class(cls, _length, length_t)
29+
return _get_or_create_class(cls, _length, length_t)
3030

3131

32-
def _create_class_from_int(cls: type, length: int) -> type:
32+
def _get_or_create_class_from_int(cls: type, length: int) -> type:
3333
length_t = get_type_generic_from_int_literal(length)
34-
return _create_class(cls, length, length_t)
34+
return _get_or_create_class(cls, length, length_t)
3535

3636

37-
def _create_class(cls: type, length: int, length_t: type) -> type:
38-
cache = cls.__concrete__ if hasattr(cls, "__concrete__") else {}
37+
def _get_or_create_class(cls: type, length: int, length_t: type) -> type:
38+
"""Get or create a type that is parametrized with element_t and length."""
39+
cache = getattr(cls, "__concrete__", {})
3940
if c := cache.get(length_t, None):
4041
assert isinstance(c, type)
4142
return c
@@ -55,12 +56,11 @@ def _create_class(cls: type, length: int, length_t: type) -> type:
5556
class _FixedBytesMeta(type):
5657
__concrete__: typing.ClassVar[dict[type, type]] = {}
5758

58-
# get or create a type that is parametrized with element_t and length
5959
def __getitem__(cls, length_t: type) -> type:
6060
if length_t == typing.Any:
6161
return cls
6262

63-
return _create_class_from_type(cls, length_t)
63+
return _get_or_create_class_from_type(cls, length_t)
6464

6565

6666
class FixedBytes(
@@ -103,8 +103,6 @@ def __bool__(self) -> bool:
103103
def __len__(self) -> int:
104104
return self._length
105105

106-
# mypy suggests due to Liskov below should be other: object
107-
# need to consider ramifications here, ignoring it for now
108106
def __eq__(self, other: FixedBytes[_TBytesLength_Arg] | Bytes | bytes) -> bool: # type: ignore[override]
109107
"""FixedBytes can be compared using the `==` operator with another FixedBytes,
110108
Bytes or bytes."""
@@ -134,17 +132,18 @@ def __radd__(self, other: FixedBytes[_TBytesLength_Arg] | Bytes | bytes) -> Byte
134132
result = as_bytes(other) + self.value
135133
return _checked_result(result, "+")
136134

135+
def __iadd__(self, _other: Bytes | typing.Self | bytes) -> typing.Self: # type: ignore[misc]
136+
raise TypeError("FixedBytes does not support in-place addition")
137+
137138
@property
138139
def length(self) -> UInt64:
139140
"""Returns the specified length of the FixedBytes."""
140141
return UInt64(self._length)
141142

142-
def __getitem__(
143-
self, index: UInt64 | int | slice
144-
) -> Bytes: # maps to substring/substring3 if slice, extract/extract3 otherwise?
143+
def __getitem__(self, index: UInt64 | int | slice) -> Bytes:
145144
"""Returns a Bytes containing a single byte if indexed with UInt64 or int
146145
otherwise the substring of bytes described by the slice."""
147-
value = self.value[None : self.length]
146+
value = self.value[: self.length]
148147
if isinstance(index, slice):
149148
return Bytes(value[index])
150149
else:
@@ -174,15 +173,15 @@ def __and__(self, other: FixedBytes[_TBytesLength_Arg] | bytes | Bytes) -> typin
174173
175174
Returns FixedBytes if other has the same length, otherwise returns Bytes.
176175
"""
177-
return self._operate_bitwise(other, "and_")
176+
return self._operate_bitwise(other, operator.and_)
178177

179178
def __rand__(self, other: bytes) -> Bytes:
180179
return self & other
181180

182181
def __iand__(self, other: Bytes | typing.Self | bytes) -> typing.Self: # type: ignore[misc]
183182
other_bytes = as_bytes(other)
184183
other_fixed_bytes = self.__class__(other_bytes)
185-
result = self._operate_bitwise(other_fixed_bytes, "and_")
184+
result = self._operate_bitwise(other_fixed_bytes, operator.and_)
186185
assert isinstance(result, self.__class__)
187186
return result
188187

@@ -191,15 +190,15 @@ def __or__(self, other: typing.Self) -> typing.Self: ... # type: ignore[overloa
191190
@typing.overload
192191
def __or__(self, other: FixedBytes[_TBytesLength_Arg] | bytes | Bytes) -> Bytes: ...
193192
def __or__(self, other: FixedBytes[_TBytesLength_Arg] | bytes | Bytes) -> typing.Self | Bytes:
194-
return self._operate_bitwise(other, "or_")
193+
return self._operate_bitwise(other, operator.or_)
195194

196195
def __ror__(self, other: bytes) -> Bytes:
197196
return self | other
198197

199198
def __ior__(self, other: Bytes | typing.Self | bytes) -> typing.Self: # type: ignore[misc]
200199
other_bytes = as_bytes(other)
201200
other_fixed_bytes = self.__class__(other_bytes)
202-
result = self._operate_bitwise(other_fixed_bytes, "or_")
201+
result = self._operate_bitwise(other_fixed_bytes, operator.or_)
203202
assert isinstance(result, self.__class__)
204203
return result
205204

@@ -208,15 +207,15 @@ def __xor__(self, other: typing.Self) -> typing.Self: ... # type: ignore[overlo
208207
@typing.overload
209208
def __xor__(self, other: FixedBytes[_TBytesLength_Arg] | bytes | Bytes) -> Bytes: ...
210209
def __xor__(self, other: FixedBytes[_TBytesLength_Arg] | bytes | Bytes) -> typing.Self | Bytes:
211-
return self._operate_bitwise(other, "xor")
210+
return self._operate_bitwise(other, operator.xor)
212211

213212
def __rxor__(self, other: bytes) -> Bytes:
214213
return self ^ other
215214

216215
def __ixor__(self, other: Bytes | typing.Self | bytes) -> typing.Self: # type: ignore[misc]
217216
other_bytes = as_bytes(other)
218217
other_fixed_bytes = self.__class__(other_bytes)
219-
result = self._operate_bitwise(other_fixed_bytes, "xor")
218+
result = self._operate_bitwise(other_fixed_bytes, operator.xor)
220219
assert isinstance(result, self.__class__)
221220
return result
222221

@@ -231,9 +230,8 @@ def __invert__(self) -> typing.Self:
231230
def _operate_bitwise(
232231
self,
233232
other: FixedBytes[_TBytesLength_Arg] | bytes | Bytes,
234-
operator_name: str,
233+
op: Callable[[int, int], int],
235234
) -> typing.Self | Bytes:
236-
op = getattr(operator, operator_name)
237235
maybe_bytes = as_bytes(other)
238236
# pad the shorter of self.value and other bytes with leading zero
239237
# by reversing them as zip_longest fills at the end
@@ -259,32 +257,46 @@ def from_base32(cls, value: str) -> typing.Self:
259257
"""Creates FixedBytes from a base32 encoded string e.g.
260258
`FixedBytes.from_base32("74======")`"""
261259
bytes_value = base64.b32decode(value)
262-
c = _create_class_from_int(cls, len(bytes_value)) if not hasattr(cls, "_length") else cls
263-
return c(bytes_value) # type: ignore[no-any-return]
260+
c = cls._ensure_class_with_length(bytes_value)
261+
return c(bytes_value)
264262

265263
@classmethod
266264
def from_base64(cls, value: str) -> typing.Self:
267265
"""Creates FixedBytes from a base64 encoded string e.g.
268266
`FixedBytes.from_base64("RkY=")`"""
269267
bytes_value = base64.b64decode(value)
270-
c = _create_class_from_int(cls, len(bytes_value)) if not hasattr(cls, "_length") else cls
271-
return c(bytes_value) # type: ignore[no-any-return]
268+
c = cls._ensure_class_with_length(bytes_value)
269+
return c(bytes_value)
272270

273271
@classmethod
274272
def from_hex(cls, value: str) -> typing.Self:
275-
"""Creates FixedBytes from a hex/octal encoded string e.g. `FixedBytes.from_hex("FF")`"""
273+
"""Creates FixedBytes from a hex/octal encoded string e.g.
274+
`FixedBytes.from_hex("FF")`"""
276275
bytes_value = base64.b16decode(value)
277-
c = _create_class_from_int(cls, len(bytes_value)) if not hasattr(cls, "_length") else cls
278-
return c(bytes_value) # type: ignore[no-any-return]
276+
c = cls._ensure_class_with_length(bytes_value)
277+
return c(bytes_value)
279278

280279
@classmethod
281280
def from_bytes(cls, value: Bytes | bytes) -> typing.Self:
282281
"""Construct an instance from the underlying bytes (no validation)"""
283282
bytes_value = as_bytes(value)
284-
c = _create_class_from_int(cls, len(bytes_value)) if not hasattr(cls, "_length") else cls
283+
c = cls._ensure_class_with_length(bytes_value)
285284
result = c()
286285
result.value = bytes_value
287-
return result # type: ignore[no-any-return]
286+
return result
287+
288+
@classmethod
289+
def _ensure_class_with_length(cls, bytes_value: bytes) -> type[typing.Self]:
290+
"""Returns the appropriate class for the given bytes value.
291+
292+
If cls has a fixed _length, returns cls. Otherwise, returns or creates a
293+
specialized class with the length set to match bytes_value.
294+
"""
295+
return (
296+
_get_or_create_class_from_int(cls, len(bytes_value))
297+
if not hasattr(cls, "_length")
298+
else cls
299+
)
288300

289301
@property
290302
def bytes(self) -> Bytes:

tests/models/test_box.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from _algopy_testing.primitives.biguint import BigUInt
2222
from _algopy_testing.primitives.bytes import Bytes
23+
from _algopy_testing.primitives.fixed_bytes import FixedBytes
2324
from _algopy_testing.primitives.string import String
2425
from _algopy_testing.primitives.uint64 import UInt64
2526
from _algopy_testing.state.box import Box
@@ -387,7 +388,7 @@ class Swapped2(Struct):
387388
b: Array[UInt64]
388389

389390

390-
def test_arrays_and_struct_in_boxes(context: AlgopyTestContext) -> None: # noqa: ARG001
391+
def test_arrays_and_struct_in_boxes(context: AlgopyTestContext) -> None: # noqa: ARG001, PLR0915
391392
# Array
392393
arr1 = Array([UInt64(1), UInt64(2), UInt64(3)])
393394
arr2 = Array([UInt64(4), UInt64(5), UInt64(6)])
@@ -460,6 +461,19 @@ def test_arrays_and_struct_in_boxes(context: AlgopyTestContext) -> None: # noqa
460461
box5.value.a[1] = UInt64(20)
461462
assert list(box5.value.a) == [UInt64(1), UInt64(20), UInt64(3)]
462463

464+
# FixedBytes in FixedArray
465+
box6 = Box(
466+
FixedArray[FixedBytes[typing.Literal[1024]], typing.Literal[10]], key=b"test_array_6"
467+
)
468+
box6.value = FixedArray(
469+
[FixedBytes[typing.Literal[1024]].from_hex(f"0{x}" * 1024) for x in range(10)]
470+
)
471+
assert box6.length == 10 * 1024
472+
assert box6.value[0].bytes == b"\x00" * 1024
473+
assert box6.value[9].bytes == b"\x09" * 1024
474+
box6.value[1] = FixedBytes[typing.Literal[1024]].from_hex("11" * 1024)
475+
assert box6.value[1].bytes == b"\x11" * 1024
476+
463477

464478
def test_box() -> None:
465479
with algopy_testing_context():

tests/primitives/test_fixed_bytes.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def test_fixed_bytes_init_wrong_length(value: bytes | Bytes, message: str) -> No
5454
)
5555
def test_fixed_bytes_bool(
5656
value: FixedBytes, # type: ignore[type-arg]
57-
expected: bool, # noqa: FBT001
57+
*,
58+
expected: bool,
5859
) -> None:
5960
assert bool(value) == expected
6061

@@ -596,6 +597,21 @@ def test_fixed_bytes_contains_edge_cases() -> None:
596597
assert b"testing" not in fb # longer than haystack
597598

598599

600+
def test_in_place_addition() -> None:
601+
"""Test that in-place addition is not supported."""
602+
a = FixedBytes[typing.Literal[4]](b"test")
603+
b = FixedBytes[typing.Literal[4]](b"data")
604+
605+
with pytest.raises(TypeError, match="FixedBytes does not support in-place addition"):
606+
a += b
607+
608+
with pytest.raises(TypeError, match="FixedBytes does not support in-place addition"):
609+
a += b"hello"
610+
611+
with pytest.raises(TypeError, match="FixedBytes does not support in-place addition"):
612+
a += Bytes(b"hello")
613+
614+
599615
def test_augmented_assignment() -> None:
600616
"""Test that augmented assignment operators are not supported."""
601617
a = FixedBytes[typing.Literal[2]](b"\x0f\x0f")

0 commit comments

Comments
 (0)