Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Implement lowering for remaining tagged integer comparisons #17040

Merged
merged 11 commits into from
Mar 19, 2024
2 changes: 1 addition & 1 deletion mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def visit_primitive_op(self, op: PrimitiveOp) -> str:
type_arg_index += 1

args_str = ", ".join(args)
return self.format("%r = %s %s ", op, op.desc.name, args_str)
return self.format("%r = %s %s", op, op.desc.name, args_str)

def visit_truncate(self, op: Truncate) -> str:
return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)
Expand Down
9 changes: 3 additions & 6 deletions mypyc/irbuild/ast_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,9 @@ def maybe_process_conditional_comparison(
self.add_bool_branch(reg, true, false)
else:
# "left op right" for two tagged integers
if op in ("==", "!="):
reg = self.builder.binary_op(left, right, op, e.line)
self.flush_keep_alives()
self.add_bool_branch(reg, true, false)
else:
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
reg = self.builder.binary_op(left, right, op, e.line)
self.flush_keep_alives()
self.add_bool_branch(reg, true, false)
return True


Expand Down
3 changes: 0 additions & 3 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,6 @@ def call_c(self, desc: CFunctionDescription, args: list[Value], line: int) -> Va
def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int) -> Value:
return self.builder.int_op(type, lhs, rhs, op, line)

def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
return self.builder.compare_tagged(lhs, rhs, op, line)

def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
return self.builder.compare_tuples(lhs, rhs, op, line)

Expand Down
6 changes: 0 additions & 6 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,12 +814,6 @@ def translate_is_none(builder: IRBuilder, expr: Expression, negated: bool) -> Va
def transform_basic_comparison(
builder: IRBuilder, op: str, left: Value, right: Value, line: int
) -> Value:
if (
is_int_rprimitive(left.type)
and is_int_rprimitive(right.type)
and op in int_comparison_op_mapping
):
return builder.compare_tagged(left, right, op, line)
if is_fixed_width_rtype(left.type) and op in int_comparison_op_mapping:
if right.type == left.type:
if left.type.is_signed:
Expand Down
5 changes: 2 additions & 3 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,8 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None:
call_impl, next_impl = BasicBlock(), BasicBlock()

current_id = builder.load_int(i)
builder.builder.compare_tagged_condition(
passed_id, current_id, "==", call_impl, next_impl, line
)
cond = builder.binary_op(passed_id, current_id, "==", line)
builder.add_bool_branch(cond, call_impl, next_impl)

# Call the registered implementation
builder.activate_block(call_impl)
Expand Down
94 changes: 13 additions & 81 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,13 +1315,6 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
return self.compare_strings(lreg, rreg, op, line)
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ("==", "!="):
return self.compare_bytes(lreg, rreg, op, line)
if (
is_tagged(ltype)
and is_tagged(rtype)
and op in int_comparison_op_mapping
and op not in ("==", "!=")
):
return self.compare_tagged(lreg, rreg, op, line)
if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in BOOL_BINARY_OPS:
if op in ComparisonOp.signed_ops:
return self.bool_comparison_op(lreg, rreg, op, line)
Expand Down Expand Up @@ -1384,16 +1377,6 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
if is_fixed_width_rtype(lreg.type):
return self.comparison_op(lreg, rreg, op_id, line)

# Mixed int comparisons
if op in ("==", "!="):
pass # TODO: Do we need anything here?
elif op in op in int_comparison_op_mapping:
if is_tagged(ltype) and is_subtype(rtype, ltype):
rreg = self.coerce(rreg, short_int_rprimitive, line)
return self.compare_tagged(lreg, rreg, op, line)
if is_tagged(rtype) and is_subtype(ltype, rtype):
lreg = self.coerce(lreg, short_int_rprimitive, line)
return self.compare_tagged(lreg, rreg, op, line)
if is_float_rprimitive(ltype) or is_float_rprimitive(rtype):
if isinstance(lreg, Integer):
lreg = Float(float(lreg.numeric_value()))
Expand Down Expand Up @@ -1445,18 +1428,16 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op]
result = Register(bool_rprimitive)
short_int_block, int_block, out = BasicBlock(), BasicBlock(), BasicBlock()
check_lhs = self.check_tagged_short_int(lhs, line)
check_lhs = self.check_tagged_short_int(lhs, line, negated=True)
if op in ("==", "!="):
check = check_lhs
self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL))
else:
# for non-equality logical ops (less/greater than, etc.), need to check both sides
check_rhs = self.check_tagged_short_int(rhs, line)
check = self.int_op(bit_rprimitive, check_lhs, check_rhs, IntOp.AND, line)
self.add(Branch(check, short_int_block, int_block, Branch.BOOL))
self.activate_block(short_int_block)
eq = self.comparison_op(lhs, rhs, op_type, line)
self.add(Assign(result, eq, line))
self.goto(out)
short_lhs = BasicBlock()
self.add(Branch(check_lhs, int_block, short_lhs, Branch.BOOL))
self.activate_block(short_lhs)
check_rhs = self.check_tagged_short_int(rhs, line, negated=True)
self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL))
self.activate_block(int_block)
if swap_op:
args = [rhs, lhs]
Expand All @@ -1469,62 +1450,12 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
else:
call_result = call
self.add(Assign(result, call_result, line))
self.goto_and_activate(out)
return result

def compare_tagged_condition(
self, lhs: Value, rhs: Value, op: str, true: BasicBlock, false: BasicBlock, line: int
) -> None:
"""Compare two tagged integers using given operator (conditional context).

Assume lhs and rhs are tagged integers.

Args:
lhs: Left operand
rhs: Right operand
op: Operation, one of '==', '!=', '<', '<=', '>', '<='
true: Branch target if comparison is true
false: Branch target if comparison is false
"""
is_eq = op in ("==", "!=")
if (is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type)) or (
is_eq and (is_short_int_rprimitive(lhs.type) or is_short_int_rprimitive(rhs.type))
):
# We can skip the tag check
check = self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
self.flush_keep_alives()
self.add(Branch(check, true, false, Branch.BOOL))
return
op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op]
int_block, short_int_block = BasicBlock(), BasicBlock()
check_lhs = self.check_tagged_short_int(lhs, line, negated=True)
if is_eq or is_short_int_rprimitive(rhs.type):
self.flush_keep_alives()
self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL))
else:
# For non-equality logical ops (less/greater than, etc.), need to check both sides
rhs_block = BasicBlock()
self.add(Branch(check_lhs, int_block, rhs_block, Branch.BOOL))
self.activate_block(rhs_block)
check_rhs = self.check_tagged_short_int(rhs, line, negated=True)
self.flush_keep_alives()
self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL))
# Arbitrary integers (slow path)
self.activate_block(int_block)
if swap_op:
args = [rhs, lhs]
else:
args = [lhs, rhs]
call = self.call_c(c_func_desc, args, line)
if negate_result:
self.add(Branch(call, false, true, Branch.BOOL))
else:
self.flush_keep_alives()
self.add(Branch(call, true, false, Branch.BOOL))
# Short integers (fast path)
self.goto(out)
self.activate_block(short_int_block)
eq = self.comparison_op(lhs, rhs, op_type, line)
self.add(Branch(eq, true, false, Branch.BOOL))
self.add(Assign(result, eq, line))
self.goto_and_activate(out)
return result

def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
"""Compare two strings"""
Expand Down Expand Up @@ -2309,7 +2240,8 @@ def builtin_len(self, val: Value, line: int, use_pyssize_t: bool = False) -> Val
length = self.gen_method_call(val, "__len__", [], int_rprimitive, line)
length = self.coerce(length, int_rprimitive, line)
ok, fail = BasicBlock(), BasicBlock()
self.compare_tagged_condition(length, Integer(0), ">=", ok, fail, line)
cond = self.binary_op(length, Integer(0), ">=", line)
self.add_bool_branch(cond, ok, fail)
self.activate_block(fail)
self.add(
RaiseStandardError(
Expand Down
20 changes: 20 additions & 0 deletions mypyc/lower/int_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,23 @@ def lower_int_eq(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Va
@lower_binary_op("int_ne")
def lower_int_ne(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], "!=", line)


@lower_binary_op("int_lt")
def lower_int_lt(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], "<", line)


@lower_binary_op("int_le")
def lower_int_le(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], "<=", line)


@lower_binary_op("int_gt")
def lower_int_gt(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], ">", line)


@lower_binary_op("int_ge")
def lower_int_ge(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], ">=", line)
4 changes: 4 additions & 0 deletions mypyc/primitives/int_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def int_binary_primitive(

int_eq = int_binary_primitive(op="==", primitive_name="int_eq", return_type=bit_rprimitive)
int_ne = int_binary_primitive(op="!=", primitive_name="int_ne", return_type=bit_rprimitive)
int_lt = int_binary_primitive(op="<", primitive_name="int_lt", return_type=bit_rprimitive)
int_le = int_binary_primitive(op="<=", primitive_name="int_le", return_type=bit_rprimitive)
int_gt = int_binary_primitive(op=">", primitive_name="int_gt", return_type=bit_rprimitive)
int_ge = int_binary_primitive(op=">=", primitive_name="int_ge", return_type=bit_rprimitive)


def int_binary_op(
Expand Down
Loading
Loading