Skip to content

Commit bcf81e8

Browse files
committed
Implemented the missing AVX512BF16 intrinsics
1 parent 5a4089c commit bcf81e8

File tree

3 files changed

+186
-16
lines changed

3 files changed

+186
-16
lines changed

crates/core_arch/missing-x86.md

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -147,27 +147,12 @@
147147
</p></details>
148148

149149

150-
<details><summary>["AVX512_BF16", "AVX512F"]</summary><p>
151-
152-
* [ ] [`_mm512_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_cvtpbh_ps)
153-
* [ ] [`_mm512_mask_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_mask_cvtpbh_ps)
154-
* [ ] [`_mm512_maskz_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_maskz_cvtpbh_ps)
155-
* [ ] [`_mm_cvtsbh_ss`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtsbh_ss)
156-
</p></details>
157-
158-
159150
<details><summary>["AVX512_BF16", "AVX512VL"]</summary><p>
160151

161-
* [ ] [`_mm256_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_cvtpbh_ps)
162-
* [ ] [`_mm256_mask_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_mask_cvtpbh_ps)
163-
* [ ] [`_mm256_maskz_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_maskz_cvtpbh_ps)
164152
* [ ] [`_mm_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneps_pbh)
165153
* [ ] [`_mm_cvtness_sbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtness_sbh)
166-
* [ ] [`_mm_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtpbh_ps)
167154
* [ ] [`_mm_mask_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mask_cvtneps_pbh)
168-
* [ ] [`_mm_mask_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mask_cvtpbh_ps)
169155
* [ ] [`_mm_maskz_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_maskz_cvtneps_pbh)
170-
* [ ] [`_mm_maskz_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_maskz_cvtpbh_ps)
171156
</p></details>
172157

173158

crates/core_arch/src/x86/avx512bf16.rs

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,85 @@ pub unsafe fn _mm512_maskz_dpbf16_ps(
365365
transmute(simd_select_bitmask(k, rst, zero))
366366
}
367367

368+
#[inline]
369+
#[target_feature(enable = "avx512bf16,avx512f")]
370+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
371+
pub unsafe fn _mm512_cvtpbh_ps(a: __m256bh) -> __m512 {
372+
_mm512_castsi512_ps(_mm512_slli_epi32::<16>(_mm512_cvtepi16_epi32(transmute(a))))
373+
}
374+
375+
#[inline]
376+
#[target_feature(enable = "avx512bf16,avx512f")]
377+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
378+
pub unsafe fn _mm512_mask_cvtpbh_ps(src: __m512, k: __mmask16, a: __m256bh) -> __m512 {
379+
let cvt = _mm512_cvtpbh_ps(a);
380+
transmute(simd_select_bitmask(k, cvt.as_f32x16(), src.as_f32x16()))
381+
}
382+
383+
#[inline]
384+
#[target_feature(enable = "avx512bf16,avx512f")]
385+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
386+
pub unsafe fn _mm512_maskz_cvtpbh_ps(k: __mmask16, a: __m256bh) -> __m512 {
387+
let cvt = _mm512_cvtpbh_ps(a);
388+
let zero = _mm512_setzero_ps();
389+
transmute(simd_select_bitmask(k, cvt.as_f32x16(), zero.as_f32x16()))
390+
}
391+
392+
#[inline]
393+
#[target_feature(enable = "avx512bf16,avx512vl")]
394+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
395+
pub unsafe fn _mm256_cvtpbh_ps(a: __m128bh) -> __m256 {
396+
_mm256_castsi256_ps(_mm256_slli_epi32::<16>(_mm256_cvtepi16_epi32(transmute(a))))
397+
}
398+
399+
#[inline]
400+
#[target_feature(enable = "avx512bf16,avx512vl")]
401+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
402+
pub unsafe fn _mm256_mask_cvtpbh_ps(src: __m256, k: __mmask8, a: __m128bh) -> __m256 {
403+
let cvt = _mm256_cvtpbh_ps(a);
404+
transmute(simd_select_bitmask(k, cvt.as_f32x8(), src.as_f32x8()))
405+
}
406+
407+
#[inline]
408+
#[target_feature(enable = "avx512bf16,avx512vl")]
409+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
410+
pub unsafe fn _mm256_maskz_cvtpbh_ps(k: __mmask8, a: __m128bh) -> __m256 {
411+
let cvt = _mm256_cvtpbh_ps(a);
412+
let zero = _mm256_setzero_ps();
413+
transmute(simd_select_bitmask(k, cvt.as_f32x8(), zero.as_f32x8()))
414+
}
415+
416+
#[inline]
417+
#[target_feature(enable = "avx512bf16,avx512vl")]
418+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
419+
pub unsafe fn _mm_cvtpbh_ps(a: __m128bh) -> __m128 {
420+
_mm_castsi128_ps(_mm_slli_epi32::<16>(_mm_cvtepi16_epi32(transmute(a))))
421+
}
422+
423+
#[inline]
424+
#[target_feature(enable = "avx512bf16,avx512vl")]
425+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
426+
pub unsafe fn _mm_mask_cvtpbh_ps(src: __m128, k: __mmask8, a: __m128bh) -> __m128 {
427+
let cvt = _mm_cvtpbh_ps(a);
428+
transmute(simd_select_bitmask(k, cvt.as_f32x4(), src.as_f32x4()))
429+
}
430+
431+
#[inline]
432+
#[target_feature(enable = "avx512bf16,avx512vl")]
433+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
434+
pub unsafe fn _mm_maskz_cvtpbh_ps(k: __mmask8, a: __m128bh) -> __m128 {
435+
let cvt = _mm_cvtpbh_ps(a);
436+
let zero = _mm_setzero_ps();
437+
transmute(simd_select_bitmask(k, cvt.as_f32x4(), zero.as_f32x4()))
438+
}
439+
440+
#[inline]
441+
#[target_feature(enable = "avx512bf16,avx512f")]
442+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
443+
pub unsafe fn _mm_cvtsbh_ss(a: u16) -> f32 {
444+
f32::from_bits((a as u32) << 16)
445+
}
446+
368447
#[cfg(test)]
369448
mod tests {
370449
use crate::{core_arch::x86::*, mem::transmute};
@@ -1592,4 +1671,110 @@ mod tests {
15921671
];
15931672
assert_eq!(result, expected_result);
15941673
}
1674+
1675+
const BF16_ONE: u16 = 0b0_01111111_0000000;
1676+
const BF16_TWO: u16 = 0b0_10000000_0000000;
1677+
const BF16_THREE: u16 = 0b0_10000000_1000000;
1678+
const BF16_FOUR: u16 = 0b0_10000001_0000000;
1679+
const BF16_FIVE: u16 = 0b0_10000001_0100000;
1680+
const BF16_SIX: u16 = 0b0_10000001_1000000;
1681+
const BF16_SEVEN: u16 = 0b0_10000001_1100000;
1682+
const BF16_EIGHT: u16 = 0b0_10000010_0000000;
1683+
1684+
#[simd_test(enable = "avx512bf16")]
1685+
unsafe fn test_mm512_cvtpbh_ps() {
1686+
let a = __m256bh(
1687+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1688+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1689+
);
1690+
let r = _mm512_cvtpbh_ps(a);
1691+
let e = _mm512_setr_ps(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
1692+
assert_eq_m512(r, e);
1693+
}
1694+
1695+
#[simd_test(enable = "avx512bf16")]
1696+
unsafe fn test_mm512_mask_cvtpbh_ps() {
1697+
let a = __m256bh(
1698+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1699+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1700+
);
1701+
let src = _mm512_setr_ps(9., 10., 11., 12., 13., 14., 15., 16., 9., 10., 11., 12., 13., 14., 15., 16.);
1702+
let k = 0b1010_1010_1010_1010;
1703+
let r = _mm512_mask_cvtpbh_ps(src, k, a);
1704+
let e = _mm512_setr_ps(9., 2., 11., 4., 13., 6., 15., 8., 9., 2., 11., 4., 13., 6., 15., 8.);
1705+
assert_eq_m512(r, e);
1706+
}
1707+
1708+
#[simd_test(enable = "avx512bf16")]
1709+
unsafe fn test_mm512_maskz_cvtpbh_ps() {
1710+
let a = __m256bh(
1711+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1712+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1713+
);
1714+
let k = 0b1010_1010_1010_1010;
1715+
let r = _mm512_maskz_cvtpbh_ps(k, a);
1716+
let e = _mm512_setr_ps(0., 2., 0., 4., 0., 6., 0., 8., 0., 2., 0., 4., 0., 6., 0., 8.);
1717+
assert_eq_m512(r, e);
1718+
}
1719+
1720+
#[simd_test(enable = "avx512bf16,avx512vl")]
1721+
unsafe fn test_mm256_cvtpbh_ps() {
1722+
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT);
1723+
let r = _mm256_cvtpbh_ps(a);
1724+
let e = _mm256_setr_ps(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
1725+
assert_eq_m256(r, e);
1726+
}
1727+
1728+
#[simd_test(enable = "avx512bf16,avx512vl")]
1729+
unsafe fn test_mm256_mask_cvtpbh_ps() {
1730+
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT);
1731+
let src = _mm256_setr_ps(9., 10., 11., 12., 13., 14., 15., 16.);
1732+
let k = 0b1010_1010;
1733+
let r = _mm256_mask_cvtpbh_ps(src, k, a);
1734+
let e = _mm256_setr_ps(9., 2., 11., 4., 13., 6., 15., 8.);
1735+
assert_eq_m256(r, e);
1736+
}
1737+
1738+
#[simd_test(enable = "avx512bf16,avx512vl")]
1739+
unsafe fn test_mm256_maskz_cvtpbh_ps() {
1740+
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT);
1741+
let k = 0b1010_1010;
1742+
let r = _mm256_maskz_cvtpbh_ps(k, a);
1743+
let e = _mm256_setr_ps(0., 2., 0., 4., 0., 6., 0., 8.);
1744+
assert_eq_m256(r, e);
1745+
}
1746+
1747+
#[simd_test(enable = "avx512bf16,avx512vl")]
1748+
unsafe fn test_mm_cvtpbh_ps() {
1749+
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, 0, 0, 0, 0);
1750+
let r = _mm_cvtpbh_ps(a);
1751+
let e = _mm_setr_ps(1.0, 2.0, 3.0, 4.0);
1752+
assert_eq_m128(r, e);
1753+
}
1754+
1755+
#[simd_test(enable = "avx512bf16,avx512vl")]
1756+
unsafe fn test_mm_mask_cvtpbh_ps() {
1757+
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, 0, 0, 0, 0);
1758+
let src = _mm_setr_ps(9., 10., 11., 12.);
1759+
let k = 0b1010;
1760+
let r = _mm_mask_cvtpbh_ps(src, k, a);
1761+
let e = _mm_setr_ps(9., 2., 11., 4.);
1762+
assert_eq_m128(r, e);
1763+
}
1764+
1765+
#[simd_test(enable = "avx512bf16,avx512vl")]
1766+
unsafe fn test_mm_maskz_cvtpbh_ps() {
1767+
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, 0, 0, 0, 0);
1768+
let k = 0b1010;
1769+
let r = _mm_maskz_cvtpbh_ps(k, a);
1770+
let e = _mm_setr_ps(0., 2., 0., 4.);
1771+
assert_eq_m128(r, e);
1772+
}
1773+
1774+
#[simd_test(enable = "avx512bf16")]
1775+
unsafe fn test_mm_cvtsbh_ss() {
1776+
let r = _mm_cvtsbh_ss(BF16_ONE);
1777+
assert_eq!(r, 1.);
1778+
}
1779+
15951780
}

crates/stdarch-verify/tests/x86-intel.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ fn equate(
699699
(&Type::PrimSigned(32), "__int32" | "const int" | "int") => {}
700700
(&Type::PrimSigned(64), "__int64" | "long long") => {}
701701
(&Type::PrimUnsigned(8), "unsigned char") => {}
702-
(&Type::PrimUnsigned(16), "unsigned short") => {}
702+
(&Type::PrimUnsigned(16), "unsigned short" | "__bfloat16") => {}
703703
(
704704
&Type::PrimUnsigned(32),
705705
"unsigned __int32" | "unsigned int" | "unsigned long" | "const unsigned int",

0 commit comments

Comments
 (0)