Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[test]: add more coverage to abi_decode fuzzer tests #4153

Merged
merged 7 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 105 additions & 19 deletions tests/functional/builtins/codegen/test_abi_decode_fuzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
IntegerT,
SArrayT,
StringT,
StructT,
TupleT,
VyperType,
_get_primitive_types,
_get_sequence_types,
)
from vyper.semantics.types.shortcuts import UINT256_T

from .abi_decode import DecodeError, spec_decode

Expand All @@ -39,7 +39,7 @@
continue
type_ctors.append(t)

complex_static_ctors = [SArrayT, TupleT]
complex_static_ctors = [SArrayT, TupleT, StructT]
complex_dynamic_ctors = [DArrayT]
leaf_ctors = [t for t in type_ctors if t not in _get_sequence_types().values()]
static_leaf_ctors = [t for t in leaf_ctors if t._is_prim_word]
Expand All @@ -50,10 +50,12 @@

@st.composite
# max type nesting
def vyper_type(draw, nesting=3, skip=None):
def vyper_type(draw, nesting=3, skip=None, source_fragments=None):
assert nesting >= 0

skip = skip or []
if source_fragments is None:
source_fragments = []

st_leaves = st.one_of(st.sampled_from(dynamic_leaf_ctors), st.sampled_from(static_leaf_ctors))
st_complex = st.one_of(
Expand All @@ -71,39 +73,52 @@ def vyper_type(draw, nesting=3, skip=None):
# note: maybe st.deferred is good here, we could define it with
# mutual recursion
def _go(skip=skip):
return draw(vyper_type(nesting=nesting - 1, skip=skip))
_, typ = draw(vyper_type(nesting=nesting - 1, skip=skip, source_fragments=source_fragments))
return typ

def finalize(typ):
return source_fragments, typ

if t in (BytesT, StringT):
# arbitrary max_value
bound = draw(st.integers(min_value=1, max_value=1024))
return t(bound)
return finalize(t(bound))

if t == SArrayT:
subtype = _go(skip=[TupleT, BytesT, StringT])
bound = draw(st.integers(min_value=1, max_value=6))
return t(subtype, bound)
return finalize(t(subtype, bound))
if t == DArrayT:
subtype = _go(skip=[TupleT])
bound = draw(st.integers(min_value=1, max_value=16))
return t(subtype, bound)
return finalize(t(subtype, bound))

if t == TupleT:
# zero-length tuples are not allowed in vyper
n = draw(st.integers(min_value=1, max_value=6))
subtypes = [_go() for _ in range(n)]
return TupleT(subtypes)
return finalize(TupleT(subtypes))

if t == StructT:
n = draw(st.integers(min_value=1, max_value=6))
subtypes = {f"x{i}": _go() for i in range(n)}
_id = len(source_fragments) # poor man's unique id
name = f"MyStruct{_id}"
typ = StructT(name, subtypes)
source_fragments.append(typ.def_source_str())
return finalize(StructT(name, subtypes))

if t in (BoolT, AddressT):
return t()
return finalize(t())

if t == IntegerT:
signed = draw(st.booleans())
bits = 8 * draw(st.integers(min_value=1, max_value=32))
return t(signed, bits)
return finalize(t(signed, bits))

if t == BytesM_T:
m = draw(st.integers(min_value=1, max_value=32))
return t(m)
return finalize(t(m))

raise RuntimeError("unreachable")

Expand All @@ -116,6 +131,9 @@ def _go(t):
if isinstance(typ, TupleT):
return tuple(_go(item_t) for item_t in typ.member_types)

if isinstance(typ, StructT):
return tuple(_go(item_t) for item_t in typ.tuple_members())

if isinstance(typ, SArrayT):
return [_go(typ.value_type) for _ in range(typ.length)]

Expand Down Expand Up @@ -294,6 +312,13 @@ def _finalize(): # little trick to save re-typing the arguments
num_dynamic_types = sum(s.num_dynamic_types for s in substats)
return _finalize()

if isinstance(typ, StructT):
substats = [_type_stats(t) for t in typ.tuple_members()]
nesting = 1 + max(s.nesting for s in substats)
breadth = max(len(typ.member_types), *[s.breadth for s in substats])
num_dynamic_types = sum(s.num_dynamic_types for s in substats)
return _finalize()

if isinstance(typ, DArrayT):
substat = _type_stats(typ.value_type)
nesting = 1 + substat.nesting
Expand Down Expand Up @@ -332,8 +357,8 @@ def payload_copier(get_contract_from_ir):
@pytest.mark.parametrize("_n", list(range(PARALLELISM)))
@hp.given(typ=vyper_type())
@hp.settings(max_examples=100, **_settings)
@hp.example(typ=DArrayT(DArrayT(UINT256_T, 2), 2))
def test_abi_decode_fuzz(_n, typ, get_contract, tx_failed, payload_copier):
def test_abi_decode_fuzz(_n, typ, get_contract, tx_failed, payload_copier, env):
source_fragments, typ = typ
# import time
# t0 = time.time()
# print("ENTER", typ)
Expand All @@ -350,12 +375,13 @@ def test_abi_decode_fuzz(_n, typ, get_contract, tx_failed, payload_copier):
# by bytes length check at function entry
type_bound = wrapped_type.abi_type.size_bound()
buffer_bound = type_bound + MAX_MUTATIONS
type_str = repr(typ) # annotation in vyper code
# TODO: intrinsic decode from staticcall/extcall
# TODO: _abi_decode from other sources (staticcall/extcall?)
# TODO: dirty the buffer
# TODO: check unwrap_tuple=False

preamble = "\n\n".join(source_fragments)
type_str = str(typ) # annotation in vyper code

code = f"""
{preamble}

@external
def run(xs: Bytes[{buffer_bound}]) -> {type_str}:
ret: {type_str} = abi_decode(xs, {type_str})
Expand All @@ -375,14 +401,20 @@ def run3(xs: Bytes[{buffer_bound}], copier: Foo) -> {type_str}:
assert len(xs) <= {type_bound}
return (extcall copier.bar(xs))
"""
try:
c = get_contract(code)
except EvmError as e:
if env.contract_size_limit_error in str(e):
hp.assume(False)
# print(code)
hp.note(code)
c = get_contract(code)

@hp.given(data=payload_from(wrapped_type))
@hp.settings(max_examples=100, **_settings)
def _fuzz(data):
hp.note(f"type: {typ}")
hp.note(f"abi_t: {wrapped_type.abi_type.selector_name()}")
hp.note(code)
hp.note(data.hex())

try:
Expand Down Expand Up @@ -414,3 +446,57 @@ def _fuzz(data):

# t1 = time.time()
# print(f"elapsed {t1 - t0}s")


@pytest.mark.parametrize("_n", list(range(PARALLELISM)))
@hp.given(typ=vyper_type())
@hp.settings(max_examples=100, **_settings)
def test_abi_decode_no_wrap_fuzz(_n, typ, get_contract, tx_failed, env):
source_fragments, typ = typ
# import time
# t0 = time.time()
# print("ENTER", typ)

stats = _type_stats(typ)
hp.target(stats.num_dynamic_types)

# add max_mutations bytes worth of padding so we don't just get caught
# by bytes length check at function entry
type_bound = typ.abi_type.size_bound()
buffer_bound = type_bound + MAX_MUTATIONS

type_str = str(typ) # annotation in vyper code
preamble = "\n\n".join(source_fragments)

code = f"""
{preamble}

@external
def run(xs: Bytes[{buffer_bound}]) -> {type_str}:
ret: {type_str} = abi_decode(xs, {type_str}, unwrap_tuple=False)
return ret
"""
try:
c = get_contract(code)
except EvmError as e:
if env.contract_size_limit_error in str(e):
hp.assume(False)

@hp.given(data=payload_from(typ))
@hp.settings(max_examples=100, **_settings)
def _fuzz(data):
hp.note(code)
hp.note(data.hex())
try:
expected = spec_decode(typ, data)
hp.note(f"expected {expected}")
assert expected == c.run(data)
except DecodeError:
hp.note("expect failure")
with tx_failed(EvmError):
c.run(data)

_fuzz()

# t1 = time.time()
# print(f"elapsed {t1 - t0}s")
11 changes: 10 additions & 1 deletion vyper/semantics/types/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,11 @@ def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT":

return cls(struct_name, members, ast_def=base_node)

def __str__(self):
return f"{self._id}"

def __repr__(self):
return f"{self._id} declaration object"
return f"{self._id} {self.members}"

def _try_fold(self, node):
if len(node.args) != 1:
Expand All @@ -384,6 +387,12 @@ def _try_fold(self, node):
# it can't be reduced, but this lets upstream code know it's constant
return node

def def_source_str(self):
ret = f"struct {self._id}:\n"
for k, v in self.member_types.items():
ret += f" {k}: {v}\n"
return ret

@property
def size_in_bytes(self):
return sum(i.size_in_bytes for i in self.member_types.values())
Expand Down
Loading