Skip to content

Commit a299d29

Browse files
committed
Use SIMD intrinsics for vector shifts
1 parent e11af1c commit a299d29

File tree

3 files changed

+162
-87
lines changed

3 files changed

+162
-87
lines changed

crates/core_arch/src/x86/avx2.rs

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2778,7 +2778,12 @@ pub fn _mm256_bslli_epi128<const IMM8: i32>(a: __m256i) -> __m256i {
27782778
#[cfg_attr(test, assert_instr(vpsllvd))]
27792779
#[stable(feature = "simd_x86", since = "1.27.0")]
27802780
pub fn _mm_sllv_epi32(a: __m128i, count: __m128i) -> __m128i {
2781-
unsafe { transmute(psllvd(a.as_i32x4(), count.as_i32x4())) }
2781+
unsafe {
2782+
let count = count.as_u32x4();
2783+
let good: u32x4 = simd_lt(count, u32x4::splat(32));
2784+
let count = simd_select(good, count, u32x4::ZERO);
2785+
simd_select(good, simd_shl(a.as_u32x4(), count), u32x4::ZERO).as_m128i()
2786+
}
27822787
}
27832788

27842789
/// Shifts packed 32-bit integers in `a` left by the amount
@@ -2791,7 +2796,12 @@ pub fn _mm_sllv_epi32(a: __m128i, count: __m128i) -> __m128i {
27912796
#[cfg_attr(test, assert_instr(vpsllvd))]
27922797
#[stable(feature = "simd_x86", since = "1.27.0")]
27932798
pub fn _mm256_sllv_epi32(a: __m256i, count: __m256i) -> __m256i {
2794-
unsafe { transmute(psllvd256(a.as_i32x8(), count.as_i32x8())) }
2799+
unsafe {
2800+
let count = count.as_u32x8();
2801+
let good: u32x8 = simd_lt(count, u32x8::splat(32));
2802+
let count = simd_select(good, count, u32x8::ZERO);
2803+
simd_select(good, simd_shl(a.as_u32x8(), count), u32x8::ZERO).as_m256i()
2804+
}
27952805
}
27962806

27972807
/// Shifts packed 64-bit integers in `a` left by the amount
@@ -2804,7 +2814,12 @@ pub fn _mm256_sllv_epi32(a: __m256i, count: __m256i) -> __m256i {
28042814
#[cfg_attr(test, assert_instr(vpsllvq))]
28052815
#[stable(feature = "simd_x86", since = "1.27.0")]
28062816
pub fn _mm_sllv_epi64(a: __m128i, count: __m128i) -> __m128i {
2807-
unsafe { transmute(psllvq(a.as_i64x2(), count.as_i64x2())) }
2817+
unsafe {
2818+
let count = count.as_u64x2();
2819+
let good: u64x2 = simd_lt(count, u64x2::splat(64));
2820+
let count = simd_select(good, count, u64x2::ZERO);
2821+
simd_select(good, simd_shl(a.as_u64x2(), count), u64x2::ZERO).as_m128i()
2822+
}
28082823
}
28092824

28102825
/// Shifts packed 64-bit integers in `a` left by the amount
@@ -2817,7 +2832,12 @@ pub fn _mm_sllv_epi64(a: __m128i, count: __m128i) -> __m128i {
28172832
#[cfg_attr(test, assert_instr(vpsllvq))]
28182833
#[stable(feature = "simd_x86", since = "1.27.0")]
28192834
pub fn _mm256_sllv_epi64(a: __m256i, count: __m256i) -> __m256i {
2820-
unsafe { transmute(psllvq256(a.as_i64x4(), count.as_i64x4())) }
2835+
unsafe {
2836+
let count = count.as_u64x4();
2837+
let good: u64x4 = simd_lt(count, u64x4::splat(64));
2838+
let count = simd_select(good, count, u64x4::ZERO);
2839+
simd_select(good, simd_shl(a.as_u64x4(), count), u64x4::ZERO).as_m256i()
2840+
}
28212841
}
28222842

28232843
/// Shifts packed 16-bit integers in `a` right by `count` while
@@ -2881,7 +2901,12 @@ pub fn _mm256_srai_epi32<const IMM8: i32>(a: __m256i) -> __m256i {
28812901
#[cfg_attr(test, assert_instr(vpsravd))]
28822902
#[stable(feature = "simd_x86", since = "1.27.0")]
28832903
pub fn _mm_srav_epi32(a: __m128i, count: __m128i) -> __m128i {
2884-
unsafe { transmute(psravd(a.as_i32x4(), count.as_i32x4())) }
2904+
unsafe {
2905+
let count = count.as_u32x4();
2906+
let good: u32x4 = simd_lt(count, u32x4::splat(32));
2907+
let count = simd_select(good, transmute(count), i32x4::splat(31));
2908+
simd_shr(a.as_i32x4(), count).as_m128i()
2909+
}
28852910
}
28862911

28872912
/// Shifts packed 32-bit integers in `a` right by the amount specified by the
@@ -2893,7 +2918,12 @@ pub fn _mm_srav_epi32(a: __m128i, count: __m128i) -> __m128i {
28932918
#[cfg_attr(test, assert_instr(vpsravd))]
28942919
#[stable(feature = "simd_x86", since = "1.27.0")]
28952920
pub fn _mm256_srav_epi32(a: __m256i, count: __m256i) -> __m256i {
2896-
unsafe { transmute(psravd256(a.as_i32x8(), count.as_i32x8())) }
2921+
unsafe {
2922+
let count = count.as_u32x8();
2923+
let good: u32x8 = simd_lt(count, u32x8::splat(32));
2924+
let count = simd_select(good, transmute(count), i32x8::splat(31));
2925+
simd_shr(a.as_i32x8(), count).as_m256i()
2926+
}
28972927
}
28982928

28992929
/// Shifts 128-bit lanes in `a` right by `imm8` bytes while shifting in zeros.
@@ -3076,7 +3106,12 @@ pub fn _mm256_srli_epi64<const IMM8: i32>(a: __m256i) -> __m256i {
30763106
#[cfg_attr(test, assert_instr(vpsrlvd))]
30773107
#[stable(feature = "simd_x86", since = "1.27.0")]
30783108
pub fn _mm_srlv_epi32(a: __m128i, count: __m128i) -> __m128i {
3079-
unsafe { transmute(psrlvd(a.as_i32x4(), count.as_i32x4())) }
3109+
unsafe {
3110+
let count = count.as_u32x4();
3111+
let good: u32x4 = simd_lt(count, u32x4::splat(32));
3112+
let count = simd_select(good, count, u32x4::ZERO);
3113+
simd_select(good, simd_shr(a.as_u32x4(), count), u32x4::ZERO).as_m128i()
3114+
}
30803115
}
30813116

30823117
/// Shifts packed 32-bit integers in `a` right by the amount specified by
@@ -3088,7 +3123,12 @@ pub fn _mm_srlv_epi32(a: __m128i, count: __m128i) -> __m128i {
30883123
#[cfg_attr(test, assert_instr(vpsrlvd))]
30893124
#[stable(feature = "simd_x86", since = "1.27.0")]
30903125
pub fn _mm256_srlv_epi32(a: __m256i, count: __m256i) -> __m256i {
3091-
unsafe { transmute(psrlvd256(a.as_i32x8(), count.as_i32x8())) }
3126+
unsafe {
3127+
let count = count.as_u32x8();
3128+
let good: u32x8 = simd_lt(count, u32x8::splat(32));
3129+
let count = simd_select(good, count, u32x8::ZERO);
3130+
simd_select(good, simd_shr(a.as_u32x8(), count), u32x8::ZERO).as_m256i()
3131+
}
30923132
}
30933133

30943134
/// Shifts packed 64-bit integers in `a` right by the amount specified by
@@ -3100,7 +3140,12 @@ pub fn _mm256_srlv_epi32(a: __m256i, count: __m256i) -> __m256i {
31003140
#[cfg_attr(test, assert_instr(vpsrlvq))]
31013141
#[stable(feature = "simd_x86", since = "1.27.0")]
31023142
pub fn _mm_srlv_epi64(a: __m128i, count: __m128i) -> __m128i {
3103-
unsafe { transmute(psrlvq(a.as_i64x2(), count.as_i64x2())) }
3143+
unsafe {
3144+
let count = count.as_u64x2();
3145+
let good: u64x2 = simd_lt(count, u64x2::splat(64));
3146+
let count = simd_select(good, count, u64x2::ZERO);
3147+
simd_select(good, simd_shr(a.as_u64x2(), count), u64x2::ZERO).as_m128i()
3148+
}
31043149
}
31053150

31063151
/// Shifts packed 64-bit integers in `a` right by the amount specified by
@@ -3112,7 +3157,12 @@ pub fn _mm_srlv_epi64(a: __m128i, count: __m128i) -> __m128i {
31123157
#[cfg_attr(test, assert_instr(vpsrlvq))]
31133158
#[stable(feature = "simd_x86", since = "1.27.0")]
31143159
pub fn _mm256_srlv_epi64(a: __m256i, count: __m256i) -> __m256i {
3115-
unsafe { transmute(psrlvq256(a.as_i64x4(), count.as_i64x4())) }
3160+
unsafe {
3161+
let count = count.as_u64x4();
3162+
let good: u64x4 = simd_lt(count, u64x4::splat(64));
3163+
let count = simd_select(good, count, u64x4::ZERO);
3164+
simd_select(good, simd_shr(a.as_u64x4(), count), u64x4::ZERO).as_m256i()
3165+
}
31163166
}
31173167

31183168
/// Load 256-bits of integer data from memory into dst using a non-temporal memory hint. mem_addr
@@ -3687,36 +3737,16 @@ unsafe extern "C" {
36873737
fn pslld(a: i32x8, count: i32x4) -> i32x8;
36883738
#[link_name = "llvm.x86.avx2.psll.q"]
36893739
fn psllq(a: i64x4, count: i64x2) -> i64x4;
3690-
#[link_name = "llvm.x86.avx2.psllv.d"]
3691-
fn psllvd(a: i32x4, count: i32x4) -> i32x4;
3692-
#[link_name = "llvm.x86.avx2.psllv.d.256"]
3693-
fn psllvd256(a: i32x8, count: i32x8) -> i32x8;
3694-
#[link_name = "llvm.x86.avx2.psllv.q"]
3695-
fn psllvq(a: i64x2, count: i64x2) -> i64x2;
3696-
#[link_name = "llvm.x86.avx2.psllv.q.256"]
3697-
fn psllvq256(a: i64x4, count: i64x4) -> i64x4;
36983740
#[link_name = "llvm.x86.avx2.psra.w"]
36993741
fn psraw(a: i16x16, count: i16x8) -> i16x16;
37003742
#[link_name = "llvm.x86.avx2.psra.d"]
37013743
fn psrad(a: i32x8, count: i32x4) -> i32x8;
3702-
#[link_name = "llvm.x86.avx2.psrav.d"]
3703-
fn psravd(a: i32x4, count: i32x4) -> i32x4;
3704-
#[link_name = "llvm.x86.avx2.psrav.d.256"]
3705-
fn psravd256(a: i32x8, count: i32x8) -> i32x8;
37063744
#[link_name = "llvm.x86.avx2.psrl.w"]
37073745
fn psrlw(a: i16x16, count: i16x8) -> i16x16;
37083746
#[link_name = "llvm.x86.avx2.psrl.d"]
37093747
fn psrld(a: i32x8, count: i32x4) -> i32x8;
37103748
#[link_name = "llvm.x86.avx2.psrl.q"]
37113749
fn psrlq(a: i64x4, count: i64x2) -> i64x4;
3712-
#[link_name = "llvm.x86.avx2.psrlv.d"]
3713-
fn psrlvd(a: i32x4, count: i32x4) -> i32x4;
3714-
#[link_name = "llvm.x86.avx2.psrlv.d.256"]
3715-
fn psrlvd256(a: i32x8, count: i32x8) -> i32x8;
3716-
#[link_name = "llvm.x86.avx2.psrlv.q"]
3717-
fn psrlvq(a: i64x2, count: i64x2) -> i64x2;
3718-
#[link_name = "llvm.x86.avx2.psrlv.q.256"]
3719-
fn psrlvq256(a: i64x4, count: i64x4) -> i64x4;
37203750
#[link_name = "llvm.x86.avx2.pshuf.b"]
37213751
fn pshufb(a: u8x32, b: u8x32) -> u8x32;
37223752
#[link_name = "llvm.x86.avx2.permd"]

crates/core_arch/src/x86/avx512bw.rs

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6852,7 +6852,12 @@ pub fn _mm_maskz_slli_epi16<const IMM8: u32>(k: __mmask8, a: __m128i) -> __m128i
68526852
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
68536853
#[cfg_attr(test, assert_instr(vpsllvw))]
68546854
pub fn _mm512_sllv_epi16(a: __m512i, count: __m512i) -> __m512i {
6855-
unsafe { transmute(vpsllvw(a.as_i16x32(), count.as_i16x32())) }
6855+
unsafe {
6856+
let count = count.as_u16x32();
6857+
let good: u16x32 = simd_lt(count, u16x32::splat(16));
6858+
let count = simd_select(good, count, u16x32::ZERO);
6859+
simd_select(good, simd_shl(a.as_u16x32(), count), u16x32::ZERO).as_m512i()
6860+
}
68566861
}
68576862

68586863
/// Shift packed 16-bit integers in a left by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -6891,7 +6896,12 @@ pub fn _mm512_maskz_sllv_epi16(k: __mmask32, a: __m512i, count: __m512i) -> __m5
68916896
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
68926897
#[cfg_attr(test, assert_instr(vpsllvw))]
68936898
pub fn _mm256_sllv_epi16(a: __m256i, count: __m256i) -> __m256i {
6894-
unsafe { transmute(vpsllvw256(a.as_i16x16(), count.as_i16x16())) }
6899+
unsafe {
6900+
let count = count.as_u16x16();
6901+
let good: u16x16 = simd_lt(count, u16x16::splat(16));
6902+
let count = simd_select(good, count, u16x16::ZERO);
6903+
simd_select(good, simd_shl(a.as_u16x16(), count), u16x16::ZERO).as_m256i()
6904+
}
68956905
}
68966906

68976907
/// Shift packed 16-bit integers in a left by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -6930,7 +6940,12 @@ pub fn _mm256_maskz_sllv_epi16(k: __mmask16, a: __m256i, count: __m256i) -> __m2
69306940
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
69316941
#[cfg_attr(test, assert_instr(vpsllvw))]
69326942
pub fn _mm_sllv_epi16(a: __m128i, count: __m128i) -> __m128i {
6933-
unsafe { transmute(vpsllvw128(a.as_i16x8(), count.as_i16x8())) }
6943+
unsafe {
6944+
let count = count.as_u16x8();
6945+
let good: u16x8 = simd_lt(count, u16x8::splat(16));
6946+
let count = simd_select(good, count, u16x8::ZERO);
6947+
simd_select(good, simd_shl(a.as_u16x8(), count), u16x8::ZERO).as_m128i()
6948+
}
69346949
}
69356950

69366951
/// Shift packed 16-bit integers in a left by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7188,7 +7203,12 @@ pub fn _mm_maskz_srli_epi16<const IMM8: i32>(k: __mmask8, a: __m128i) -> __m128i
71887203
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
71897204
#[cfg_attr(test, assert_instr(vpsrlvw))]
71907205
pub fn _mm512_srlv_epi16(a: __m512i, count: __m512i) -> __m512i {
7191-
unsafe { transmute(vpsrlvw(a.as_i16x32(), count.as_i16x32())) }
7206+
unsafe {
7207+
let count = count.as_u16x32();
7208+
let good: u16x32 = simd_lt(count, u16x32::splat(16));
7209+
let count = simd_select(good, count, u16x32::ZERO);
7210+
simd_select(good, simd_shr(a.as_u16x32(), count), u16x32::ZERO).as_m512i()
7211+
}
71927212
}
71937213

71947214
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7227,7 +7247,12 @@ pub fn _mm512_maskz_srlv_epi16(k: __mmask32, a: __m512i, count: __m512i) -> __m5
72277247
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
72287248
#[cfg_attr(test, assert_instr(vpsrlvw))]
72297249
pub fn _mm256_srlv_epi16(a: __m256i, count: __m256i) -> __m256i {
7230-
unsafe { transmute(vpsrlvw256(a.as_i16x16(), count.as_i16x16())) }
7250+
unsafe {
7251+
let count = count.as_u16x16();
7252+
let good: u16x16 = simd_lt(count, u16x16::splat(16));
7253+
let count = simd_select(good, count, u16x16::ZERO);
7254+
simd_select(good, simd_shr(a.as_u16x16(), count), u16x16::ZERO).as_m256i()
7255+
}
72317256
}
72327257

72337258
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7266,7 +7291,12 @@ pub fn _mm256_maskz_srlv_epi16(k: __mmask16, a: __m256i, count: __m256i) -> __m2
72667291
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
72677292
#[cfg_attr(test, assert_instr(vpsrlvw))]
72687293
pub fn _mm_srlv_epi16(a: __m128i, count: __m128i) -> __m128i {
7269-
unsafe { transmute(vpsrlvw128(a.as_i16x8(), count.as_i16x8())) }
7294+
unsafe {
7295+
let count = count.as_u16x8();
7296+
let good: u16x8 = simd_lt(count, u16x8::splat(16));
7297+
let count = simd_select(good, count, u16x8::ZERO);
7298+
simd_select(good, simd_shr(a.as_u16x8(), count), u16x8::ZERO).as_m128i()
7299+
}
72707300
}
72717301

72727302
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7511,7 +7541,12 @@ pub fn _mm_maskz_srai_epi16<const IMM8: u32>(k: __mmask8, a: __m128i) -> __m128i
75117541
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
75127542
#[cfg_attr(test, assert_instr(vpsravw))]
75137543
pub fn _mm512_srav_epi16(a: __m512i, count: __m512i) -> __m512i {
7514-
unsafe { transmute(vpsravw(a.as_i16x32(), count.as_i16x32())) }
7544+
unsafe {
7545+
let count = count.as_u16x32();
7546+
let good: u16x32 = simd_lt(count, u16x32::splat(16));
7547+
let count = simd_select(good, transmute(count), i16x32::splat(15));
7548+
simd_shr(a.as_i16x32(), count).as_m512i()
7549+
}
75157550
}
75167551

75177552
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7550,7 +7585,12 @@ pub fn _mm512_maskz_srav_epi16(k: __mmask32, a: __m512i, count: __m512i) -> __m5
75507585
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
75517586
#[cfg_attr(test, assert_instr(vpsravw))]
75527587
pub fn _mm256_srav_epi16(a: __m256i, count: __m256i) -> __m256i {
7553-
unsafe { transmute(vpsravw256(a.as_i16x16(), count.as_i16x16())) }
7588+
unsafe {
7589+
let count = count.as_u16x16();
7590+
let good: u16x16 = simd_lt(count, u16x16::splat(16));
7591+
let count = simd_select(good, transmute(count), i16x16::splat(15));
7592+
simd_shr(a.as_i16x16(), count).as_m256i()
7593+
}
75547594
}
75557595

75567596
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7589,7 +7629,12 @@ pub fn _mm256_maskz_srav_epi16(k: __mmask16, a: __m256i, count: __m256i) -> __m2
75897629
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
75907630
#[cfg_attr(test, assert_instr(vpsravw))]
75917631
pub fn _mm_srav_epi16(a: __m128i, count: __m128i) -> __m128i {
7592-
unsafe { transmute(vpsravw128(a.as_i16x8(), count.as_i16x8())) }
7632+
unsafe {
7633+
let count = count.as_u16x8();
7634+
let good: u16x8 = simd_lt(count, u16x8::splat(16));
7635+
let count = simd_select(good, transmute(count), i16x8::splat(15));
7636+
simd_shr(a.as_i16x8(), count).as_m128i()
7637+
}
75937638
}
75947639

75957640
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -11645,33 +11690,12 @@ unsafe extern "C" {
1164511690
#[link_name = "llvm.x86.avx512.psll.w.512"]
1164611691
fn vpsllw(a: i16x32, count: i16x8) -> i16x32;
1164711692

11648-
#[link_name = "llvm.x86.avx512.psllv.w.512"]
11649-
fn vpsllvw(a: i16x32, b: i16x32) -> i16x32;
11650-
#[link_name = "llvm.x86.avx512.psllv.w.256"]
11651-
fn vpsllvw256(a: i16x16, b: i16x16) -> i16x16;
11652-
#[link_name = "llvm.x86.avx512.psllv.w.128"]
11653-
fn vpsllvw128(a: i16x8, b: i16x8) -> i16x8;
11654-
1165511693
#[link_name = "llvm.x86.avx512.psrl.w.512"]
1165611694
fn vpsrlw(a: i16x32, count: i16x8) -> i16x32;
1165711695

11658-
#[link_name = "llvm.x86.avx512.psrlv.w.512"]
11659-
fn vpsrlvw(a: i16x32, b: i16x32) -> i16x32;
11660-
#[link_name = "llvm.x86.avx512.psrlv.w.256"]
11661-
fn vpsrlvw256(a: i16x16, b: i16x16) -> i16x16;
11662-
#[link_name = "llvm.x86.avx512.psrlv.w.128"]
11663-
fn vpsrlvw128(a: i16x8, b: i16x8) -> i16x8;
11664-
1166511696
#[link_name = "llvm.x86.avx512.psra.w.512"]
1166611697
fn vpsraw(a: i16x32, count: i16x8) -> i16x32;
1166711698

11668-
#[link_name = "llvm.x86.avx512.psrav.w.512"]
11669-
fn vpsravw(a: i16x32, count: i16x32) -> i16x32;
11670-
#[link_name = "llvm.x86.avx512.psrav.w.256"]
11671-
fn vpsravw256(a: i16x16, count: i16x16) -> i16x16;
11672-
#[link_name = "llvm.x86.avx512.psrav.w.128"]
11673-
fn vpsravw128(a: i16x8, count: i16x8) -> i16x8;
11674-
1167511699
#[link_name = "llvm.x86.avx512.vpermi2var.hi.512"]
1167611700
fn vpermi2w(a: i16x32, idx: i16x32, b: i16x32) -> i16x32;
1167711701
#[link_name = "llvm.x86.avx512.vpermi2var.hi.256"]

0 commit comments

Comments
 (0)