Skip to content

Commit 8e0efaf

Browse files
committed
Use SIMD intrinsics for vector shifts
1 parent e11af1c commit 8e0efaf

File tree

3 files changed

+135
-87
lines changed

3 files changed

+135
-87
lines changed

crates/core_arch/src/x86/avx2.rs

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2778,7 +2778,11 @@ pub fn _mm256_bslli_epi128<const IMM8: i32>(a: __m256i) -> __m256i {
27782778
#[cfg_attr(test, assert_instr(vpsllvd))]
27792779
#[stable(feature = "simd_x86", since = "1.27.0")]
27802780
pub fn _mm_sllv_epi32(a: __m128i, count: __m128i) -> __m128i {
2781-
unsafe { transmute(psllvd(a.as_i32x4(), count.as_i32x4())) }
2781+
unsafe {
2782+
let count = count.as_u32x4();
2783+
let overflow: u32x4 = simd_ge(count, u32x4::splat(32));
2784+
simd_select(overflow, u32x4::ZERO, simd_shl(a.as_u32x4(), count)).as_m128i()
2785+
}
27822786
}
27832787

27842788
/// Shifts packed 32-bit integers in `a` left by the amount
@@ -2791,7 +2795,11 @@ pub fn _mm_sllv_epi32(a: __m128i, count: __m128i) -> __m128i {
27912795
#[cfg_attr(test, assert_instr(vpsllvd))]
27922796
#[stable(feature = "simd_x86", since = "1.27.0")]
27932797
pub fn _mm256_sllv_epi32(a: __m256i, count: __m256i) -> __m256i {
2794-
unsafe { transmute(psllvd256(a.as_i32x8(), count.as_i32x8())) }
2798+
unsafe {
2799+
let count = count.as_u32x8();
2800+
let overflow: u32x8 = simd_ge(count, u32x8::splat(32));
2801+
simd_select(overflow, u32x8::ZERO, simd_shl(a.as_u32x8(), count)).as_m256i()
2802+
}
27952803
}
27962804

27972805
/// Shifts packed 64-bit integers in `a` left by the amount
@@ -2804,7 +2812,11 @@ pub fn _mm256_sllv_epi32(a: __m256i, count: __m256i) -> __m256i {
28042812
#[cfg_attr(test, assert_instr(vpsllvq))]
28052813
#[stable(feature = "simd_x86", since = "1.27.0")]
28062814
pub fn _mm_sllv_epi64(a: __m128i, count: __m128i) -> __m128i {
2807-
unsafe { transmute(psllvq(a.as_i64x2(), count.as_i64x2())) }
2815+
unsafe {
2816+
let count = count.as_u64x2();
2817+
let overflow: u64x2 = simd_ge(count, u64x2::splat(64));
2818+
simd_select(overflow, u64x2::ZERO, simd_shl(a.as_u64x2(), count)).as_m128i()
2819+
}
28082820
}
28092821

28102822
/// Shifts packed 64-bit integers in `a` left by the amount
@@ -2817,7 +2829,11 @@ pub fn _mm_sllv_epi64(a: __m128i, count: __m128i) -> __m128i {
28172829
#[cfg_attr(test, assert_instr(vpsllvq))]
28182830
#[stable(feature = "simd_x86", since = "1.27.0")]
28192831
pub fn _mm256_sllv_epi64(a: __m256i, count: __m256i) -> __m256i {
2820-
unsafe { transmute(psllvq256(a.as_i64x4(), count.as_i64x4())) }
2832+
unsafe {
2833+
let count = count.as_u64x4();
2834+
let overflow: u64x4 = simd_ge(count, u64x4::splat(64));
2835+
simd_select(overflow, u64x4::ZERO, simd_shl(a.as_u64x4(), count)).as_m256i()
2836+
}
28212837
}
28222838

28232839
/// Shifts packed 16-bit integers in `a` right by `count` while
@@ -2881,7 +2897,11 @@ pub fn _mm256_srai_epi32<const IMM8: i32>(a: __m256i) -> __m256i {
28812897
#[cfg_attr(test, assert_instr(vpsravd))]
28822898
#[stable(feature = "simd_x86", since = "1.27.0")]
28832899
pub fn _mm_srav_epi32(a: __m128i, count: __m128i) -> __m128i {
2884-
unsafe { transmute(psravd(a.as_i32x4(), count.as_i32x4())) }
2900+
unsafe {
2901+
let count = count.as_i32x4();
2902+
let overflow: i32x4 = simd_ge(count, i32x4::splat(32));
2903+
simd_select(overflow, i32x4::ZERO, simd_shr(a.as_i32x4(), count)).as_m128i()
2904+
}
28852905
}
28862906

28872907
/// Shifts packed 32-bit integers in `a` right by the amount specified by the
@@ -2893,7 +2913,11 @@ pub fn _mm_srav_epi32(a: __m128i, count: __m128i) -> __m128i {
28932913
#[cfg_attr(test, assert_instr(vpsravd))]
28942914
#[stable(feature = "simd_x86", since = "1.27.0")]
28952915
pub fn _mm256_srav_epi32(a: __m256i, count: __m256i) -> __m256i {
2896-
unsafe { transmute(psravd256(a.as_i32x8(), count.as_i32x8())) }
2916+
unsafe {
2917+
let count = count.as_i32x8();
2918+
let overflow: i32x8 = simd_ge(count, i32x8::splat(32));
2919+
simd_select(overflow, i32x8::ZERO, simd_shr(a.as_i32x8(), count)).as_m256i()
2920+
}
28972921
}
28982922

28992923
/// Shifts 128-bit lanes in `a` right by `imm8` bytes while shifting in zeros.
@@ -3076,7 +3100,11 @@ pub fn _mm256_srli_epi64<const IMM8: i32>(a: __m256i) -> __m256i {
30763100
#[cfg_attr(test, assert_instr(vpsrlvd))]
30773101
#[stable(feature = "simd_x86", since = "1.27.0")]
30783102
pub fn _mm_srlv_epi32(a: __m128i, count: __m128i) -> __m128i {
3079-
unsafe { transmute(psrlvd(a.as_i32x4(), count.as_i32x4())) }
3103+
unsafe {
3104+
let count = count.as_u32x4();
3105+
let overflow: u32x4 = simd_ge(count, u32x4::splat(32));
3106+
simd_select(overflow, u32x4::ZERO, simd_shr(a.as_u32x4(), count)).as_m128i()
3107+
}
30803108
}
30813109

30823110
/// Shifts packed 32-bit integers in `a` right by the amount specified by
@@ -3088,7 +3116,11 @@ pub fn _mm_srlv_epi32(a: __m128i, count: __m128i) -> __m128i {
30883116
#[cfg_attr(test, assert_instr(vpsrlvd))]
30893117
#[stable(feature = "simd_x86", since = "1.27.0")]
30903118
pub fn _mm256_srlv_epi32(a: __m256i, count: __m256i) -> __m256i {
3091-
unsafe { transmute(psrlvd256(a.as_i32x8(), count.as_i32x8())) }
3119+
unsafe {
3120+
let count = count.as_u32x8();
3121+
let overflow: u32x8 = simd_ge(count, u32x8::splat(32));
3122+
simd_select(overflow, u32x8::ZERO, simd_shr(a.as_u32x8(), count)).as_m256i()
3123+
}
30923124
}
30933125

30943126
/// Shifts packed 64-bit integers in `a` right by the amount specified by
@@ -3100,7 +3132,11 @@ pub fn _mm256_srlv_epi32(a: __m256i, count: __m256i) -> __m256i {
31003132
#[cfg_attr(test, assert_instr(vpsrlvq))]
31013133
#[stable(feature = "simd_x86", since = "1.27.0")]
31023134
pub fn _mm_srlv_epi64(a: __m128i, count: __m128i) -> __m128i {
3103-
unsafe { transmute(psrlvq(a.as_i64x2(), count.as_i64x2())) }
3135+
unsafe {
3136+
let count = count.as_u64x2();
3137+
let overflow: u64x2 = simd_ge(count, u64x2::splat(64));
3138+
simd_select(overflow, u64x2::ZERO, simd_shr(a.as_u64x2(), count)).as_m128i()
3139+
}
31043140
}
31053141

31063142
/// Shifts packed 64-bit integers in `a` right by the amount specified by
@@ -3112,7 +3148,11 @@ pub fn _mm_srlv_epi64(a: __m128i, count: __m128i) -> __m128i {
31123148
#[cfg_attr(test, assert_instr(vpsrlvq))]
31133149
#[stable(feature = "simd_x86", since = "1.27.0")]
31143150
pub fn _mm256_srlv_epi64(a: __m256i, count: __m256i) -> __m256i {
3115-
unsafe { transmute(psrlvq256(a.as_i64x4(), count.as_i64x4())) }
3151+
unsafe {
3152+
let count = count.as_u64x4();
3153+
let overflow: u64x4 = simd_ge(count, u64x4::splat(64));
3154+
simd_select(overflow, u64x4::ZERO, simd_shr(a.as_u64x4(), count)).as_m256i()
3155+
}
31163156
}
31173157

31183158
/// Load 256-bits of integer data from memory into dst using a non-temporal memory hint. mem_addr
@@ -3687,36 +3727,16 @@ unsafe extern "C" {
36873727
fn pslld(a: i32x8, count: i32x4) -> i32x8;
36883728
#[link_name = "llvm.x86.avx2.psll.q"]
36893729
fn psllq(a: i64x4, count: i64x2) -> i64x4;
3690-
#[link_name = "llvm.x86.avx2.psllv.d"]
3691-
fn psllvd(a: i32x4, count: i32x4) -> i32x4;
3692-
#[link_name = "llvm.x86.avx2.psllv.d.256"]
3693-
fn psllvd256(a: i32x8, count: i32x8) -> i32x8;
3694-
#[link_name = "llvm.x86.avx2.psllv.q"]
3695-
fn psllvq(a: i64x2, count: i64x2) -> i64x2;
3696-
#[link_name = "llvm.x86.avx2.psllv.q.256"]
3697-
fn psllvq256(a: i64x4, count: i64x4) -> i64x4;
36983730
#[link_name = "llvm.x86.avx2.psra.w"]
36993731
fn psraw(a: i16x16, count: i16x8) -> i16x16;
37003732
#[link_name = "llvm.x86.avx2.psra.d"]
37013733
fn psrad(a: i32x8, count: i32x4) -> i32x8;
3702-
#[link_name = "llvm.x86.avx2.psrav.d"]
3703-
fn psravd(a: i32x4, count: i32x4) -> i32x4;
3704-
#[link_name = "llvm.x86.avx2.psrav.d.256"]
3705-
fn psravd256(a: i32x8, count: i32x8) -> i32x8;
37063734
#[link_name = "llvm.x86.avx2.psrl.w"]
37073735
fn psrlw(a: i16x16, count: i16x8) -> i16x16;
37083736
#[link_name = "llvm.x86.avx2.psrl.d"]
37093737
fn psrld(a: i32x8, count: i32x4) -> i32x8;
37103738
#[link_name = "llvm.x86.avx2.psrl.q"]
37113739
fn psrlq(a: i64x4, count: i64x2) -> i64x4;
3712-
#[link_name = "llvm.x86.avx2.psrlv.d"]
3713-
fn psrlvd(a: i32x4, count: i32x4) -> i32x4;
3714-
#[link_name = "llvm.x86.avx2.psrlv.d.256"]
3715-
fn psrlvd256(a: i32x8, count: i32x8) -> i32x8;
3716-
#[link_name = "llvm.x86.avx2.psrlv.q"]
3717-
fn psrlvq(a: i64x2, count: i64x2) -> i64x2;
3718-
#[link_name = "llvm.x86.avx2.psrlv.q.256"]
3719-
fn psrlvq256(a: i64x4, count: i64x4) -> i64x4;
37203740
#[link_name = "llvm.x86.avx2.pshuf.b"]
37213741
fn pshufb(a: u8x32, b: u8x32) -> u8x32;
37223742
#[link_name = "llvm.x86.avx2.permd"]

crates/core_arch/src/x86/avx512bw.rs

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6852,7 +6852,11 @@ pub fn _mm_maskz_slli_epi16<const IMM8: u32>(k: __mmask8, a: __m128i) -> __m128i
68526852
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
68536853
#[cfg_attr(test, assert_instr(vpsllvw))]
68546854
pub fn _mm512_sllv_epi16(a: __m512i, count: __m512i) -> __m512i {
6855-
unsafe { transmute(vpsllvw(a.as_i16x32(), count.as_i16x32())) }
6855+
unsafe {
6856+
let count = count.as_u16x32();
6857+
let overflow: u16x32 = simd_ge(count, u16x32::splat(16));
6858+
simd_select(overflow, u16x32::ZERO, simd_shl(a.as_u16x32(), count)).as_m512i()
6859+
}
68566860
}
68576861

68586862
/// Shift packed 16-bit integers in a left by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -6891,7 +6895,11 @@ pub fn _mm512_maskz_sllv_epi16(k: __mmask32, a: __m512i, count: __m512i) -> __m5
68916895
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
68926896
#[cfg_attr(test, assert_instr(vpsllvw))]
68936897
pub fn _mm256_sllv_epi16(a: __m256i, count: __m256i) -> __m256i {
6894-
unsafe { transmute(vpsllvw256(a.as_i16x16(), count.as_i16x16())) }
6898+
unsafe {
6899+
let count = count.as_u16x16();
6900+
let overflow: u16x16 = simd_ge(count, u16x16::splat(16));
6901+
simd_select(overflow, u16x16::ZERO, simd_shl(a.as_u16x16(), count)).as_m256i()
6902+
}
68956903
}
68966904

68976905
/// Shift packed 16-bit integers in a left by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -6930,7 +6938,11 @@ pub fn _mm256_maskz_sllv_epi16(k: __mmask16, a: __m256i, count: __m256i) -> __m2
69306938
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
69316939
#[cfg_attr(test, assert_instr(vpsllvw))]
69326940
pub fn _mm_sllv_epi16(a: __m128i, count: __m128i) -> __m128i {
6933-
unsafe { transmute(vpsllvw128(a.as_i16x8(), count.as_i16x8())) }
6941+
unsafe {
6942+
let count = count.as_u16x8();
6943+
let overflow: u16x8 = simd_ge(count, u16x8::splat(16));
6944+
simd_select(overflow, u16x8::ZERO, simd_shl(a.as_u16x8(), count)).as_m128i()
6945+
}
69346946
}
69356947

69366948
/// Shift packed 16-bit integers in a left by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7188,7 +7200,11 @@ pub fn _mm_maskz_srli_epi16<const IMM8: i32>(k: __mmask8, a: __m128i) -> __m128i
71887200
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
71897201
#[cfg_attr(test, assert_instr(vpsrlvw))]
71907202
pub fn _mm512_srlv_epi16(a: __m512i, count: __m512i) -> __m512i {
7191-
unsafe { transmute(vpsrlvw(a.as_i16x32(), count.as_i16x32())) }
7203+
unsafe {
7204+
let count = count.as_u16x32();
7205+
let overflow: u16x32 = simd_ge(count, u16x32::splat(16));
7206+
simd_select(overflow, u16x32::ZERO, simd_shr(a.as_u16x32(), count)).as_m512i()
7207+
}
71927208
}
71937209

71947210
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7227,7 +7243,11 @@ pub fn _mm512_maskz_srlv_epi16(k: __mmask32, a: __m512i, count: __m512i) -> __m5
72277243
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
72287244
#[cfg_attr(test, assert_instr(vpsrlvw))]
72297245
pub fn _mm256_srlv_epi16(a: __m256i, count: __m256i) -> __m256i {
7230-
unsafe { transmute(vpsrlvw256(a.as_i16x16(), count.as_i16x16())) }
7246+
unsafe {
7247+
let count = count.as_u16x16();
7248+
let overflow: u16x16 = simd_ge(count, u16x16::splat(16));
7249+
simd_select(overflow, u16x16::ZERO, simd_shr(a.as_u16x16(), count)).as_m256i()
7250+
}
72317251
}
72327252

72337253
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7266,7 +7286,11 @@ pub fn _mm256_maskz_srlv_epi16(k: __mmask16, a: __m256i, count: __m256i) -> __m2
72667286
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
72677287
#[cfg_attr(test, assert_instr(vpsrlvw))]
72687288
pub fn _mm_srlv_epi16(a: __m128i, count: __m128i) -> __m128i {
7269-
unsafe { transmute(vpsrlvw128(a.as_i16x8(), count.as_i16x8())) }
7289+
unsafe {
7290+
let count = count.as_u16x8();
7291+
let overflow: u16x8 = simd_ge(count, u16x8::splat(16));
7292+
simd_select(overflow, u16x8::ZERO, simd_shr(a.as_u16x8(), count)).as_m128i()
7293+
}
72707294
}
72717295

72727296
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7511,7 +7535,11 @@ pub fn _mm_maskz_srai_epi16<const IMM8: u32>(k: __mmask8, a: __m128i) -> __m128i
75117535
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
75127536
#[cfg_attr(test, assert_instr(vpsravw))]
75137537
pub fn _mm512_srav_epi16(a: __m512i, count: __m512i) -> __m512i {
7514-
unsafe { transmute(vpsravw(a.as_i16x32(), count.as_i16x32())) }
7538+
unsafe {
7539+
let count = count.as_i16x32();
7540+
let overflow: i16x32 = simd_ge(count, i16x32::splat(16));
7541+
simd_select(overflow, i16x32::ZERO, simd_shr(a.as_i16x32(), count)).as_m512i()
7542+
}
75157543
}
75167544

75177545
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7550,7 +7578,11 @@ pub fn _mm512_maskz_srav_epi16(k: __mmask32, a: __m512i, count: __m512i) -> __m5
75507578
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
75517579
#[cfg_attr(test, assert_instr(vpsravw))]
75527580
pub fn _mm256_srav_epi16(a: __m256i, count: __m256i) -> __m256i {
7553-
unsafe { transmute(vpsravw256(a.as_i16x16(), count.as_i16x16())) }
7581+
unsafe {
7582+
let count = count.as_i16x16();
7583+
let overflow: i16x16 = simd_ge(count, i16x16::splat(16));
7584+
simd_select(overflow, i16x16::ZERO, simd_shr(a.as_i16x16(), count)).as_m256i()
7585+
}
75547586
}
75557587

75567588
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7589,7 +7621,11 @@ pub fn _mm256_maskz_srav_epi16(k: __mmask16, a: __m256i, count: __m256i) -> __m2
75897621
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
75907622
#[cfg_attr(test, assert_instr(vpsravw))]
75917623
pub fn _mm_srav_epi16(a: __m128i, count: __m128i) -> __m128i {
7592-
unsafe { transmute(vpsravw128(a.as_i16x8(), count.as_i16x8())) }
7624+
unsafe {
7625+
let count = count.as_i16x8();
7626+
let overflow: i16x8 = simd_ge(count, i16x8::splat(16));
7627+
simd_select(overflow, i16x8::ZERO, simd_shr(a.as_i16x8(), count)).as_m128i()
7628+
}
75937629
}
75947630

75957631
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -11645,33 +11681,12 @@ unsafe extern "C" {
1164511681
#[link_name = "llvm.x86.avx512.psll.w.512"]
1164611682
fn vpsllw(a: i16x32, count: i16x8) -> i16x32;
1164711683

11648-
#[link_name = "llvm.x86.avx512.psllv.w.512"]
11649-
fn vpsllvw(a: i16x32, b: i16x32) -> i16x32;
11650-
#[link_name = "llvm.x86.avx512.psllv.w.256"]
11651-
fn vpsllvw256(a: i16x16, b: i16x16) -> i16x16;
11652-
#[link_name = "llvm.x86.avx512.psllv.w.128"]
11653-
fn vpsllvw128(a: i16x8, b: i16x8) -> i16x8;
11654-
1165511684
#[link_name = "llvm.x86.avx512.psrl.w.512"]
1165611685
fn vpsrlw(a: i16x32, count: i16x8) -> i16x32;
1165711686

11658-
#[link_name = "llvm.x86.avx512.psrlv.w.512"]
11659-
fn vpsrlvw(a: i16x32, b: i16x32) -> i16x32;
11660-
#[link_name = "llvm.x86.avx512.psrlv.w.256"]
11661-
fn vpsrlvw256(a: i16x16, b: i16x16) -> i16x16;
11662-
#[link_name = "llvm.x86.avx512.psrlv.w.128"]
11663-
fn vpsrlvw128(a: i16x8, b: i16x8) -> i16x8;
11664-
1166511687
#[link_name = "llvm.x86.avx512.psra.w.512"]
1166611688
fn vpsraw(a: i16x32, count: i16x8) -> i16x32;
1166711689

11668-
#[link_name = "llvm.x86.avx512.psrav.w.512"]
11669-
fn vpsravw(a: i16x32, count: i16x32) -> i16x32;
11670-
#[link_name = "llvm.x86.avx512.psrav.w.256"]
11671-
fn vpsravw256(a: i16x16, count: i16x16) -> i16x16;
11672-
#[link_name = "llvm.x86.avx512.psrav.w.128"]
11673-
fn vpsravw128(a: i16x8, count: i16x8) -> i16x8;
11674-
1167511690
#[link_name = "llvm.x86.avx512.vpermi2var.hi.512"]
1167611691
fn vpermi2w(a: i16x32, idx: i16x32, b: i16x32) -> i16x32;
1167711692
#[link_name = "llvm.x86.avx512.vpermi2var.hi.256"]

0 commit comments

Comments
 (0)