Skip to content

[mypyc] Refactor IR building for generator functions #19008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 2, 2025
34 changes: 30 additions & 4 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ def flatten_classes(self, arg: RefExpr | TupleExpr) -> list[ClassIR] | None:
return None
return res

def enter(self, fn_info: FuncInfo | str = "") -> None:
def enter(self, fn_info: FuncInfo | str = "", *, ret_type: RType = none_rprimitive) -> None:
if isinstance(fn_info, str):
fn_info = FuncInfo(name=fn_info)
self.builder = LowLevelIRBuilder(self.errors, self.options)
Expand All @@ -1179,7 +1179,7 @@ def enter(self, fn_info: FuncInfo | str = "") -> None:
self.runtime_args.append([])
self.fn_info = fn_info
self.fn_infos.append(self.fn_info)
self.ret_types.append(none_rprimitive)
self.ret_types.append(ret_type)
if fn_info.is_generator:
self.nonlocal_control.append(GeneratorNonlocalControl())
else:
Expand Down Expand Up @@ -1219,10 +1219,9 @@ def enter_method(
self_type: If not None, override default type of the implicit 'self'
argument (by default, derive type from class_ir)
"""
self.enter(fn_info)
self.enter(fn_info, ret_type=ret_type)
self.function_name_stack.append(name)
self.class_ir_stack.append(class_ir)
self.ret_types[-1] = ret_type
if self_type is None:
self_type = RInstance(class_ir)
self.add_argument(SELF_NAME, self_type)
Expand Down Expand Up @@ -1498,3 +1497,30 @@ def create_type_params(
builder.init_type_var(tv, type_param.name, line)
tvs.append(tv)
return tvs


def calculate_arg_defaults(
builder: IRBuilder,
fn_info: FuncInfo,
func_reg: Value | None,
symtable: dict[SymbolNode, SymbolTarget],
) -> None:
"""Calculate default argument values and store them.

They are stored in statics for top level functions and in
the function objects for nested functions (while constants are
still stored computed on demand).
"""
fitem = fn_info.fitem
for arg in fitem.arguments:
# Constant values don't get stored but just recomputed
if arg.initializer and not is_constant(arg.initializer):
value = builder.coerce(
builder.accept(arg.initializer), symtable[arg.variable].type, arg.line
)
if not fn_info.is_nested:
name = fitem.fullname + "." + arg.variable.name
builder.add(InitStatic(value, name, builder.module_name))
else:
assert func_reg is not None
builder.add(SetAttr(func_reg, arg.variable.name, value, arg.line))
35 changes: 35 additions & 0 deletions mypyc/irbuild/env_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,41 @@ def add_args_to_env(
builder.add_var_to_env_class(arg.variable, rtype, base, reassign=reassign)


def add_vars_to_env(builder: IRBuilder) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this code here since this is used for both regular and generator functions.

"""Add relevant local variables and nested functions to the environment class.

Add all variables and functions that are declared/defined within current
function and are referenced in functions nested within this one to this
function's environment class so the nested functions can reference
them even if they are declared after the nested function's definition.
Note that this is done before visiting the body of the function.
"""
env_for_func: FuncInfo | ImplicitClass = builder.fn_info
if builder.fn_info.is_generator:
env_for_func = builder.fn_info.generator_class
elif builder.fn_info.is_nested or builder.fn_info.in_non_ext:
env_for_func = builder.fn_info.callable_class

if builder.fn_info.fitem in builder.free_variables:
# Sort the variables to keep things deterministic
for var in sorted(builder.free_variables[builder.fn_info.fitem], key=lambda x: x.name):
if isinstance(var, Var):
rtype = builder.type_to_rtype(var.type)
builder.add_var_to_env_class(var, rtype, env_for_func, reassign=False)

if builder.fn_info.fitem in builder.encapsulating_funcs:
for nested_fn in builder.encapsulating_funcs[builder.fn_info.fitem]:
if isinstance(nested_fn, FuncDef):
# The return type is 'object' instead of an RInstance of the
# callable class because differently defined functions with
# the same name and signature across conditional blocks
# will generate different callable classes, so the callable
# class that gets instantiated must be generic.
builder.add_var_to_env_class(
nested_fn, object_rprimitive, env_for_func, reassign=False
)


def setup_func_for_recursive_call(builder: IRBuilder, fdef: FuncDef, base: ImplicitClass) -> None:
"""Enable calling a nested function (with a callable class) recursively.

Expand Down
174 changes: 45 additions & 129 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
FuncItem,
LambdaExpr,
OverloadedFuncDef,
SymbolNode,
TypeInfo,
Var,
)
Expand All @@ -44,7 +43,6 @@
from mypyc.ir.ops import (
BasicBlock,
GetAttr,
InitStatic,
Integer,
LoadAddress,
LoadLiteral,
Expand All @@ -62,31 +60,22 @@
int_rprimitive,
object_rprimitive,
)
from mypyc.irbuild.builder import IRBuilder, SymbolTarget, gen_arg_defaults
from mypyc.irbuild.builder import IRBuilder, calculate_arg_defaults, gen_arg_defaults
from mypyc.irbuild.callable_class import (
add_call_to_callable_class,
add_get_to_callable_class,
instantiate_callable_class,
setup_callable_class,
)
from mypyc.irbuild.context import FuncInfo, ImplicitClass
from mypyc.irbuild.context import FuncInfo
from mypyc.irbuild.env_class import (
add_vars_to_env,
finalize_env_class,
load_env_registers,
load_outer_envs,
setup_env_class,
setup_func_for_recursive_call,
)
from mypyc.irbuild.generator import (
add_methods_to_generator_class,
add_raise_exception_blocks_to_generator_class,
create_switch_for_generator_class,
gen_generator_func,
populate_switch_for_generator_class,
setup_env_for_generator_class,
)
from mypyc.irbuild.generator import gen_generator_func, gen_generator_func_body
from mypyc.irbuild.targets import AssignmentTarget
from mypyc.irbuild.util import is_constant
from mypyc.primitives.dict_ops import dict_get_method_with_none, dict_new_op, dict_set_item_op
from mypyc.primitives.generic_ops import py_setattr_op
from mypyc.primitives.misc_ops import register_function
Expand Down Expand Up @@ -235,123 +224,77 @@ def c() -> None:
func_name = singledispatch_main_func_name(name)
else:
func_name = name
builder.enter(
FuncInfo(
fitem=fitem,
name=func_name,
class_name=class_name,
namespace=gen_func_ns(builder),
is_nested=is_nested,
contains_nested=contains_nested,
is_decorated=is_decorated,
in_non_ext=in_non_ext,
add_nested_funcs_to_env=add_nested_funcs_to_env,
)

fn_info = FuncInfo(
fitem=fitem,
name=func_name,
class_name=class_name,
namespace=gen_func_ns(builder),
is_nested=is_nested,
contains_nested=contains_nested,
is_decorated=is_decorated,
in_non_ext=in_non_ext,
add_nested_funcs_to_env=add_nested_funcs_to_env,
)
is_generator = fn_info.is_generator
builder.enter(fn_info, ret_type=sig.ret_type)

# Functions that contain nested functions need an environment class to store variables that
# are free in their nested functions. Generator functions need an environment class to
# store a variable denoting the next instruction to be executed when the __next__ function
# is called, along with all the variables inside the function itself.
if builder.fn_info.contains_nested or builder.fn_info.is_generator:
if contains_nested or is_generator:
setup_env_class(builder)

if builder.fn_info.is_nested or builder.fn_info.in_non_ext:
if is_nested or in_non_ext:
setup_callable_class(builder)

if builder.fn_info.is_generator:
# Do a first-pass and generate a function that just returns a generator object.
gen_generator_func(builder)
args, _, blocks, ret_type, fn_info = builder.leave()
func_ir, func_reg = gen_func_ir(
builder, args, blocks, sig, fn_info, cdef, is_singledispatch
if is_generator:
# First generate a function that just constructs and returns a generator object.
func_ir, func_reg = gen_generator_func(
builder,
lambda args, blocks, fn_info: gen_func_ir(
builder, args, blocks, sig, fn_info, cdef, is_singledispatch
),
)

# Re-enter the FuncItem and visit the body of the function this time.
builder.enter(fn_info)
setup_env_for_generator_class(builder)

load_outer_envs(builder, builder.fn_info.generator_class)
top_level = builder.top_level_fn_info()
if (
builder.fn_info.is_nested
and isinstance(fitem, FuncDef)
and top_level
and top_level.add_nested_funcs_to_env
):
setup_func_for_recursive_call(builder, fitem, builder.fn_info.generator_class)
create_switch_for_generator_class(builder)
add_raise_exception_blocks_to_generator_class(builder, fitem.line)
gen_generator_func_body(builder, fn_info, sig, func_reg)
else:
load_env_registers(builder)
gen_arg_defaults(builder)
func_ir, func_reg = gen_func_body(builder, sig, cdef, is_singledispatch)

if builder.fn_info.contains_nested and not builder.fn_info.is_generator:
finalize_env_class(builder)
if is_singledispatch:
# add the generated main singledispatch function
builder.functions.append(func_ir)
# create the dispatch function
assert isinstance(fitem, FuncDef)
return gen_dispatch_func_ir(builder, fitem, fn_info.name, name, sig)

builder.ret_types[-1] = sig.ret_type
return func_ir, func_reg

# Add all variables and functions that are declared/defined within this
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following code was moved to add_vars_to_env.

# function and are referenced in functions nested within this one to this
# function's environment class so the nested functions can reference
# them even if they are declared after the nested function's definition.
# Note that this is done before visiting the body of this function.

env_for_func: FuncInfo | ImplicitClass = builder.fn_info
if builder.fn_info.is_generator:
env_for_func = builder.fn_info.generator_class
elif builder.fn_info.is_nested or builder.fn_info.in_non_ext:
env_for_func = builder.fn_info.callable_class

if builder.fn_info.fitem in builder.free_variables:
# Sort the variables to keep things deterministic
for var in sorted(builder.free_variables[builder.fn_info.fitem], key=lambda x: x.name):
if isinstance(var, Var):
rtype = builder.type_to_rtype(var.type)
builder.add_var_to_env_class(var, rtype, env_for_func, reassign=False)

if builder.fn_info.fitem in builder.encapsulating_funcs:
for nested_fn in builder.encapsulating_funcs[builder.fn_info.fitem]:
if isinstance(nested_fn, FuncDef):
# The return type is 'object' instead of an RInstance of the
# callable class because differently defined functions with
# the same name and signature across conditional blocks
# will generate different callable classes, so the callable
# class that gets instantiated must be generic.
builder.add_var_to_env_class(
nested_fn, object_rprimitive, env_for_func, reassign=False
)

builder.accept(fitem.body)
def gen_func_body(
builder: IRBuilder, sig: FuncSignature, cdef: ClassDef | None, is_singledispatch: bool
) -> tuple[FuncIR, Value | None]:
load_env_registers(builder)
gen_arg_defaults(builder)
if builder.fn_info.contains_nested:
finalize_env_class(builder)
add_vars_to_env(builder)
builder.accept(builder.fn_info.fitem.body)
builder.maybe_add_implicit_return()

if builder.fn_info.is_generator:
populate_switch_for_generator_class(builder)

# Hang on to the local symbol table for a while, since we use it
# to calculate argument defaults below.
symtable = builder.symtables[-1]

args, _, blocks, ret_type, fn_info = builder.leave()

if fn_info.is_generator:
add_methods_to_generator_class(builder, fn_info, sig, args, blocks, fitem.is_coroutine)
else:
func_ir, func_reg = gen_func_ir(
builder, args, blocks, sig, fn_info, cdef, is_singledispatch
)
func_ir, func_reg = gen_func_ir(builder, args, blocks, sig, fn_info, cdef, is_singledispatch)

# Evaluate argument defaults in the surrounding scope, since we
# calculate them *once* when the function definition is evaluated.
calculate_arg_defaults(builder, fn_info, func_reg, symtable)

if is_singledispatch:
# add the generated main singledispatch function
builder.functions.append(func_ir)
# create the dispatch function
assert isinstance(fitem, FuncDef)
return gen_dispatch_func_ir(builder, fitem, fn_info.name, name, sig)

return func_ir, func_reg


Expand Down Expand Up @@ -512,33 +455,6 @@ def handle_non_ext_method(
builder.add_to_non_ext_dict(non_ext, name, func_reg, fdef.line)


def calculate_arg_defaults(
builder: IRBuilder,
fn_info: FuncInfo,
func_reg: Value | None,
symtable: dict[SymbolNode, SymbolTarget],
) -> None:
"""Calculate default argument values and store them.

They are stored in statics for top level functions and in
the function objects for nested functions (while constants are
still stored computed on demand).
"""
fitem = fn_info.fitem
for arg in fitem.arguments:
# Constant values don't get stored but just recomputed
if arg.initializer and not is_constant(arg.initializer):
value = builder.coerce(
builder.accept(arg.initializer), symtable[arg.variable].type, arg.line
)
if not fn_info.is_nested:
name = fitem.fullname + "." + arg.variable.name
builder.add(InitStatic(value, name, builder.module_name))
else:
assert func_reg is not None
builder.add(SetAttr(func_reg, arg.variable.name, value, arg.line))


def gen_func_ns(builder: IRBuilder) -> str:
"""Generate a namespace for a nested function using its outer function names."""
return "_".join(
Expand Down
Loading