Skip to content

Optimize test_ptr_try_cast_into_soundness #1907

Open

Description

The following tests consume the majority of time when a PR is going through the merge queue (currently ~25 min/PR). In particular, this happens when they are run under Miri.

zerocopy/src/pointer/ptr.rs

Lines 1255 to 1426 in b43c510

mod test_ptr_try_cast_into_soundness {
use super::*;
// This test is designed so that if `Ptr::try_cast_into_xxx` are
// buggy, it will manifest as unsoundness that Miri can detect.
// - If `size_of::<T>() == 0`, `N == 4`
// - Else, `N == 4 * size_of::<T>()`
//
// Each test will be run for each metadata in `metas`.
fn test<T, I, const N: usize>(metas: I)
where
T: ?Sized + KnownLayout + Immutable + FromBytes,
I: IntoIterator<Item = Option<T::PointerMetadata>> + Clone,
{
let mut bytes = [MaybeUninit::<u8>::uninit(); N];
let initialized = [MaybeUninit::new(0u8); N];
for start in 0..=bytes.len() {
for end in start..=bytes.len() {
// Set all bytes to uninitialized other than those in
// the range we're going to pass to `try_cast_from`.
// This allows Miri to detect out-of-bounds reads
// because they read uninitialized memory. Without this,
// some out-of-bounds reads would still be in-bounds of
// `bytes`, and so might spuriously be accepted.
bytes = [MaybeUninit::<u8>::uninit(); N];
let bytes = &mut bytes[start..end];
// Initialize only the byte range we're going to pass to
// `try_cast_from`.
bytes.copy_from_slice(&initialized[start..end]);
let bytes = {
let bytes: *const [MaybeUninit<u8>] = bytes;
#[allow(clippy::as_conversions)]
let bytes = bytes as *const [u8];
// SAFETY: We just initialized these bytes to valid
// `u8`s.
unsafe { &*bytes }
};
// SAFETY: The bytes in `slf` must be initialized.
unsafe fn validate_and_get_len<T: ?Sized + KnownLayout + FromBytes>(
slf: Ptr<'_, T, (Shared, Aligned, Initialized)>,
) -> usize {
let t = slf.bikeshed_recall_valid().as_ref();
let bytes = {
let len = mem::size_of_val(t);
let t: *const T = t;
// SAFETY:
// - We know `t`'s bytes are all initialized
// because we just read it from `slf`, which
// points to an initialized range of bytes. If
// there's a bug and this doesn't hold, then
// that's exactly what we're hoping Miri will
// catch!
// - Since `T: FromBytes`, `T` doesn't contain
// any `UnsafeCell`s, so it's okay for `t: T`
// and a `&[u8]` to the same memory to be
// alive concurrently.
unsafe { core::slice::from_raw_parts(t.cast::<u8>(), len) }
};
// This assertion ensures that `t`'s bytes are read
// and compared to another value, which in turn
// ensures that Miri gets a chance to notice if any
// of `t`'s bytes are uninitialized, which they
// shouldn't be (see the comment above).
assert_eq!(bytes, vec![0u8; bytes.len()]);
mem::size_of_val(t)
}
for meta in metas.clone().into_iter() {
for cast_type in [CastType::Prefix, CastType::Suffix] {
if let Ok((slf, remaining)) = Ptr::from_ref(bytes)
.try_cast_into::<T, BecauseImmutable>(cast_type, meta)
{
// SAFETY: All bytes in `bytes` have been
// initialized.
let len = unsafe { validate_and_get_len(slf) };
assert_eq!(remaining.len(), bytes.len() - len);
#[allow(unstable_name_collisions)]
let bytes_addr = bytes.as_ptr().addr();
#[allow(unstable_name_collisions)]
let remaining_addr =
remaining.as_inner().as_non_null().as_ptr().addr();
match cast_type {
CastType::Prefix => {
assert_eq!(remaining_addr, bytes_addr + len)
}
CastType::Suffix => assert_eq!(remaining_addr, bytes_addr),
}
if let Some(want) = meta {
let got = KnownLayout::pointer_to_metadata(
slf.as_inner().as_non_null().as_ptr(),
);
assert_eq!(got, want);
}
}
}
if let Ok(slf) = Ptr::from_ref(bytes)
.try_cast_into_no_leftover::<T, BecauseImmutable>(meta)
{
// SAFETY: All bytes in `bytes` have been
// initialized.
let len = unsafe { validate_and_get_len(slf) };
assert_eq!(len, bytes.len());
if let Some(want) = meta {
let got = KnownLayout::pointer_to_metadata(
slf.as_inner().as_non_null().as_ptr(),
);
assert_eq!(got, want);
}
}
}
}
}
}
#[derive(FromBytes, KnownLayout, Immutable)]
#[repr(C)]
struct SliceDst<T> {
a: u8,
trailing: [T],
}
// Each test case becomes its own `#[test]` function. We do this because
// this test in particular takes far, far longer to execute under Miri
// than all of our other tests combined. Previously, we had these
// execute sequentially in a single test function. We run Miri tests in
// parallel in CI, but this test being sequential meant that most of
// that parallelism was wasted, as all other tests would finish in a
// fraction of the total execution time, leaving this test to execute on
// a single thread for the remainder of the test. By putting each test
// case in its own function, we permit better use of available
// parallelism.
macro_rules! test {
($test_name:ident: $ty:ty) => {
#[test]
#[allow(non_snake_case)]
fn $test_name() {
const S: usize = core::mem::size_of::<$ty>();
const N: usize = if S == 0 { 4 } else { S * 4 };
test::<$ty, _, N>([None]);
// If `$ty` is a ZST, then we can't pass `None` as the
// pointer metadata, or else computing the correct trailing
// slice length will panic.
if S == 0 {
test::<[$ty], _, N>([Some(0), Some(1), Some(2), Some(3)]);
test::<SliceDst<$ty>, _, N>([Some(0), Some(1), Some(2), Some(3)]);
} else {
test::<[$ty], _, N>([None, Some(0), Some(1), Some(2), Some(3)]);
test::<SliceDst<$ty>, _, N>([None, Some(0), Some(1), Some(2), Some(3)]);
}
}
};
($ty:ident) => {
test!($ty: $ty);
};
($($ty:ident),*) => { $(test!($ty);)* }
}
test!(empty_tuple: ());
test!(u8, u16, u32, u64, u128, usize, AU64);
test!(i8, i16, i32, i64, i128, isize);
test!(f32, f64);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions