Skip to content

Simplify codegen for niche-encoded enums in simple cases #102901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 114 additions & 13 deletions compiler/rustc_codegen_ssa/src/mir/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use rustc_middle::mir;
use rustc_middle::mir::tcx::PlaceTy;
use rustc_middle::ty::layout::{HasTyCtxt, LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, Ty};
use rustc_target::abi::{Abi, Align, FieldsShape, Int, TagEncoding};
use rustc_target::abi::{VariantIdx, Variants};
use rustc_target::abi::{Abi, Align, FieldsShape, Int, Integer, TagEncoding};
use rustc_target::abi::{Primitive, Scalar, VariantIdx, Variants};

#[derive(Copy, Clone, Debug)]
pub struct PlaceRef<'tcx, V> {
Expand Down Expand Up @@ -231,6 +231,9 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
let tag = self.project_field(bx, tag_field);
let tag = bx.load_operand(tag);

let niche_llty = bx.cx().immediate_backend_type(tag.layout);
let tag = tag.immediate();

// Decode the discriminant (specifically if it's niche-encoded).
match *tag_encoding {
TagEncoding::Direct => {
Expand All @@ -242,13 +245,100 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
Int(_, signed) => !tag_scalar.is_bool() && signed,
_ => false,
};
bx.intcast(tag.immediate(), cast_to, signed)
bx.intcast(tag, cast_to, signed)
}
// Handle single-variant-niche separately.
// Note: this also cover pointer-based niches.
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start }
if niche_variants.end().as_u32() - niche_variants.start().as_u32() == 0 =>
{
let niche_start = if niche_start == 0 {
// Avoid calling `const_uint`, which wouldn't work for pointers.
bx.cx().const_null(niche_llty)
} else {
bx.cx().const_uint_big(niche_llty, niche_start)
};
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
bx.select(
is_niche,
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64),
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
)
}
// Handle other non-patological cases (= without niche wraparound, and with
// niche_variants.end fitting in the tag type).
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start }
if is_niche_after_nonniche(tag_scalar, niche_start, niche_variants, bx.cx()) =>
{
if untagged_variant >= *niche_variants.start() {
// This assumption helps LLVM to compile `get_discr() == untagged_variant`
// to a single cmp.
if untagged_variant <= *niche_variants.end() {
let unused_tag = bx.cx().const_uint_big(
niche_llty,
niche_start + untagged_variant.as_u32() as u128,
);
let assumption = bx.icmp(IntPredicate::IntNE, tag, unused_tag);
bx.assume(assumption);
} else {
// This assumption is in theory implied by tag range, but doing it
// explicitly helps.
let last_tag = bx.cx().const_uint_big(
niche_llty,
niche_start + niche_variants.end().as_u32() as u128,
);
let assumption = bx.icmp(IntPredicate::IntULE, tag, last_tag);
bx.assume(assumption);
}
}

let (niche_llty, tag) = match tag_scalar.primitive() {
// Operations on u8/u16 directly result in some additional movzxs
// (https://github.com/llvm/llvm-project/issues/58338), pretend the tag is
// cast_to type (usize) instead.
// FIXUP this is very x86-specific assumption
Int(Integer::I8 | Integer::I16, _) => {
(cast_to, bx.intcast(tag, cast_to, false))
}
_ => (niche_llty, tag),
};
// Thanks to the cast above, we can guarantee that all variant indexes fit in
// niche_llty (as they're u32s).

let is_untagged = bx.icmp(
IntPredicate::IntULT,
tag,
bx.cx().const_uint_big(niche_llty, niche_start),
);

let first_niche_variant = niche_variants.start().as_u32() as u128;
let niche_discr = match first_niche_variant.cmp(&niche_start) {
std::cmp::Ordering::Less => {
let diff = niche_start - first_niche_variant;
let diff = bx.cx().const_uint_big(niche_llty, diff);
bx.unchecked_usub(tag, diff)
}
std::cmp::Ordering::Equal => tag,
std::cmp::Ordering::Greater => {
let diff = first_niche_variant as u64 - niche_start as u64;
let diff = bx.cx().const_uint(niche_llty, diff);
bx.unchecked_uadd(tag, diff)
}
};
let untagged_discr =
bx.cx().const_uint(niche_llty, untagged_variant.as_u32() as u64);
let discr = bx.select(is_untagged, untagged_discr, niche_discr);

bx.intcast(discr, cast_to, false)
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
if matches!(tag_scalar.primitive(), Primitive::Pointer) {
// arith and non-null constants won't work on pointers
bug!("pointer with more than one niche variant")
}

// Rebase from niche values to discriminants, and check
// whether the result is in range for the niche variants.
let niche_llty = bx.cx().immediate_backend_type(tag.layout);
let tag = tag.immediate();

// We first compute the "relative discriminant" (wrt `niche_variants`),
// that is, if `n = niche_variants.end() - niche_variants.start()`,
Expand All @@ -260,19 +350,13 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
// that might not fit in the same type, on top of needing an extra
// comparison (see also the comment on `let niche_discr`).
let relative_discr = if niche_start == 0 {
// Avoid subtracting `0`, which wouldn't work for pointers.
// FIXME(eddyb) check the actual primitive type here.
tag
} else {
// Note: subtracting `0` wouldn't work for pointers.
bx.sub(tag, bx.cx().const_uint_big(niche_llty, niche_start))
};
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
let is_niche = if relative_max == 0 {
// Avoid calling `const_uint`, which wouldn't work for pointers.
// Also use canonical == 0 instead of non-canonical u<= 0.
// FIXME(eddyb) check the actual primitive type here.
bx.icmp(IntPredicate::IntEQ, relative_discr, bx.cx().const_null(niche_llty))
} else {
let is_niche = {
let relative_max = bx.cx().const_uint(niche_llty, relative_max as u64);
bx.icmp(IntPredicate::IntULE, relative_discr, relative_max)
};
Expand Down Expand Up @@ -536,3 +620,20 @@ fn round_up_const_value_to_alignment<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
let offset = bx.and(neg_value, align_minus_1);
bx.add(value, offset)
}

/// Are all niche variants encoded in tag greater than the actual untagged tag values (so that we
/// can avoid the "relative discriminant" and we can simply `< niche_start` to ask whether it's
/// untagged or not).
fn is_niche_after_nonniche(
tag: Scalar,
niche_start: u128,
niche_variants: &std::ops::RangeInclusive<VariantIdx>,
cx: &impl rustc_target::abi::HasDataLayout,
) -> bool {
let tag_range = tag.valid_range(cx);
if tag_range.wraps() {
return false;
}
let n_variants = niche_variants.end().as_u32() - niche_variants.start().as_u32();
niche_start.checked_add(n_variants as u128) == Some(tag_range.end)
}
5 changes: 5 additions & 0 deletions compiler/rustc_target/src/abi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,11 @@ impl WrappingRange {
debug_assert!(self.start <= max_value && self.end <= max_value);
self.start == (self.end.wrapping_add(1) & max_value)
}

#[inline]
pub fn wraps(&self) -> bool {
return self.start > self.end;
}
}

impl fmt::Debug for WrappingRange {
Expand Down