Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Implement lowering pass and add primitives for int (in)equality #17027

Merged
merged 15 commits into from
Mar 16, 2024
4 changes: 4 additions & 0 deletions mypyc/analysis/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
RegisterOp,
Return,
Expand Down Expand Up @@ -234,6 +235,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill[T]:
def visit_call_c(self, op: CallC) -> GenAndKill[T]:
return self.visit_register_op(op)

def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill[T]:
return self.visit_register_op(op)

def visit_truncate(self, op: Truncate) -> GenAndKill[T]:
return self.visit_register_op(op)

Expand Down
4 changes: 4 additions & 0 deletions mypyc/analysis/ircheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
Return,
Expand Down Expand Up @@ -381,6 +382,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
def visit_call_c(self, op: CallC) -> None:
pass

def visit_primitive_op(self, op: PrimitiveOp) -> None:
pass

def visit_truncate(self, op: Truncate) -> None:
pass

Expand Down
4 changes: 4 additions & 0 deletions mypyc/analysis/selfleaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
LoadStatic,
MethodCall,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
RegisterOp,
Expand Down Expand Up @@ -149,6 +150,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill:
def visit_call_c(self, op: CallC) -> GenAndKill:
return self.check_register_op(op)

def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill:
return self.check_register_op(op)

def visit_truncate(self, op: Truncate) -> GenAndKill:
return CLEAN

Expand Down
6 changes: 6 additions & 0 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
Return,
Expand Down Expand Up @@ -629,6 +630,11 @@ def visit_call_c(self, op: CallC) -> None:
args = ", ".join(self.reg(arg) for arg in op.args)
self.emitter.emit_line(f"{dest}{op.function_name}({args});")

def visit_primitive_op(self, op: PrimitiveOp) -> None:
raise RuntimeError(
f"unexpected PrimitiveOp {op.desc.name}: they must be lowered before codegen"
)

def visit_truncate(self, op: Truncate) -> None:
dest = self.reg(op)
value = self.reg(op.src)
Expand Down
3 changes: 3 additions & 0 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from mypyc.transform.copy_propagation import do_copy_propagation
from mypyc.transform.exceptions import insert_exception_handling
from mypyc.transform.flag_elimination import do_flag_elimination
from mypyc.transform.lower import lower_ir
from mypyc.transform.refcount import insert_ref_count_opcodes
from mypyc.transform.uninit import insert_uninit_checks

Expand Down Expand Up @@ -235,6 +236,8 @@ def compile_scc_to_ir(
insert_exception_handling(fn)
# Insert refcount handling.
insert_ref_count_opcodes(fn)
# Switch to lower abstraction level IR.
lower_ir(fn, compiler_options)
# Perform optimizations.
do_copy_propagation(fn, compiler_options)
do_flag_elimination(fn, compiler_options)
Expand Down
79 changes: 78 additions & 1 deletion mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,78 @@ def accept(self, visitor: OpVisitor[T]) -> T:
return visitor.visit_method_call(self)


class PrimitiveDescription:
"""Description of a primitive op.

Primitives get lowered into lower-level ops before code generation.

If c_function_name is provided, a primitive will be lowered into a CallC op.
Otherwise custom logic will need to be implemented to transform the
primitive into lower-level ops.
"""

def __init__(
self,
name: str,
arg_types: list[RType],
return_type: RType, # TODO: What about generic?
var_arg_type: RType | None,
truncated_type: RType | None,
c_function_name: str | None,
error_kind: int,
steals: StealsDescription,
is_borrowed: bool,
ordering: list[int] | None,
extra_int_constants: list[tuple[int, RType]],
priority: int,
) -> None:
# Each primitive much have a distinct name, but otherwise they are arbitrary.
self.name: Final = name
self.arg_types: Final = arg_types
self.return_type: Final = return_type
self.var_arg_type: Final = var_arg_type
self.truncated_type: Final = truncated_type
# If non-None, this will map to a call of a C helper function; if None,
# there must be a custom handler function that gets invoked during the lowering
# pass to generate low-level IR for the primitive (in the mypyc.lower package)
self.c_function_name: Final = c_function_name
self.error_kind: Final = error_kind
self.steals: Final = steals
self.is_borrowed: Final = is_borrowed
self.ordering: Final = ordering
self.extra_int_constants: Final = extra_int_constants
self.priority: Final = priority

def __repr__(self) -> str:
return f"<PrimitiveDescription {self.name}>"


class PrimitiveOp(RegisterOp):
"""A higher-level primitive operation.

Some of these have special compiler support. These will be lowered
(transformed) into lower-level IR ops before code generation, and after
reference counting op insertion. Others will be transformed into CallC
ops.

Tagged integer equality is a typical primitive op with non-trivial
lowering. It gets transformed into a tag check, followed by different
code paths for short and long representations.
"""

def __init__(self, args: list[Value], desc: PrimitiveDescription, line: int = -1) -> None:
self.args = args
self.type = desc.return_type
self.error_kind = desc.error_kind
self.desc = desc

def sources(self) -> list[Value]:
return self.args

def accept(self, visitor: OpVisitor[T]) -> T:
return visitor.visit_primitive_op(self)


class LoadErrorValue(RegisterOp):
"""Load an error value.

Expand Down Expand Up @@ -1446,7 +1518,8 @@ class Unborrow(RegisterOp):

error_kind = ERR_NEVER

def __init__(self, src: Value) -> None:
def __init__(self, src: Value, line: int = -1) -> None:
super().__init__(line)
assert src.is_borrowed
self.src = src
self.type = src.type
Expand Down Expand Up @@ -1555,6 +1628,10 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> T:
def visit_call_c(self, op: CallC) -> T:
raise NotImplementedError

@abstractmethod
def visit_primitive_op(self, op: PrimitiveOp) -> T:
raise NotImplementedError

@abstractmethod
def visit_truncate(self, op: Truncate) -> T:
raise NotImplementedError
Expand Down
17 changes: 17 additions & 0 deletions mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
Return,
Expand Down Expand Up @@ -217,6 +218,22 @@ def visit_call_c(self, op: CallC) -> str:
else:
return self.format("%r = %s(%s)", op, op.function_name, args_str)

def visit_primitive_op(self, op: PrimitiveOp) -> str:
args = []
arg_index = 0
type_arg_index = 0
for arg_type in zip(op.desc.arg_types):
if arg_type:
args.append(self.format("%r", op.args[arg_index]))
arg_index += 1
else:
assert op.type_args
args.append(self.format("%r", op.type_args[type_arg_index]))
type_arg_index += 1

args_str = ", ".join(args)
return self.format("%r = %s %s ", op, op.desc.name, args_str)

def visit_truncate(self, op: Truncate) -> str:
return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)

Expand Down
6 changes: 5 additions & 1 deletion mypyc/irbuild/ast_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ def maybe_process_conditional_comparison(
self.add_bool_branch(reg, true, false)
else:
# "left op right" for two tagged integers
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
if op in ("==", "!="):
reg = self.builder.binary_op(left, right, op, e.line)
self.add_bool_branch(reg, true, false)
else:
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
return True


Expand Down
4 changes: 2 additions & 2 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
set_literal = precompute_set_literal(builder, e.operands[1])
if set_literal is not None:
lhs = e.operands[0]
result = builder.builder.call_c(
result = builder.builder.primitive_op(
set_in_op, [builder.accept(lhs), set_literal], e.line, bool_rprimitive
)
if first_op == "not in":
Expand All @@ -778,7 +778,7 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
borrow_left = is_borrow_friendly_expr(builder, right_expr)
left = builder.accept(left_expr, can_borrow=borrow_left)
right = builder.accept(right_expr, can_borrow=True)
return builder.compare_tagged(left, right, first_op, e.line)
return builder.binary_op(left, right, first_op, e.line)

# TODO: Don't produce an expression when used in conditional context
# All of the trickiness here is due to support for chained conditionals
Expand Down
Loading
Loading