@@ -451,6 +451,286 @@ class wi_element {
451
451
}
452
452
};
453
453
454
+ // Note that similarly to the other matrix functions, uint16_t is used here to
455
+ // represent bf16 type. Since the AMX and DPAS implementations don't support
456
+ // uint16_t, this interpretation is possible. This design choice was made before
457
+ // the introduction of SYCL experimental bfloat16 type. Our plan is to move
458
+ // towards using the SYCL bfloat16. But since it is still experimental, we will
459
+ // probably keep both uint16 interpretation and SYCL bfloat16.
460
+ template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group>
461
+ class wi_element <uint16_t , NumRows, NumCols, Layout, Group> {
462
+ joint_matrix<uint16_t , NumRows, NumCols, Layout, Group> &M;
463
+ std::size_t idx;
464
+
465
+ public:
466
+ wi_element (joint_matrix<uint16_t , NumRows, NumCols, Layout, Group> &Mat,
467
+ std::size_t i)
468
+ : M(Mat), idx(i) {}
469
+ operator uint16_t () {
470
+ #ifdef __SYCL_DEVICE_ONLY__
471
+ return __spirv_VectorExtractDynamic (M.spvm , idx);
472
+ #else
473
+ throw runtime_error (" joint matrix is not supported on host device." ,
474
+ PI_INVALID_DEVICE);
475
+ #endif // __SYCL_DEVICE_ONLY__
476
+ }
477
+
478
+ explicit operator bool () {
479
+ #ifdef __SYCL_DEVICE_ONLY__
480
+ return __spirv_VectorExtractDynamic (M.spvm , idx) !=
481
+ static_cast <uint16_t >(0 );
482
+ #else
483
+ throw runtime_error (" joint matrix is not supported on host device." ,
484
+ PI_INVALID_DEVICE);
485
+ #endif // __SYCL_DEVICE_ONLY__
486
+ }
487
+
488
+ wi_element &operator =(const uint16_t &rhs) {
489
+ #ifdef __SYCL_DEVICE_ONLY__
490
+ M.spvm = __spirv_VectorInsertDynamic (M.spvm , rhs, idx);
491
+ return *this ;
492
+ #else
493
+ (void )rhs;
494
+ throw runtime_error (" joint matrix is not supported on host device." ,
495
+ PI_INVALID_DEVICE);
496
+ #endif // __SYCL_DEVICE_ONLY__
497
+ }
498
+
499
+ wi_element &
500
+ operator =(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &rhs) {
501
+ #ifdef __SYCL_DEVICE_ONLY__
502
+ M.spvm = __spirv_VectorInsertDynamic (
503
+ M.spvm , __spirv_VectorExtractDynamic (rhs.M .spvm , rhs.idx ), idx);
504
+ return *this ;
505
+ #else
506
+ (void )rhs;
507
+ throw runtime_error (" joint matrix is not supported on host device." ,
508
+ PI_INVALID_DEVICE);
509
+ #endif // __SYCL_DEVICE_ONLY__
510
+ }
511
+
512
+ // We use here the following functions for conversion (bf16=>fp32 and
513
+ // fp32=>bf16). This is a workaround until we are able to use
514
+ // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL once these are
515
+ // supported in the CPU backend
516
+ static float make_fp32 (uint16_t x) {
517
+ unsigned int y = x;
518
+ y = y << 16 ;
519
+ float *res = reinterpret_cast <float *>(&y);
520
+ return *res;
521
+ }
522
+
523
+ static uint16_t make_bf16 (float x) {
524
+ int *res = reinterpret_cast <int *>(&x);
525
+ *res = *res >> 16 ;
526
+ return (uint16_t )*res;
527
+ }
528
+
529
+ friend uint16_t
530
+ operator +(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
531
+ const uint16_t &rhs) {
532
+ #ifdef __SYCL_DEVICE_ONLY__
533
+ return make_bf16 (
534
+ make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) +
535
+ make_fp32 (rhs));
536
+ #else
537
+ (void )lhs;
538
+ (void )rhs;
539
+ throw runtime_error (" joint matrix is not supported on host device." ,
540
+ PI_INVALID_DEVICE);
541
+ #endif // __SYCL_DEVICE_ONLY__
542
+ }
543
+
544
+ wi_element &operator +=(const uint16_t &rhs) {
545
+ #ifdef __SYCL_DEVICE_ONLY__
546
+ M.spvm = __spirv_VectorInsertDynamic (
547
+ M.spvm ,
548
+ make_bf16 (make_fp32 (__spirv_VectorExtractDynamic (M.spvm , idx)) +
549
+ make_fp32 (rhs)),
550
+ idx);
551
+ return *this ;
552
+ #else
553
+ (void )rhs;
554
+ throw runtime_error (" joint matrix is not supported on host device." ,
555
+ PI_INVALID_DEVICE);
556
+ #endif // __SYCL_DEVICE_ONLY__
557
+ }
558
+
559
+ friend uint16_t
560
+ operator -(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
561
+ const uint16_t &rhs) {
562
+ #ifdef __SYCL_DEVICE_ONLY__
563
+ return make_bf16 (
564
+ make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) -
565
+ make_fp32 (rhs));
566
+ #else
567
+ (void )lhs;
568
+ (void )rhs;
569
+ throw runtime_error (" joint matrix is not supported on host device." ,
570
+ PI_INVALID_DEVICE);
571
+ #endif // __SYCL_DEVICE_ONLY__
572
+ }
573
+
574
+ wi_element &operator -=(const uint16_t &rhs) {
575
+ #ifdef __SYCL_DEVICE_ONLY__
576
+ M.spvm = __spirv_VectorInsertDynamic (
577
+ M.spvm ,
578
+ make_bf16 (make_fp32 (__spirv_VectorExtractDynamic (M.spvm , idx)) -
579
+ make_fp32 (rhs)),
580
+ idx);
581
+ return *this ;
582
+ #else
583
+ (void )rhs;
584
+ throw runtime_error (" joint matrix is not supported on host device." ,
585
+ PI_INVALID_DEVICE);
586
+ #endif // __SYCL_DEVICE_ONLY__
587
+ }
588
+
589
+ friend uint16_t
590
+ operator *(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
591
+ const uint16_t &rhs) {
592
+ #ifdef __SYCL_DEVICE_ONLY__
593
+ return make_bf16 (
594
+ make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) *
595
+ make_fp32 (rhs));
596
+ #else
597
+ (void )lhs;
598
+ (void )rhs;
599
+ throw runtime_error (" joint matrix is not supported on host device." ,
600
+ PI_INVALID_DEVICE);
601
+ #endif // __SYCL_DEVICE_ONLY__
602
+ }
603
+
604
+ wi_element &operator *=(const uint16_t &rhs) {
605
+ #ifdef __SYCL_DEVICE_ONLY__
606
+ M.spvm = __spirv_VectorInsertDynamic (
607
+ M.spvm ,
608
+ make_bf16 (make_fp32 (__spirv_VectorExtractDynamic (M.spvm , idx)) *
609
+ make_fp32 (rhs)),
610
+ idx);
611
+ return *this ;
612
+ #else
613
+ (void )rhs;
614
+ throw runtime_error (" joint matrix is not supported on host device." ,
615
+ PI_INVALID_DEVICE);
616
+ #endif // __SYCL_DEVICE_ONLY__
617
+ }
618
+
619
+ friend uint16_t
620
+ operator /(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
621
+ const uint16_t &rhs) {
622
+ #ifdef __SYCL_DEVICE_ONLY__
623
+ return make_bf16 (
624
+ make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) /
625
+ make_fp32 (rhs));
626
+ #else
627
+ (void )lhs;
628
+ (void )rhs;
629
+ throw runtime_error (" joint matrix is not supported on host device." ,
630
+ PI_INVALID_DEVICE);
631
+ #endif // __SYCL_DEVICE_ONLY__
632
+ }
633
+
634
+ wi_element &operator /=(const uint16_t &rhs) {
635
+ #ifdef __SYCL_DEVICE_ONLY__
636
+ M.spvm = __spirv_VectorInsertDynamic (
637
+ M.spvm ,
638
+ make_bf16 (make_fp32 (__spirv_VectorExtractDynamic (M.spvm , idx)) /
639
+ make_fp32 (rhs)),
640
+ idx);
641
+ return *this ;
642
+ #else
643
+ (void )rhs;
644
+ throw runtime_error (" joint matrix is not supported on host device." ,
645
+ PI_INVALID_DEVICE);
646
+ #endif // __SYCL_DEVICE_ONLY__
647
+ }
648
+
649
+ friend bool
650
+ operator <(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
651
+ const uint16_t &rhs) {
652
+ #ifdef __SYCL_DEVICE_ONLY__
653
+ return make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) <
654
+ make_fp32 (rhs);
655
+ #else
656
+ (void )lhs;
657
+ (void )rhs;
658
+ throw runtime_error (" joint matrix is not supported on host device." ,
659
+ PI_INVALID_DEVICE);
660
+ #endif // __SYCL_DEVICE_ONLY__
661
+ }
662
+
663
+ friend bool
664
+ operator <=(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
665
+ const uint16_t &rhs) {
666
+ #ifdef __SYCL_DEVICE_ONLY__
667
+ return make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) <=
668
+ make_fp32 (rhs);
669
+ #else
670
+ (void )lhs;
671
+ (void )rhs;
672
+ throw runtime_error (" joint matrix is not supported on host device." ,
673
+ PI_INVALID_DEVICE);
674
+ #endif // __SYCL_DEVICE_ONLY__
675
+ }
676
+
677
+ friend bool
678
+ operator >(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
679
+ const uint16_t &rhs) {
680
+ #ifdef __SYCL_DEVICE_ONLY__
681
+ return make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) >
682
+ make_fp32 (rhs);
683
+ #else
684
+ (void )lhs;
685
+ (void )rhs;
686
+ throw runtime_error (" joint matrix is not supported on host device." ,
687
+ PI_INVALID_DEVICE);
688
+ #endif // __SYCL_DEVICE_ONLY__
689
+ }
690
+
691
+ friend bool
692
+ operator >=(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
693
+ const uint16_t &rhs) {
694
+ #ifdef __SYCL_DEVICE_ONLY__
695
+ return make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) >=
696
+ make_fp32 (rhs);
697
+ #else
698
+ (void )lhs;
699
+ (void )rhs;
700
+ throw runtime_error (" joint matrix is not supported on host device." ,
701
+ PI_INVALID_DEVICE);
702
+ #endif // __SYCL_DEVICE_ONLY__
703
+ }
704
+
705
+ friend bool
706
+ operator ==(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
707
+ const uint16_t &rhs) {
708
+ #ifdef __SYCL_DEVICE_ONLY__
709
+ return make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) ==
710
+ make_fp32 (rhs);
711
+ #else
712
+ (void )lhs;
713
+ (void )rhs;
714
+ throw runtime_error (" joint matrix is not supported on host device." ,
715
+ PI_INVALID_DEVICE);
716
+ #endif // __SYCL_DEVICE_ONLY__
717
+ }
718
+
719
+ friend bool
720
+ operator !=(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
721
+ const uint16_t &rhs) {
722
+ #ifdef __SYCL_DEVICE_ONLY__
723
+ return make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) !=
724
+ make_fp32 (rhs);
725
+ #else
726
+ (void )lhs;
727
+ (void )rhs;
728
+ throw runtime_error (" joint matrix is not supported on host device." ,
729
+ PI_INVALID_DEVICE);
730
+ #endif // __SYCL_DEVICE_ONLY__
731
+ }
732
+ };
733
+
454
734
template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout,
455
735
typename Group>
456
736
class wi_slice {
0 commit comments