Skip to content

Commit 51f7643

Browse files
committed
share permute shim
1 parent 760047c commit 51f7643

File tree

3 files changed

+43
-34
lines changed

3 files changed

+43
-34
lines changed

src/shims/x86/avx2.rs

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;
66

77
use super::{
88
ShiftOp, horizontal_bin_op, mask_load, mask_store, mpsadbw, packssdw, packsswb, packusdw,
9-
packuswb, pmaddbw, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd,
9+
packuswb, permute, pmaddbw, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd,
1010
};
1111
use crate::*;
1212

@@ -189,28 +189,12 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
189189

190190
packusdw(this, left, right, dest)?;
191191
}
192-
// Used to implement the _mm256_permutevar8x32_epi32 and
193-
// _mm256_permutevar8x32_ps function.
194-
// Shuffles `left` using the three low bits of each element of `right`
195-
// as indices.
192+
// Used to implement _mm256_permutevar8x32_epi32 and _mm256_permutevar8x32_ps.
196193
"permd" | "permps" => {
197194
let [left, right] =
198195
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
199196

200-
let (left, left_len) = this.project_to_simd(left)?;
201-
let (right, right_len) = this.project_to_simd(right)?;
202-
let (dest, dest_len) = this.project_to_simd(dest)?;
203-
204-
assert_eq!(dest_len, left_len);
205-
assert_eq!(dest_len, right_len);
206-
207-
for i in 0..dest_len {
208-
let dest = this.project_index(&dest, i)?;
209-
let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u32()?;
210-
let left = this.project_index(&left, (right & 0b111).into())?;
211-
212-
this.copy_op(&left, &dest)?;
213-
}
197+
permute(this, left, right, dest)?;
214198
}
215199
// Used to implement the _mm256_sad_epu8 function.
216200
"psad.bw" => {

src/shims/x86/avx512.rs

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use rustc_middle::ty::Ty;
33
use rustc_span::Symbol;
44
use rustc_target::callconv::FnAbi;
55

6-
use super::{pmaddbw, psadbw};
6+
use super::{permute, pmaddbw, psadbw};
77
use crate::*;
88

99
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
@@ -100,20 +100,7 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
100100
let [left, right] =
101101
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
102102

103-
let (left, left_len) = this.project_to_simd(left)?;
104-
let (right, right_len) = this.project_to_simd(right)?;
105-
let (dest, dest_len) = this.project_to_simd(dest)?;
106-
107-
assert_eq!(dest_len, left_len);
108-
assert_eq!(dest_len, right_len);
109-
110-
for i in 0..dest_len {
111-
let dest = this.project_index(&dest, i)?;
112-
let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u32()?;
113-
let left = this.project_index(&left, (right & 0b1111).into())?;
114-
115-
this.copy_op(&left, &dest)?;
116-
}
103+
permute(this, left, right, dest)?;
117104
}
118105
_ => return interp_ok(EmulateItemResult::NotSupported),
119106
}

src/shims/x86/mod.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,44 @@ fn pmaddbw<'tcx>(
11321132
interp_ok(())
11331133
}
11341134

1135+
/// Shuffle 32-bit integers in `values` across lanes using the corresponding
1136+
/// index in `indices`, and store the results in dst.
1137+
///
1138+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_permutevar8x32_epi32>
1139+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_permutevar8x32_ps>
1140+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_permutexvar_epi32>
1141+
fn permute<'tcx>(
1142+
ecx: &mut crate::MiriInterpCx<'tcx>,
1143+
values: &OpTy<'tcx>,
1144+
indices: &OpTy<'tcx>,
1145+
dest: &MPlaceTy<'tcx>,
1146+
) -> InterpResult<'tcx, ()> {
1147+
let (values, values_len) = ecx.project_to_simd(values)?;
1148+
let (indices, indices_len) = ecx.project_to_simd(indices)?;
1149+
let (dest, dest_len) = ecx.project_to_simd(dest)?;
1150+
1151+
// fn permd(a: u32x8, b: u32x8) -> u32x8;
1152+
// fn permps(a: __m256, b: i32x8) -> __m256;
1153+
// fn vpermd(a: i32x16, idx: i32x16) -> i32x16;
1154+
assert_eq!(dest_len, values_len);
1155+
assert_eq!(dest_len, indices_len);
1156+
1157+
// Only use the lower 3 bits to index into a vector with 8 lanes,
1158+
// or the lower 4 bits when indexing into a 16-lane vector.
1159+
assert!(dest_len.is_power_of_two());
1160+
let mask = u32::try_from(dest_len).unwrap().strict_sub(1);
1161+
1162+
for i in 0..dest_len {
1163+
let dest = ecx.project_index(&dest, i)?;
1164+
let index = ecx.read_scalar(&ecx.project_index(&indices, i)?)?.to_u32()?;
1165+
let element = ecx.project_index(&values, (index & mask).into())?;
1166+
1167+
ecx.copy_op(&element, &dest)?;
1168+
}
1169+
1170+
interp_ok(())
1171+
}
1172+
11351173
/// Multiplies packed 16-bit signed integer values, truncates the 32-bit
11361174
/// product to the 18 most significant bits by right-shifting, and then
11371175
/// divides the 18-bit value by 2 (rounding to nearest) by first adding

0 commit comments

Comments
 (0)