Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix overflow checking for operations with mixed sign #9403

Merged
merged 3 commits into from
Jun 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions spec/std/overflow_spec.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{% skip_file unless compare_versions(Crystal::VERSION, "0.35.0-0") > 0 %}

require "big"
require "spec"

{% for i in Int::Signed.union_types %}
struct {{i}}
TEST_CASES = [MIN, MIN &+ 1, MIN &+ 2, -1, 0, 1, MAX &- 2, MAX &- 1, MAX] of {{i}}
end
{% end %}

{% for i in Int::Unsigned.union_types %}
struct {{i}}
TEST_CASES = [MIN, MIN &+ 1, MIN &+ 2, MAX &- 2, MAX &- 1, MAX] of {{i}}
end
{% end %}

macro run_op_tests(t, u, op)
it "overflow test #{{{t}}} #{{{op}}} #{{{u}}}" do
{{t}}::TEST_CASES.each do |lhs|
{{u}}::TEST_CASES.each do |rhs|
result = lhs.to_big_i {{op.id}} rhs.to_big_i
passes = {{t}}::MIN <= result <= {{t}}::MAX
begin
if passes
(lhs {{op.id}} rhs).should eq(lhs &{{op.id}} rhs)
else
expect_raises(OverflowError) { lhs {{op.id}} rhs }
end
rescue e : Spec::AssertionFailed
raise Spec::AssertionFailed.new("#{e.message}: #{lhs} #{{{op}}} #{rhs}", e.file, e.line)
rescue e : OverflowError
raise OverflowError.new("#{e.message}: #{lhs} #{{{op}}} #{rhs}")
end
end
end
end
end

{% if flag?(:darwin) %}
private OVERFLOW_TEST_TYPES = [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, UInt128]
{% else %}
private OVERFLOW_TEST_TYPES = [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64]
{% end %}

describe "overflow" do
{% for t in OVERFLOW_TEST_TYPES %}
{% for u in OVERFLOW_TEST_TYPES %}
run_op_tests {{t}}, {{u}}, :+
run_op_tests {{t}}, {{u}}, :-
run_op_tests {{t}}, {{u}}, :*
{% end %}
{% end %}
end
214 changes: 79 additions & 135 deletions src/compiler/crystal/codegen/primitives.cr
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,15 @@ class Crystal::CodeGenVisitor
else # go on
end

case op
when "+", "-", "*"
return codegen_binary_op_with_overflow(op, t1, t2, p1, p2)
else # go on
end

tmax, p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2)

case op
when "+" then codegen_binary_op_add(tmax, t1, t2, p1, p2)
when "-" then codegen_binary_op_sub(tmax, t1, t2, p1, p2)
when "*" then codegen_binary_op_mul(tmax, t1, t2, p1, p2)
when "&+" then codegen_trunc_binary_op_result(t1, t2, builder.add(p1, p2))
when "&-" then codegen_trunc_binary_op_result(t1, t2, builder.sub(p1, p2))
when "&*" then codegen_trunc_binary_op_result(t1, t2, builder.mul(p1, p2))
Expand All @@ -151,6 +154,79 @@ class Crystal::CodeGenVisitor
end
end

def codegen_binary_op_with_overflow(op, t1, t2, p1, p2)
if op == "*"
if t1.unsigned? && t2.signed?
return codegen_mul_unsigned_signed_with_overflow(t1, t2, p1, p2)
elsif t1.signed? && t2.unsigned?
return codegen_mul_signed_unsigned_with_overflow(t1, t2, p1, p2)
end
end

calc_signed = t1.signed? || t2.signed?
calc_width = {t1, t2}.map { |t| t.bytes * 8 + ((calc_signed && t.unsigned?) ? 1 : 0) }.max
calc_type = llvm_context.int(calc_width)

e1 = t1.signed? ? builder.sext(p1, calc_type) : builder.zext(p1, calc_type)
e2 = t2.signed? ? builder.sext(p2, calc_type) : builder.zext(p2, calc_type)

llvm_op =
case {calc_signed, op}
when {false, "+"} then "uadd"
when {false, "-"} then "usub"
when {false, "*"} then "umul"
when {true, "+"} then "sadd"
when {true, "-"} then "ssub"
when {true, "*"} then "smul"
else raise "BUG: unknown overflow op"
end

llvm_fun = binary_overflow_fun "llvm.#{llvm_op}.with.overflow.i#{calc_width}", calc_type
res_with_overflow = builder.call(llvm_fun, [e1, e2])

result = extract_value res_with_overflow, 0
overflow = extract_value res_with_overflow, 1

if calc_width > t1.bytes * 8
result_trunc = trunc result, llvm_type(t1)
result_trunc_ext = t1.signed? ? builder.sext(result_trunc, calc_type) : builder.zext(result_trunc, calc_type)
overflow = or(overflow, builder.icmp LLVM::IntPredicate::NE, result, result_trunc_ext)
end

codegen_raise_overflow_cond overflow

trunc result, llvm_type(t1)
end

def codegen_mul_unsigned_signed_with_overflow(t1, t2, p1, p2)
overflow = and(
codegen_binary_op_ne(t1, t1, p1, int(0, t1)), # self != 0
codegen_binary_op_lt(t2, t2, p2, int(0, t2)) # other < 0
)
codegen_raise_overflow_cond overflow

return codegen_binary_op_with_overflow("*", t1, @program.int_type(false, t2.bytes), p1, p2)
end

def codegen_mul_signed_unsigned_with_overflow(t1, t2, p1, p2)
negative = codegen_binary_op_lt(t1, t1, p1, int(0, t1)) # self < 0
minus_p1 = builder.sub int(0, t1), p1
abs = builder.select negative, minus_p1, p1
u1 = @program.int_type(false, t1.bytes)

# tmp is the abs value of the result
# there is overflow when |result| > max + (negative ? 1 : 0)
tmp = codegen_binary_op_with_overflow("*", u1, t2, abs, p2)
_, max = t1.range
max_result = builder.add(int(max, t1), builder.zext(negative, llvm_type(t1)))
overflow = codegen_binary_op_gt(u1, u1, tmp, max_result)
codegen_raise_overflow_cond overflow

# negate back the result if p1 was negative
minus_tmp = builder.sub int(0, t1), tmp
builder.select negative, minus_tmp, tmp
end

def codegen_binary_extend_int(t1, t2, p1, p2)
if t1.normal_rank == t2.normal_rank
# Nothing to do
Expand All @@ -176,138 +252,6 @@ class Crystal::CodeGenVisitor
end
end

def codegen_binary_op_add(t : IntegerType, t1, t2, p1, p2)
llvm_fun = case t.kind
when :i8
binary_overflow_fun "llvm.sadd.with.overflow.i8", llvm_context.int8
when :i16
binary_overflow_fun "llvm.sadd.with.overflow.i16", llvm_context.int16
when :i32
binary_overflow_fun "llvm.sadd.with.overflow.i32", llvm_context.int32
when :i64
binary_overflow_fun "llvm.sadd.with.overflow.i64", llvm_context.int64
when :i128
binary_overflow_fun "llvm.sadd.with.overflow.i128", llvm_context.int128
when :u8
binary_overflow_fun "llvm.uadd.with.overflow.i8", llvm_context.int8
when :u16
binary_overflow_fun "llvm.uadd.with.overflow.i16", llvm_context.int16
when :u32
binary_overflow_fun "llvm.uadd.with.overflow.i32", llvm_context.int32
when :u64
binary_overflow_fun "llvm.uadd.with.overflow.i64", llvm_context.int64
when :u128
binary_overflow_fun "llvm.uadd.with.overflow.i128", llvm_context.int128
else
raise "unreachable"
end

codegen_binary_overflow_check(llvm_fun, t, t1, t2, p1, p2)
end

def codegen_binary_op_sub(t : IntegerType, t1, t2, p1, p2)
llvm_fun = case t.kind
when :i8
binary_overflow_fun "llvm.ssub.with.overflow.i8", llvm_context.int8
when :i16
binary_overflow_fun "llvm.ssub.with.overflow.i16", llvm_context.int16
when :i32
binary_overflow_fun "llvm.ssub.with.overflow.i32", llvm_context.int32
when :i64
binary_overflow_fun "llvm.ssub.with.overflow.i64", llvm_context.int64
when :i128
binary_overflow_fun "llvm.ssub.with.overflow.i128", llvm_context.int128
when :u8
binary_overflow_fun "llvm.usub.with.overflow.i8", llvm_context.int8
when :u16
binary_overflow_fun "llvm.usub.with.overflow.i16", llvm_context.int16
when :u32
binary_overflow_fun "llvm.usub.with.overflow.i32", llvm_context.int32
when :u64
binary_overflow_fun "llvm.usub.with.overflow.i64", llvm_context.int64
when :u128
binary_overflow_fun "llvm.usub.with.overflow.i128", llvm_context.int128
else
raise "unreachable"
end

codegen_binary_overflow_check(llvm_fun, t, t1, t2, p1, p2)
end

def codegen_binary_op_mul(t : IntegerType, t1, t2, p1, p2)
llvm_fun = case t.kind
when :i8
binary_overflow_fun "llvm.smul.with.overflow.i8", llvm_context.int8
when :i16
binary_overflow_fun "llvm.smul.with.overflow.i16", llvm_context.int16
when :i32
binary_overflow_fun "llvm.smul.with.overflow.i32", llvm_context.int32
when :i64
binary_overflow_fun "llvm.smul.with.overflow.i64", llvm_context.int64
when :i128
binary_overflow_fun "llvm.smul.with.overflow.i128", llvm_context.int128
when :u8
binary_overflow_fun "llvm.umul.with.overflow.i8", llvm_context.int8
when :u16
binary_overflow_fun "llvm.umul.with.overflow.i16", llvm_context.int16
when :u32
binary_overflow_fun "llvm.umul.with.overflow.i32", llvm_context.int32
when :u64
binary_overflow_fun "llvm.umul.with.overflow.i64", llvm_context.int64
when :u128
binary_overflow_fun "llvm.umul.with.overflow.i128", llvm_context.int128
else
raise "unreachable"
end

codegen_binary_overflow_check(llvm_fun, t, t1, t2, p1, p2)
end

# Generates a call to llvm_fun(p1, p2).
# t1, t2 are the original types of p1, p2.
# t is the super type of t1 and t2 where the operation is performed.
# llvm_fun returns {res, o_bit} where the o_bit signals overflow.
# The generated code also performs a range check and truncation of res
# in order to fit in the original type t1 if needed.
#
# ```
# %res_with_overflow = call {T, i1} <llvm_fun>(T %p1, T %p2)
# %res = extractvalue {T, i1} %res, 0
# %o_bit = extractvalue {T, i1} %res, 1
# ;; if T != T1
# %out_of_range = %res < T1::MIN || %res > T1::MAX ;; compare T1.range and %res
# br i1 or(%o_bit, %out_of_range), label %overflow, label %normal
# ;; else
# br i1 %o_bit, label %overflow, label %normal
# ;; end
#
# overflow:
# ;; codegen: raise OverflowError.new with caller's location
#
# normal:
# ;; if T != T1
# ;; %res' is returned
# %res' = trunc T %res to T1
# ;; else
# ;; %res is returned
# ;; end
# ```
private def codegen_binary_overflow_check(llvm_fun, t : IntegerType, t1, t2, p1, p2)
res_with_overflow = builder.call(llvm_fun, [p1, p2])

res = extract_value res_with_overflow, 0
o_bit = extract_value res_with_overflow, 1

if t != t1
overflow = or(o_bit, codegen_out_of_range(t1, t, res))
else
overflow = o_bit
end

codegen_raise_overflow_cond overflow
codegen_trunc_binary_op_result(t1, t2, res)
end

private def codegen_out_of_range(target_type : IntegerType, arg_type : IntegerType, arg)
min_value, max_value = target_type.range
# arg < min_value || arg > max_value
Expand Down
24 changes: 24 additions & 0 deletions src/compiler/crystal/program.cr
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,30 @@ module Crystal
end
end

def int_type(signed, size)
if signed
case size
when 1 then int8
when 2 then int16
when 4 then int32
when 8 then int64
when 16 then int128
else
raise "BUG: Invalid int size: #{size}"
end
else
case size
when 1 then uint8
when 2 then uint16
when 4 then uint32
when 8 then uint64
when 16 then uint128
else
raise "BUG: Invalid int size: #{size}"
end
end
end

# Returns the `IntegerType` that matches the given Int value
def int?(int)
case int
Expand Down