Skip to content

Commit 9f2b7bd

Browse files
[Matrix][SYCL] Add support for bf16's wi_element (#5397)
1 parent 28aa398 commit 9f2b7bd

File tree

1 file changed

+280
-0
lines changed

1 file changed

+280
-0
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,286 @@ class wi_element {
451451
}
452452
};
453453

454+
// Note that similarly to the other matrix functions, uint16_t is used here to
455+
// represent bf16 type. Since the AMX and DPAS implementations don't support
456+
// uint16_t, this interpretation is possible. This design choice was made before
457+
// the introduction of SYCL experimental bfloat16 type. Our plan is to move
458+
// towards using the SYCL bfloat16. But since it is still experimental, we will
459+
// probably keep both uint16 interpretation and SYCL bfloat16.
460+
template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group>
461+
class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
462+
joint_matrix<uint16_t, NumRows, NumCols, Layout, Group> &M;
463+
std::size_t idx;
464+
465+
public:
466+
wi_element(joint_matrix<uint16_t, NumRows, NumCols, Layout, Group> &Mat,
467+
std::size_t i)
468+
: M(Mat), idx(i) {}
469+
operator uint16_t() {
470+
#ifdef __SYCL_DEVICE_ONLY__
471+
return __spirv_VectorExtractDynamic(M.spvm, idx);
472+
#else
473+
throw runtime_error("joint matrix is not supported on host device.",
474+
PI_INVALID_DEVICE);
475+
#endif // __SYCL_DEVICE_ONLY__
476+
}
477+
478+
explicit operator bool() {
479+
#ifdef __SYCL_DEVICE_ONLY__
480+
return __spirv_VectorExtractDynamic(M.spvm, idx) !=
481+
static_cast<uint16_t>(0);
482+
#else
483+
throw runtime_error("joint matrix is not supported on host device.",
484+
PI_INVALID_DEVICE);
485+
#endif // __SYCL_DEVICE_ONLY__
486+
}
487+
488+
wi_element &operator=(const uint16_t &rhs) {
489+
#ifdef __SYCL_DEVICE_ONLY__
490+
M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
491+
return *this;
492+
#else
493+
(void)rhs;
494+
throw runtime_error("joint matrix is not supported on host device.",
495+
PI_INVALID_DEVICE);
496+
#endif // __SYCL_DEVICE_ONLY__
497+
}
498+
499+
wi_element &
500+
operator=(const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &rhs) {
501+
#ifdef __SYCL_DEVICE_ONLY__
502+
M.spvm = __spirv_VectorInsertDynamic(
503+
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
504+
return *this;
505+
#else
506+
(void)rhs;
507+
throw runtime_error("joint matrix is not supported on host device.",
508+
PI_INVALID_DEVICE);
509+
#endif // __SYCL_DEVICE_ONLY__
510+
}
511+
512+
// We use here the following functions for conversion (bf16=>fp32 and
513+
// fp32=>bf16). This is a workaround until we are able to use
514+
// __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL once these are
515+
// supported in the CPU backend
516+
static float make_fp32(uint16_t x) {
517+
unsigned int y = x;
518+
y = y << 16;
519+
float *res = reinterpret_cast<float *>(&y);
520+
return *res;
521+
}
522+
523+
static uint16_t make_bf16(float x) {
524+
int *res = reinterpret_cast<int *>(&x);
525+
*res = *res >> 16;
526+
return (uint16_t)*res;
527+
}
528+
529+
friend uint16_t
530+
operator+(const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs,
531+
const uint16_t &rhs) {
532+
#ifdef __SYCL_DEVICE_ONLY__
533+
return make_bf16(
534+
make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) +
535+
make_fp32(rhs));
536+
#else
537+
(void)lhs;
538+
(void)rhs;
539+
throw runtime_error("joint matrix is not supported on host device.",
540+
PI_INVALID_DEVICE);
541+
#endif // __SYCL_DEVICE_ONLY__
542+
}
543+
544+
wi_element &operator+=(const uint16_t &rhs) {
545+
#ifdef __SYCL_DEVICE_ONLY__
546+
M.spvm = __spirv_VectorInsertDynamic(
547+
M.spvm,
548+
make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx)) +
549+
make_fp32(rhs)),
550+
idx);
551+
return *this;
552+
#else
553+
(void)rhs;
554+
throw runtime_error("joint matrix is not supported on host device.",
555+
PI_INVALID_DEVICE);
556+
#endif // __SYCL_DEVICE_ONLY__
557+
}
558+
559+
friend uint16_t
560+
operator-(const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs,
561+
const uint16_t &rhs) {
562+
#ifdef __SYCL_DEVICE_ONLY__
563+
return make_bf16(
564+
make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) -
565+
make_fp32(rhs));
566+
#else
567+
(void)lhs;
568+
(void)rhs;
569+
throw runtime_error("joint matrix is not supported on host device.",
570+
PI_INVALID_DEVICE);
571+
#endif // __SYCL_DEVICE_ONLY__
572+
}
573+
574+
wi_element &operator-=(const uint16_t &rhs) {
575+
#ifdef __SYCL_DEVICE_ONLY__
576+
M.spvm = __spirv_VectorInsertDynamic(
577+
M.spvm,
578+
make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx)) -
579+
make_fp32(rhs)),
580+
idx);
581+
return *this;
582+
#else
583+
(void)rhs;
584+
throw runtime_error("joint matrix is not supported on host device.",
585+
PI_INVALID_DEVICE);
586+
#endif // __SYCL_DEVICE_ONLY__
587+
}
588+
589+
friend uint16_t
590+
operator*(const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs,
591+
const uint16_t &rhs) {
592+
#ifdef __SYCL_DEVICE_ONLY__
593+
return make_bf16(
594+
make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) *
595+
make_fp32(rhs));
596+
#else
597+
(void)lhs;
598+
(void)rhs;
599+
throw runtime_error("joint matrix is not supported on host device.",
600+
PI_INVALID_DEVICE);
601+
#endif // __SYCL_DEVICE_ONLY__
602+
}
603+
604+
wi_element &operator*=(const uint16_t &rhs) {
605+
#ifdef __SYCL_DEVICE_ONLY__
606+
M.spvm = __spirv_VectorInsertDynamic(
607+
M.spvm,
608+
make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx)) *
609+
make_fp32(rhs)),
610+
idx);
611+
return *this;
612+
#else
613+
(void)rhs;
614+
throw runtime_error("joint matrix is not supported on host device.",
615+
PI_INVALID_DEVICE);
616+
#endif // __SYCL_DEVICE_ONLY__
617+
}
618+
619+
friend uint16_t
620+
operator/(const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs,
621+
const uint16_t &rhs) {
622+
#ifdef __SYCL_DEVICE_ONLY__
623+
return make_bf16(
624+
make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) /
625+
make_fp32(rhs));
626+
#else
627+
(void)lhs;
628+
(void)rhs;
629+
throw runtime_error("joint matrix is not supported on host device.",
630+
PI_INVALID_DEVICE);
631+
#endif // __SYCL_DEVICE_ONLY__
632+
}
633+
634+
wi_element &operator/=(const uint16_t &rhs) {
635+
#ifdef __SYCL_DEVICE_ONLY__
636+
M.spvm = __spirv_VectorInsertDynamic(
637+
M.spvm,
638+
make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx)) /
639+
make_fp32(rhs)),
640+
idx);
641+
return *this;
642+
#else
643+
(void)rhs;
644+
throw runtime_error("joint matrix is not supported on host device.",
645+
PI_INVALID_DEVICE);
646+
#endif // __SYCL_DEVICE_ONLY__
647+
}
648+
649+
friend bool
650+
operator<(const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs,
651+
const uint16_t &rhs) {
652+
#ifdef __SYCL_DEVICE_ONLY__
653+
return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) <
654+
make_fp32(rhs);
655+
#else
656+
(void)lhs;
657+
(void)rhs;
658+
throw runtime_error("joint matrix is not supported on host device.",
659+
PI_INVALID_DEVICE);
660+
#endif // __SYCL_DEVICE_ONLY__
661+
}
662+
663+
friend bool
664+
operator<=(const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs,
665+
const uint16_t &rhs) {
666+
#ifdef __SYCL_DEVICE_ONLY__
667+
return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) <=
668+
make_fp32(rhs);
669+
#else
670+
(void)lhs;
671+
(void)rhs;
672+
throw runtime_error("joint matrix is not supported on host device.",
673+
PI_INVALID_DEVICE);
674+
#endif // __SYCL_DEVICE_ONLY__
675+
}
676+
677+
friend bool
678+
operator>(const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs,
679+
const uint16_t &rhs) {
680+
#ifdef __SYCL_DEVICE_ONLY__
681+
return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) >
682+
make_fp32(rhs);
683+
#else
684+
(void)lhs;
685+
(void)rhs;
686+
throw runtime_error("joint matrix is not supported on host device.",
687+
PI_INVALID_DEVICE);
688+
#endif // __SYCL_DEVICE_ONLY__
689+
}
690+
691+
friend bool
692+
operator>=(const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs,
693+
const uint16_t &rhs) {
694+
#ifdef __SYCL_DEVICE_ONLY__
695+
return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) >=
696+
make_fp32(rhs);
697+
#else
698+
(void)lhs;
699+
(void)rhs;
700+
throw runtime_error("joint matrix is not supported on host device.",
701+
PI_INVALID_DEVICE);
702+
#endif // __SYCL_DEVICE_ONLY__
703+
}
704+
705+
friend bool
706+
operator==(const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs,
707+
const uint16_t &rhs) {
708+
#ifdef __SYCL_DEVICE_ONLY__
709+
return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) ==
710+
make_fp32(rhs);
711+
#else
712+
(void)lhs;
713+
(void)rhs;
714+
throw runtime_error("joint matrix is not supported on host device.",
715+
PI_INVALID_DEVICE);
716+
#endif // __SYCL_DEVICE_ONLY__
717+
}
718+
719+
friend bool
720+
operator!=(const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs,
721+
const uint16_t &rhs) {
722+
#ifdef __SYCL_DEVICE_ONLY__
723+
return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) !=
724+
make_fp32(rhs);
725+
#else
726+
(void)lhs;
727+
(void)rhs;
728+
throw runtime_error("joint matrix is not supported on host device.",
729+
PI_INVALID_DEVICE);
730+
#endif // __SYCL_DEVICE_ONLY__
731+
}
732+
};
733+
454734
template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout,
455735
typename Group>
456736
class wi_slice {

0 commit comments

Comments
 (0)