From 4c0aa48c3a79fdfe0c87375920c6df107e6cdb6a Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 15 Jul 2023 15:15:22 -0400 Subject: [PATCH] remove skip_nonpayable_check kwarg it's no longer needed! the nonpayable checks can be computed during selector section construction. --- vyper/codegen/function_definitions/common.py | 10 +--- .../function_definitions/external_function.py | 10 +--- vyper/codegen/module.py | 49 +++++++------------ 3 files changed, 21 insertions(+), 48 deletions(-) diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index b8d9a7eb8d..3fd5ce0b29 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -87,10 +87,7 @@ class InternalFuncIR(FuncIR): # TODO: should split this into external and internal ir generation? def generate_ir_for_function( - code: vy_ast.FunctionDef, - global_ctx: GlobalContext, - skip_nonpayable_check: bool, - is_ctor_context: bool = False, + code: vy_ast.FunctionDef, global_ctx: GlobalContext, is_ctor_context: bool = False ) -> FuncIR: """ Parse a function and produce IR code for the function, includes: @@ -130,13 +127,10 @@ def generate_ir_for_function( ) if func_t.is_internal: - assert skip_nonpayable_check is False ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context)) func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore else: - kwarg_handlers, common = generate_ir_for_external_function( - code, func_t, context, skip_nonpayable_check - ) + kwarg_handlers, common = generate_ir_for_external_function(code, func_t, context) entry_points = { k: EntryPointInfo(func_t, mincalldatasize, ir_node) for k, (mincalldatasize, ir_node) in kwarg_handlers.items() diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 76156e4ebd..dadea75a86 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -142,7 +142,7 @@ def handler_for(calldata_kwargs, default_kwargs): # TODO it would be nice if this returned a data structure which were # amenable to generating a jump table instead of the linear search for # method_id we have now. -def generate_ir_for_external_function(code, func_t, context, skip_nonpayable_check): +def generate_ir_for_external_function(code, func_t, context): # TODO type hints: # def generate_ir_for_external_function( # code: vy_ast.FunctionDef, @@ -167,14 +167,6 @@ def generate_ir_for_external_function(code, func_t, context, skip_nonpayable_che # generate the main body of the function body += handle_base_args - if not func_t.is_payable and not skip_nonpayable_check: - # if the contract contains payable functions, but this is not one of them - # add an assertion that the value of the call is zero - nonpayable_check = IRnode.from_list( - ["assert", ["iszero", "callvalue"]], error_msg="nonpayable check" - ) - body.append(nonpayable_check) - body += nonreentrant_pre body += [parse_body(code.body, context, ensure_terminated=True)] diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 403fac5a6b..8e66405c50 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -60,30 +60,23 @@ def label_for_entry_point(abi_sig, entry_point): return f"{entry_point.func_t._ir_info.ir_identifier}{method_id}" -# TODO: probably dead code -def _ir_for_external_function(func_ast, *args, **kwargs): - # adapt whatever generate_ir_for_function gives us into an IR node - ret = ["seq"] +# adapt whatever generate_ir_for_function gives us into an IR node +def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): func_t = func_ast._metadata["type"] - func_ir = generate_ir_for_function(func_ast, *args, **kwargs) + assert func_t.is_fallback or func_t.is_constructor - if func_t.is_fallback or func_t.is_constructor: - assert len(func_ir.entry_points) == 1 - # add a goto to make the function entry look like other functions - # (for zksync interpreter) - ret.append(["goto", func_t._ir_info.external_function_base_entry_label]) - ret.append(func_ir.common_ir) + ret = ["seq"] + if not func_t.is_payable: + callvalue_check = ["assert", ["iszero", "callvalue"]] + ret.append(IRnode.from_list(callvalue_check, error_msg="nonpayable check")) - else: - for sig, ir_node in func_ir.entry_points.items(): - method_id = _annotated_method_id(sig) - ret.append(["if", ["eq", "_calldata_method_id", method_id], ir_node]) + func_ir = generate_ir_for_function(func_ast, *args, **kwargs) + assert len(func_ir.entry_points) == 1 - # stick function common body into final entry point to save a jump - # TODO: this would not really be necessary if we had basic block - # reordering in optimizer. - ir_node = ["seq", ir_node, func_ir.common_ir] - func_ir.entry_points[sig] = ir_node + # add a goto to make the function entry look like other functions + # (for zksync interpreter) + ret.append(["goto", func_t._ir_info.external_function_base_entry_label]) + ret.append(func_ir.common_ir) return IRnode.from_list(ret) @@ -107,7 +100,7 @@ def _selector_section_dense(external_functions, global_ctx): return IRnode.from_list(["seq"]) for code in external_functions: - func_ir = generate_ir_for_function(code, global_ctx, skip_nonpayable_check=True) + func_ir = generate_ir_for_function(code, global_ctx) for abi_sig, entry_point in func_ir.entry_points.items(): assert abi_sig not in entry_points entry_points[abi_sig] = entry_point @@ -264,7 +257,7 @@ def _selector_section_sparse(external_functions, global_ctx): return selector_section for code in external_functions: - func_ir = generate_ir_for_function(code, global_ctx, skip_nonpayable_check=True) + func_ir = generate_ir_for_function(code, global_ctx) for abi_sig, entry_point in func_ir.entry_points.items(): assert abi_sig not in entry_points entry_points[abi_sig] = entry_point @@ -392,9 +385,7 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: selector_section = _selector_section_sparse(external_functions, global_ctx) if default_function: - fallback_ir = _ir_for_external_function( - default_function, global_ctx, skip_nonpayable_check=False - ) + fallback_ir = _ir_for_fallback_or_ctor(default_function, global_ctx) else: fallback_ir = IRnode.from_list( ["revert", 0, 0], annotation="Default function", error_msg="fallback function" @@ -408,9 +399,7 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: immutables_len = global_ctx.immutable_section_bytes if init_function: # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = _ir_for_external_function( - init_function, global_ctx, skip_nonpayable_check=False, is_ctor_context=True - ) + init_func_ir = _ir_for_fallback_or_ctor(init_function, global_ctx, is_ctor_context=True) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables @@ -446,9 +435,7 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: # unreachable code, delete it continue - func_ir = _ir_for_internal_function( - f, global_ctx, skip_nonpayable_check=False, is_ctor_context=True - ) + func_ir = _ir_for_internal_function(f, global_ctx, is_ctor_context=True) deploy_code.append(func_ir) else: