Skip to content

Commit

Permalink
remove skip_nonpayable_check kwarg
Browse files Browse the repository at this point in the history
it's no longer needed! the nonpayable checks can be computed during
selector section construction.
  • Loading branch information
charles-cooper committed Jul 15, 2023
1 parent 37456c3 commit 4c0aa48
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 48 deletions.
10 changes: 2 additions & 8 deletions vyper/codegen/function_definitions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@ class InternalFuncIR(FuncIR):

# TODO: should split this into external and internal ir generation?
def generate_ir_for_function(
code: vy_ast.FunctionDef,
global_ctx: GlobalContext,
skip_nonpayable_check: bool,
is_ctor_context: bool = False,
code: vy_ast.FunctionDef, global_ctx: GlobalContext, is_ctor_context: bool = False
) -> FuncIR:
"""
Parse a function and produce IR code for the function, includes:
Expand Down Expand Up @@ -130,13 +127,10 @@ def generate_ir_for_function(
)

if func_t.is_internal:
assert skip_nonpayable_check is False
ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context))
func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore
else:
kwarg_handlers, common = generate_ir_for_external_function(
code, func_t, context, skip_nonpayable_check
)
kwarg_handlers, common = generate_ir_for_external_function(code, func_t, context)
entry_points = {
k: EntryPointInfo(func_t, mincalldatasize, ir_node)
for k, (mincalldatasize, ir_node) in kwarg_handlers.items()
Expand Down
10 changes: 1 addition & 9 deletions vyper/codegen/function_definitions/external_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def handler_for(calldata_kwargs, default_kwargs):
# TODO it would be nice if this returned a data structure which were
# amenable to generating a jump table instead of the linear search for
# method_id we have now.
def generate_ir_for_external_function(code, func_t, context, skip_nonpayable_check):
def generate_ir_for_external_function(code, func_t, context):
# TODO type hints:
# def generate_ir_for_external_function(
# code: vy_ast.FunctionDef,
Expand All @@ -167,14 +167,6 @@ def generate_ir_for_external_function(code, func_t, context, skip_nonpayable_che
# generate the main body of the function
body += handle_base_args

if not func_t.is_payable and not skip_nonpayable_check:
# if the contract contains payable functions, but this is not one of them
# add an assertion that the value of the call is zero
nonpayable_check = IRnode.from_list(
["assert", ["iszero", "callvalue"]], error_msg="nonpayable check"
)
body.append(nonpayable_check)

body += nonreentrant_pre

body += [parse_body(code.body, context, ensure_terminated=True)]
Expand Down
49 changes: 18 additions & 31 deletions vyper/codegen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,30 +60,23 @@ def label_for_entry_point(abi_sig, entry_point):
return f"{entry_point.func_t._ir_info.ir_identifier}{method_id}"


# TODO: probably dead code
def _ir_for_external_function(func_ast, *args, **kwargs):
# adapt whatever generate_ir_for_function gives us into an IR node
ret = ["seq"]
# adapt whatever generate_ir_for_function gives us into an IR node
def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs):
func_t = func_ast._metadata["type"]
func_ir = generate_ir_for_function(func_ast, *args, **kwargs)
assert func_t.is_fallback or func_t.is_constructor

if func_t.is_fallback or func_t.is_constructor:
assert len(func_ir.entry_points) == 1
# add a goto to make the function entry look like other functions
# (for zksync interpreter)
ret.append(["goto", func_t._ir_info.external_function_base_entry_label])
ret.append(func_ir.common_ir)
ret = ["seq"]
if not func_t.is_payable:
callvalue_check = ["assert", ["iszero", "callvalue"]]
ret.append(IRnode.from_list(callvalue_check, error_msg="nonpayable check"))

else:
for sig, ir_node in func_ir.entry_points.items():
method_id = _annotated_method_id(sig)
ret.append(["if", ["eq", "_calldata_method_id", method_id], ir_node])
func_ir = generate_ir_for_function(func_ast, *args, **kwargs)
assert len(func_ir.entry_points) == 1

# stick function common body into final entry point to save a jump
# TODO: this would not really be necessary if we had basic block
# reordering in optimizer.
ir_node = ["seq", ir_node, func_ir.common_ir]
func_ir.entry_points[sig] = ir_node
# add a goto to make the function entry look like other functions
# (for zksync interpreter)
ret.append(["goto", func_t._ir_info.external_function_base_entry_label])
ret.append(func_ir.common_ir)

return IRnode.from_list(ret)

Expand All @@ -107,7 +100,7 @@ def _selector_section_dense(external_functions, global_ctx):
return IRnode.from_list(["seq"])

for code in external_functions:
func_ir = generate_ir_for_function(code, global_ctx, skip_nonpayable_check=True)
func_ir = generate_ir_for_function(code, global_ctx)
for abi_sig, entry_point in func_ir.entry_points.items():
assert abi_sig not in entry_points
entry_points[abi_sig] = entry_point
Expand Down Expand Up @@ -264,7 +257,7 @@ def _selector_section_sparse(external_functions, global_ctx):
return selector_section

for code in external_functions:
func_ir = generate_ir_for_function(code, global_ctx, skip_nonpayable_check=True)
func_ir = generate_ir_for_function(code, global_ctx)
for abi_sig, entry_point in func_ir.entry_points.items():
assert abi_sig not in entry_points
entry_points[abi_sig] = entry_point
Expand Down Expand Up @@ -392,9 +385,7 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]:
selector_section = _selector_section_sparse(external_functions, global_ctx)

if default_function:
fallback_ir = _ir_for_external_function(
default_function, global_ctx, skip_nonpayable_check=False
)
fallback_ir = _ir_for_fallback_or_ctor(default_function, global_ctx)
else:
fallback_ir = IRnode.from_list(
["revert", 0, 0], annotation="Default function", error_msg="fallback function"
Expand All @@ -408,9 +399,7 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]:
immutables_len = global_ctx.immutable_section_bytes
if init_function:
# TODO might be cleaner to separate this into an _init_ir helper func
init_func_ir = _ir_for_external_function(
init_function, global_ctx, skip_nonpayable_check=False, is_ctor_context=True
)
init_func_ir = _ir_for_fallback_or_ctor(init_function, global_ctx, is_ctor_context=True)

# pass the amount of memory allocated for the init function
# so that deployment does not clobber while preparing immutables
Expand Down Expand Up @@ -446,9 +435,7 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]:
# unreachable code, delete it
continue

func_ir = _ir_for_internal_function(
f, global_ctx, skip_nonpayable_check=False, is_ctor_context=True
)
func_ir = _ir_for_internal_function(f, global_ctx, is_ctor_context=True)
deploy_code.append(func_ir)

else:
Expand Down

0 comments on commit 4c0aa48

Please sign in to comment.