Skip to content

Commit

Permalink
[mypyc] Implement lowering for remaining tagged integer comparisons (#…
Browse files Browse the repository at this point in the history
…17040)

Support lowering of tagged integer `<`, `<=`, `>` and `>=` operations.

Previously we had separate code paths for integer comparisons in values
vs conditions. Unify these and remove the duplicate code path. The
different code paths produced subtly different code, but now they are
identical.

The generated code is now sometimes slightly more verbose in the slow
path (big integer). I may look into simplifying it in a follow-up PR.

This also makes the output of many irbuild test cases significantly more
compact.

Follow-up to #17027. Work on mypyc/mypyc#854.
  • Loading branch information
JukkaL authored Mar 19, 2024
1 parent 7d0a8e7 commit afdd9d5
Show file tree
Hide file tree
Showing 21 changed files with 622 additions and 968 deletions.
2 changes: 1 addition & 1 deletion mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def visit_primitive_op(self, op: PrimitiveOp) -> str:
type_arg_index += 1

args_str = ", ".join(args)
return self.format("%r = %s %s ", op, op.desc.name, args_str)
return self.format("%r = %s %s", op, op.desc.name, args_str)

def visit_truncate(self, op: Truncate) -> str:
return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)
Expand Down
9 changes: 3 additions & 6 deletions mypyc/irbuild/ast_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,9 @@ def maybe_process_conditional_comparison(
self.add_bool_branch(reg, true, false)
else:
# "left op right" for two tagged integers
if op in ("==", "!="):
reg = self.builder.binary_op(left, right, op, e.line)
self.flush_keep_alives()
self.add_bool_branch(reg, true, false)
else:
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
reg = self.builder.binary_op(left, right, op, e.line)
self.flush_keep_alives()
self.add_bool_branch(reg, true, false)
return True


Expand Down
3 changes: 0 additions & 3 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,6 @@ def call_c(self, desc: CFunctionDescription, args: list[Value], line: int) -> Va
def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int) -> Value:
return self.builder.int_op(type, lhs, rhs, op, line)

def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
return self.builder.compare_tagged(lhs, rhs, op, line)

def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
return self.builder.compare_tuples(lhs, rhs, op, line)

Expand Down
6 changes: 0 additions & 6 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,12 +814,6 @@ def translate_is_none(builder: IRBuilder, expr: Expression, negated: bool) -> Va
def transform_basic_comparison(
builder: IRBuilder, op: str, left: Value, right: Value, line: int
) -> Value:
if (
is_int_rprimitive(left.type)
and is_int_rprimitive(right.type)
and op in int_comparison_op_mapping
):
return builder.compare_tagged(left, right, op, line)
if is_fixed_width_rtype(left.type) and op in int_comparison_op_mapping:
if right.type == left.type:
if left.type.is_signed:
Expand Down
5 changes: 2 additions & 3 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,8 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None:
call_impl, next_impl = BasicBlock(), BasicBlock()

current_id = builder.load_int(i)
builder.builder.compare_tagged_condition(
passed_id, current_id, "==", call_impl, next_impl, line
)
cond = builder.binary_op(passed_id, current_id, "==", line)
builder.add_bool_branch(cond, call_impl, next_impl)

# Call the registered implementation
builder.activate_block(call_impl)
Expand Down
94 changes: 13 additions & 81 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,13 +1315,6 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
return self.compare_strings(lreg, rreg, op, line)
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ("==", "!="):
return self.compare_bytes(lreg, rreg, op, line)
if (
is_tagged(ltype)
and is_tagged(rtype)
and op in int_comparison_op_mapping
and op not in ("==", "!=")
):
return self.compare_tagged(lreg, rreg, op, line)
if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in BOOL_BINARY_OPS:
if op in ComparisonOp.signed_ops:
return self.bool_comparison_op(lreg, rreg, op, line)
Expand Down Expand Up @@ -1384,16 +1377,6 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
if is_fixed_width_rtype(lreg.type):
return self.comparison_op(lreg, rreg, op_id, line)

# Mixed int comparisons
if op in ("==", "!="):
pass # TODO: Do we need anything here?
elif op in op in int_comparison_op_mapping:
if is_tagged(ltype) and is_subtype(rtype, ltype):
rreg = self.coerce(rreg, short_int_rprimitive, line)
return self.compare_tagged(lreg, rreg, op, line)
if is_tagged(rtype) and is_subtype(ltype, rtype):
lreg = self.coerce(lreg, short_int_rprimitive, line)
return self.compare_tagged(lreg, rreg, op, line)
if is_float_rprimitive(ltype) or is_float_rprimitive(rtype):
if isinstance(lreg, Integer):
lreg = Float(float(lreg.numeric_value()))
Expand Down Expand Up @@ -1445,18 +1428,16 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op]
result = Register(bool_rprimitive)
short_int_block, int_block, out = BasicBlock(), BasicBlock(), BasicBlock()
check_lhs = self.check_tagged_short_int(lhs, line)
check_lhs = self.check_tagged_short_int(lhs, line, negated=True)
if op in ("==", "!="):
check = check_lhs
self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL))
else:
# for non-equality logical ops (less/greater than, etc.), need to check both sides
check_rhs = self.check_tagged_short_int(rhs, line)
check = self.int_op(bit_rprimitive, check_lhs, check_rhs, IntOp.AND, line)
self.add(Branch(check, short_int_block, int_block, Branch.BOOL))
self.activate_block(short_int_block)
eq = self.comparison_op(lhs, rhs, op_type, line)
self.add(Assign(result, eq, line))
self.goto(out)
short_lhs = BasicBlock()
self.add(Branch(check_lhs, int_block, short_lhs, Branch.BOOL))
self.activate_block(short_lhs)
check_rhs = self.check_tagged_short_int(rhs, line, negated=True)
self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL))
self.activate_block(int_block)
if swap_op:
args = [rhs, lhs]
Expand All @@ -1469,62 +1450,12 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
else:
call_result = call
self.add(Assign(result, call_result, line))
self.goto_and_activate(out)
return result

def compare_tagged_condition(
self, lhs: Value, rhs: Value, op: str, true: BasicBlock, false: BasicBlock, line: int
) -> None:
"""Compare two tagged integers using given operator (conditional context).
Assume lhs and rhs are tagged integers.
Args:
lhs: Left operand
rhs: Right operand
op: Operation, one of '==', '!=', '<', '<=', '>', '<='
true: Branch target if comparison is true
false: Branch target if comparison is false
"""
is_eq = op in ("==", "!=")
if (is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type)) or (
is_eq and (is_short_int_rprimitive(lhs.type) or is_short_int_rprimitive(rhs.type))
):
# We can skip the tag check
check = self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
self.flush_keep_alives()
self.add(Branch(check, true, false, Branch.BOOL))
return
op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op]
int_block, short_int_block = BasicBlock(), BasicBlock()
check_lhs = self.check_tagged_short_int(lhs, line, negated=True)
if is_eq or is_short_int_rprimitive(rhs.type):
self.flush_keep_alives()
self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL))
else:
# For non-equality logical ops (less/greater than, etc.), need to check both sides
rhs_block = BasicBlock()
self.add(Branch(check_lhs, int_block, rhs_block, Branch.BOOL))
self.activate_block(rhs_block)
check_rhs = self.check_tagged_short_int(rhs, line, negated=True)
self.flush_keep_alives()
self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL))
# Arbitrary integers (slow path)
self.activate_block(int_block)
if swap_op:
args = [rhs, lhs]
else:
args = [lhs, rhs]
call = self.call_c(c_func_desc, args, line)
if negate_result:
self.add(Branch(call, false, true, Branch.BOOL))
else:
self.flush_keep_alives()
self.add(Branch(call, true, false, Branch.BOOL))
# Short integers (fast path)
self.goto(out)
self.activate_block(short_int_block)
eq = self.comparison_op(lhs, rhs, op_type, line)
self.add(Branch(eq, true, false, Branch.BOOL))
self.add(Assign(result, eq, line))
self.goto_and_activate(out)
return result

def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
"""Compare two strings"""
Expand Down Expand Up @@ -2309,7 +2240,8 @@ def builtin_len(self, val: Value, line: int, use_pyssize_t: bool = False) -> Val
length = self.gen_method_call(val, "__len__", [], int_rprimitive, line)
length = self.coerce(length, int_rprimitive, line)
ok, fail = BasicBlock(), BasicBlock()
self.compare_tagged_condition(length, Integer(0), ">=", ok, fail, line)
cond = self.binary_op(length, Integer(0), ">=", line)
self.add_bool_branch(cond, ok, fail)
self.activate_block(fail)
self.add(
RaiseStandardError(
Expand Down
20 changes: 20 additions & 0 deletions mypyc/lower/int_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,23 @@ def lower_int_eq(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Va
@lower_binary_op("int_ne")
def lower_int_ne(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], "!=", line)


@lower_binary_op("int_lt")
def lower_int_lt(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], "<", line)


@lower_binary_op("int_le")
def lower_int_le(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], "<=", line)


@lower_binary_op("int_gt")
def lower_int_gt(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], ">", line)


@lower_binary_op("int_ge")
def lower_int_ge(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], ">=", line)
4 changes: 4 additions & 0 deletions mypyc/primitives/int_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def int_binary_primitive(

int_eq = int_binary_primitive(op="==", primitive_name="int_eq", return_type=bit_rprimitive)
int_ne = int_binary_primitive(op="!=", primitive_name="int_ne", return_type=bit_rprimitive)
int_lt = int_binary_primitive(op="<", primitive_name="int_lt", return_type=bit_rprimitive)
int_le = int_binary_primitive(op="<=", primitive_name="int_le", return_type=bit_rprimitive)
int_gt = int_binary_primitive(op=">", primitive_name="int_gt", return_type=bit_rprimitive)
int_ge = int_binary_primitive(op=">=", primitive_name="int_ge", return_type=bit_rprimitive)


def int_binary_op(
Expand Down
Loading

0 comments on commit afdd9d5

Please sign in to comment.