Skip to content

Commit

Permalink
Auto merge of #124113 - RalfJung:interpret-scalar-ops, r=oli-obk
Browse files Browse the repository at this point in the history
interpret: use ScalarInt for bin-ops; avoid PartialOrd for ScalarInt

Best reviewed commit-by-commit

r? `@oli-obk`
  • Loading branch information
bors committed Apr 19, 2024
2 parents d1a0fa5 + d3f927d commit ce3263e
Show file tree
Hide file tree
Showing 15 changed files with 211 additions and 152 deletions.
41 changes: 19 additions & 22 deletions compiler/rustc_codegen_cranelift/src/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ pub(crate) fn codegen_const_value<'tcx>(
if fx.clif_type(layout.ty).is_some() {
return CValue::const_val(fx, layout, int);
} else {
let raw_val = int.size().truncate(int.to_bits(int.size()).unwrap());
let raw_val = int.size().truncate(int.assert_bits(int.size()));
let val = match int.size().bytes() {
1 => fx.bcx.ins().iconst(types::I8, raw_val as i64),
2 => fx.bcx.ins().iconst(types::I16, raw_val as i64),
Expand Down Expand Up @@ -491,27 +491,24 @@ pub(crate) fn mir_operand_get_const_val<'tcx>(
return None;
}
let scalar_int = mir_operand_get_const_val(fx, operand)?;
let scalar_int = match fx
.layout_of(*ty)
.size
.cmp(&scalar_int.size())
{
Ordering::Equal => scalar_int,
Ordering::Less => match ty.kind() {
ty::Uint(_) => ScalarInt::try_from_uint(
scalar_int.try_to_uint(scalar_int.size()).unwrap(),
fx.layout_of(*ty).size,
)
.unwrap(),
ty::Int(_) => ScalarInt::try_from_int(
scalar_int.try_to_int(scalar_int.size()).unwrap(),
fx.layout_of(*ty).size,
)
.unwrap(),
_ => unreachable!(),
},
Ordering::Greater => return None,
};
let scalar_int =
match fx.layout_of(*ty).size.cmp(&scalar_int.size()) {
Ordering::Equal => scalar_int,
Ordering::Less => match ty.kind() {
ty::Uint(_) => ScalarInt::try_from_uint(
scalar_int.assert_uint(scalar_int.size()),
fx.layout_of(*ty).size,
)
.unwrap(),
ty::Int(_) => ScalarInt::try_from_int(
scalar_int.assert_int(scalar_int.size()),
fx.layout_of(*ty).size,
)
.unwrap(),
_ => unreachable!(),
},
Ordering::Greater => return None,
};
computed_scalar_int = Some(scalar_int);
}
Rvalue::Use(operand) => {
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_codegen_cranelift/src/value_and_place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ impl<'tcx> CValue<'tcx> {

let val = match layout.ty.kind() {
ty::Uint(UintTy::U128) | ty::Int(IntTy::I128) => {
let const_val = const_val.to_bits(layout.size).unwrap();
let const_val = const_val.assert_bits(layout.size);
let lsb = fx.bcx.ins().iconst(types::I64, const_val as u64 as i64);
let msb = fx.bcx.ins().iconst(types::I64, (const_val >> 64) as u64 as i64);
fx.bcx.ins().iconcat(lsb, msb)
Expand All @@ -338,7 +338,7 @@ impl<'tcx> CValue<'tcx> {
| ty::Ref(..)
| ty::RawPtr(..)
| ty::FnPtr(..) => {
let raw_val = const_val.size().truncate(const_val.to_bits(layout.size).unwrap());
let raw_val = const_val.size().truncate(const_val.assert_bits(layout.size));
fx.bcx.ins().iconst(clif_ty, raw_val as i64)
}
ty::Float(FloatTy::F32) => {
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_const_eval/src/interpret/discriminant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
&niche_start_val,
)?
.to_scalar()
.try_to_int()
.unwrap();
.assert_int();
Ok(Some((tag, tag_field)))
}
}
Expand Down
24 changes: 22 additions & 2 deletions compiler/rustc_const_eval/src/interpret/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ use std::assert_matches::assert_matches;
use either::{Either, Left, Right};

use rustc_hir::def::Namespace;
use rustc_middle::mir::interpret::ScalarSizeMismatch;
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use rustc_middle::ty::print::{FmtPrinter, PrettyPrinter};
use rustc_middle::ty::{ConstInt, Ty, TyCtxt};
use rustc_middle::ty::{ConstInt, ScalarInt, Ty, TyCtxt};
use rustc_middle::{mir, ty};
use rustc_target::abi::{self, Abi, HasDataLayout, Size};

Expand Down Expand Up @@ -210,6 +211,12 @@ impl<'tcx, Prov: Provenance> ImmTy<'tcx, Prov> {
ImmTy { imm: Immediate::Uninit, layout }
}

#[inline]
pub fn from_scalar_int(s: ScalarInt, layout: TyAndLayout<'tcx>) -> Self {
assert_eq!(s.size(), layout.size);
Self::from_scalar(Scalar::from(s), layout)
}

#[inline]
pub fn try_from_uint(i: impl Into<u128>, layout: TyAndLayout<'tcx>) -> Option<Self> {
Some(Self::from_scalar(Scalar::try_from_uint(i, layout.size)?, layout))
Expand All @@ -223,7 +230,6 @@ impl<'tcx, Prov: Provenance> ImmTy<'tcx, Prov> {
pub fn try_from_int(i: impl Into<i128>, layout: TyAndLayout<'tcx>) -> Option<Self> {
Some(Self::from_scalar(Scalar::try_from_int(i, layout.size)?, layout))
}

#[inline]
pub fn from_int(i: impl Into<i128>, layout: TyAndLayout<'tcx>) -> Self {
Self::from_scalar(Scalar::from_int(i, layout.size), layout)
Expand All @@ -242,6 +248,20 @@ impl<'tcx, Prov: Provenance> ImmTy<'tcx, Prov> {
Self::from_scalar(Scalar::from_i8(c as i8), layout)
}

/// Return the immediate as a `ScalarInt`. Ensures that it has the size that the layout of the
/// immediate indicates.
#[inline]
pub fn to_scalar_int(&self) -> InterpResult<'tcx, ScalarInt> {
let s = self.to_scalar().to_scalar_int()?;
if s.size() != self.layout.size {
throw_ub!(ScalarSizeMismatch(ScalarSizeMismatch {
target_size: self.layout.size.bytes(),
data_size: s.size().bytes(),
}));
}
Ok(s)
}

#[inline]
pub fn to_const_int(self) -> ConstInt {
assert!(self.layout.ty.is_integral());
Expand Down
113 changes: 59 additions & 54 deletions compiler/rustc_const_eval/src/interpret/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use rustc_apfloat::{Float, FloatConvert};
use rustc_middle::mir;
use rustc_middle::mir::interpret::{InterpResult, Scalar};
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, FloatTy, Ty};
use rustc_middle::ty::{self, FloatTy, ScalarInt, Ty};
use rustc_span::symbol::sym;
use rustc_target::abi::Abi;

Expand Down Expand Up @@ -146,14 +146,20 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
fn binary_int_op(
&self,
bin_op: mir::BinOp,
// passing in raw bits
l: u128,
left_layout: TyAndLayout<'tcx>,
r: u128,
right_layout: TyAndLayout<'tcx>,
left: &ImmTy<'tcx, M::Provenance>,
right: &ImmTy<'tcx, M::Provenance>,
) -> InterpResult<'tcx, (ImmTy<'tcx, M::Provenance>, bool)> {
use rustc_middle::mir::BinOp::*;

// This checks the size, so that we can just assert it below.
let l = left.to_scalar_int()?;
let r = right.to_scalar_int()?;
// Prepare to convert the values to signed or unsigned form.
let l_signed = || l.assert_int(left.layout.size);
let l_unsigned = || l.assert_uint(left.layout.size);
let r_signed = || r.assert_int(right.layout.size);
let r_unsigned = || r.assert_uint(right.layout.size);

let throw_ub_on_overflow = match bin_op {
AddUnchecked => Some(sym::unchecked_add),
SubUnchecked => Some(sym::unchecked_sub),
Expand All @@ -165,69 +171,72 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {

// Shift ops can have an RHS with a different numeric type.
if matches!(bin_op, Shl | ShlUnchecked | Shr | ShrUnchecked) {
let size = left_layout.size.bits();
let size = left.layout.size.bits();
// The shift offset is implicitly masked to the type size. (This is the one MIR operator
// that does *not* directly map to a single LLVM operation.) Compute how much we
// actually shift and whether there was an overflow due to shifting too much.
let (shift_amount, overflow) = if right_layout.abi.is_signed() {
let shift_amount = self.sign_extend(r, right_layout) as i128;
let (shift_amount, overflow) = if right.layout.abi.is_signed() {
let shift_amount = r_signed();
let overflow = shift_amount < 0 || shift_amount >= i128::from(size);
// Deliberately wrapping `as` casts: shift_amount *can* be negative, but the result
// of the `as` will be equal modulo `size` (since it is a power of two).
let masked_amount = (shift_amount as u128) % u128::from(size);
debug_assert_eq!(overflow, shift_amount != (masked_amount as i128));
assert_eq!(overflow, shift_amount != (masked_amount as i128));
(masked_amount, overflow)
} else {
let shift_amount = r;
let shift_amount = r_unsigned();
let masked_amount = shift_amount % u128::from(size);
(masked_amount, shift_amount != masked_amount)
};
let shift_amount = u32::try_from(shift_amount).unwrap(); // we masked so this will always fit
// Compute the shifted result.
let result = if left_layout.abi.is_signed() {
let l = self.sign_extend(l, left_layout) as i128;
let result = if left.layout.abi.is_signed() {
let l = l_signed();
let result = match bin_op {
Shl | ShlUnchecked => l.checked_shl(shift_amount).unwrap(),
Shr | ShrUnchecked => l.checked_shr(shift_amount).unwrap(),
_ => bug!(),
};
result as u128
ScalarInt::truncate_from_int(result, left.layout.size).0
} else {
match bin_op {
let l = l_unsigned();
let result = match bin_op {
Shl | ShlUnchecked => l.checked_shl(shift_amount).unwrap(),
Shr | ShrUnchecked => l.checked_shr(shift_amount).unwrap(),
_ => bug!(),
}
};
ScalarInt::truncate_from_uint(result, left.layout.size).0
};
let truncated = self.truncate(result, left_layout);

if overflow && let Some(intrinsic_name) = throw_ub_on_overflow {
throw_ub_custom!(
fluent::const_eval_overflow_shift,
val = if right_layout.abi.is_signed() {
(self.sign_extend(r, right_layout) as i128).to_string()
val = if right.layout.abi.is_signed() {
r_signed().to_string()
} else {
r.to_string()
r_unsigned().to_string()
},
name = intrinsic_name
);
}

return Ok((ImmTy::from_uint(truncated, left_layout), overflow));
return Ok((ImmTy::from_scalar_int(result, left.layout), overflow));
}

// For the remaining ops, the types must be the same on both sides
if left_layout.ty != right_layout.ty {
if left.layout.ty != right.layout.ty {
span_bug!(
self.cur_span(),
"invalid asymmetric binary op {bin_op:?}: {l:?} ({l_ty}), {r:?} ({r_ty})",
l_ty = left_layout.ty,
r_ty = right_layout.ty,
l_ty = left.layout.ty,
r_ty = right.layout.ty,
)
}

let size = left_layout.size;
let size = left.layout.size;

// Operations that need special treatment for signed integers
if left_layout.abi.is_signed() {
if left.layout.abi.is_signed() {
let op: Option<fn(&i128, &i128) -> bool> = match bin_op {
Lt => Some(i128::lt),
Le => Some(i128::le),
Expand All @@ -236,18 +245,14 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
_ => None,
};
if let Some(op) = op {
let l = self.sign_extend(l, left_layout) as i128;
let r = self.sign_extend(r, right_layout) as i128;
return Ok((ImmTy::from_bool(op(&l, &r), *self.tcx), false));
return Ok((ImmTy::from_bool(op(&l_signed(), &r_signed()), *self.tcx), false));
}
if bin_op == Cmp {
let l = self.sign_extend(l, left_layout) as i128;
let r = self.sign_extend(r, right_layout) as i128;
return Ok(self.three_way_compare(l, r));
return Ok(self.three_way_compare(l_signed(), r_signed()));
}
let op: Option<fn(i128, i128) -> (i128, bool)> = match bin_op {
Div if r == 0 => throw_ub!(DivisionByZero),
Rem if r == 0 => throw_ub!(RemainderByZero),
Div if r.is_null() => throw_ub!(DivisionByZero),
Rem if r.is_null() => throw_ub!(RemainderByZero),
Div => Some(i128::overflowing_div),
Rem => Some(i128::overflowing_rem),
Add | AddUnchecked => Some(i128::overflowing_add),
Expand All @@ -256,8 +261,8 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
_ => None,
};
if let Some(op) = op {
let l = self.sign_extend(l, left_layout) as i128;
let r = self.sign_extend(r, right_layout) as i128;
let l = l_signed();
let r = r_signed();

// We need a special check for overflowing Rem and Div since they are *UB*
// on overflow, which can happen with "int_min $OP -1".
Expand All @@ -272,17 +277,19 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
}

let (result, oflo) = op(l, r);
// This may be out-of-bounds for the result type, so we have to truncate ourselves.
// This may be out-of-bounds for the result type, so we have to truncate.
// If that truncation loses any information, we have an overflow.
let result = result as u128;
let truncated = self.truncate(result, left_layout);
let overflow = oflo || self.sign_extend(truncated, left_layout) != result;
let (result, lossy) = ScalarInt::truncate_from_int(result, left.layout.size);
let overflow = oflo || lossy;
if overflow && let Some(intrinsic_name) = throw_ub_on_overflow {
throw_ub_custom!(fluent::const_eval_overflow, name = intrinsic_name);
}
return Ok((ImmTy::from_uint(truncated, left_layout), overflow));
return Ok((ImmTy::from_scalar_int(result, left.layout), overflow));
}
}
// From here on it's okay to treat everything as unsigned.
let l = l_unsigned();
let r = r_unsigned();

if bin_op == Cmp {
return Ok(self.three_way_compare(l, r));
Expand All @@ -297,12 +304,12 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
Gt => ImmTy::from_bool(l > r, *self.tcx),
Ge => ImmTy::from_bool(l >= r, *self.tcx),

BitOr => ImmTy::from_uint(l | r, left_layout),
BitAnd => ImmTy::from_uint(l & r, left_layout),
BitXor => ImmTy::from_uint(l ^ r, left_layout),
BitOr => ImmTy::from_uint(l | r, left.layout),
BitAnd => ImmTy::from_uint(l & r, left.layout),
BitXor => ImmTy::from_uint(l ^ r, left.layout),

Add | AddUnchecked | Sub | SubUnchecked | Mul | MulUnchecked | Rem | Div => {
assert!(!left_layout.abi.is_signed());
assert!(!left.layout.abi.is_signed());
let op: fn(u128, u128) -> (u128, bool) = match bin_op {
Add | AddUnchecked => u128::overflowing_add,
Sub | SubUnchecked => u128::overflowing_sub,
Expand All @@ -316,21 +323,21 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
let (result, oflo) = op(l, r);
// Truncate to target type.
// If that truncation loses any information, we have an overflow.
let truncated = self.truncate(result, left_layout);
let overflow = oflo || truncated != result;
let (result, lossy) = ScalarInt::truncate_from_uint(result, left.layout.size);
let overflow = oflo || lossy;
if overflow && let Some(intrinsic_name) = throw_ub_on_overflow {
throw_ub_custom!(fluent::const_eval_overflow, name = intrinsic_name);
}
return Ok((ImmTy::from_uint(truncated, left_layout), overflow));
return Ok((ImmTy::from_scalar_int(result, left.layout), overflow));
}

_ => span_bug!(
self.cur_span(),
"invalid binary op {:?}: {:?}, {:?} (both {})",
bin_op,
l,
r,
right_layout.ty,
left,
right,
right.layout.ty,
),
};

Expand Down Expand Up @@ -427,9 +434,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
right.layout.ty
);

let l = left.to_scalar().to_bits(left.layout.size)?;
let r = right.to_scalar().to_bits(right.layout.size)?;
self.binary_int_op(bin_op, l, left.layout, r, right.layout)
self.binary_int_op(bin_op, left, right)
}
_ if left.layout.ty.is_any_ptr() => {
// The RHS type must be a `pointer` *or an integer type* (for `Offset`).
Expand Down
Loading

0 comments on commit ce3263e

Please sign in to comment.