@@ -1480,19 +1480,74 @@ class GemmBatchFunctorThreadNM_vecm
1480
1480
}
1481
1481
};
1482
1482
1483
- template < typename resT> struct GemmBatchFunctorThreadNM_vecm_HyperParameters
1483
+ struct GemmBatchFunctorThreadNM_vecm_HyperParameters
1484
1484
{
1485
- static constexpr std::uint32_t wi_delta_n = 4 ;
1486
- static constexpr std::uint32_t wi_delta_m_vecs = 1 ;
1487
- static constexpr std::uint32_t m_vec_size = 4 ;
1485
+ private:
1486
+ std::uint32_t wi_delta_n = 2 ;
1487
+ std::uint32_t wi_delta_m_vecs = 4 ;
1488
+ std::uint32_t m_vec_size = 1 ;
1489
+
1490
+ public:
1491
+ constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters ();
1492
+ constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters (
1493
+ std::uint32_t wi_delta_n_,
1494
+ std::uint32_t wi_delta_m_vecs_,
1495
+ std::uint32_t m_vec_size_)
1496
+ : wi_delta_n(wi_delta_n_), wi_delta_m_vecs(wi_delta_m_vecs_),
1497
+ m_vec_size(m_vec_size_)
1498
+ {
1499
+ }
1500
+
1501
+ constexpr std::uint32_t get_wi_delta_n () const
1502
+ {
1503
+ return wi_delta_n;
1504
+ }
1505
+ constexpr std::uint32_t get_wi_delta_m_vecs () const
1506
+ {
1507
+ return wi_delta_m_vecs;
1508
+ }
1509
+ constexpr std::uint32_t get_m_vec_size () const
1510
+ {
1511
+ return m_vec_size;
1512
+ }
1488
1513
};
1489
1514
1490
- template <typename T >
1491
- struct GemmBatchFunctorThreadNM_vecm_HyperParameters <std::complex<T>>
1515
+ template <typename resT >
1516
+ struct GemmBatchFunctorThreadNM_vecm_HyperParametersSelector
1492
1517
{
1493
- static constexpr std::uint32_t wi_delta_n = 2 ;
1494
- static constexpr std::uint32_t wi_delta_m_vecs = 2 ;
1495
- static constexpr std::uint32_t m_vec_size = 1 ;
1518
+ constexpr GemmBatchFunctorThreadNM_vecm_HyperParametersSelector () {}
1519
+
1520
+ constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters get () const
1521
+ {
1522
+ if constexpr (sizeof (resT) == 1 ) {
1523
+ // 1 * 8 * 2 * 4 == 64
1524
+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (8 , 2 , 4 );
1525
+ }
1526
+ else if constexpr (sizeof (resT) == 2 ) {
1527
+ // 2 * 4 * 2 * 4 == 64
1528
+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (4 , 2 , 4 );
1529
+ }
1530
+ else if constexpr (sizeof (resT) == 4 ) {
1531
+ // 4 * 4 * 1 * 4 == 64
1532
+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (4 , 1 , 4 );
1533
+ }
1534
+ else if constexpr (sizeof (resT) == 8 ) {
1535
+ // 8 * 2 * 1 * 4 == 64
1536
+ if constexpr (std::is_same_v<resT, std::complex<float >>) {
1537
+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (2 , 4 , 1 );
1538
+ }
1539
+ else {
1540
+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (2 , 1 , 4 );
1541
+ }
1542
+ }
1543
+ else if constexpr (std::is_same_v<resT, std::complex<double >>) {
1544
+ // 16 * 2 * 2 * 1 == 64
1545
+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (2 , 2 , 1 );
1546
+ }
1547
+ else {
1548
+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (2 , 2 , 1 );
1549
+ }
1550
+ }
1496
1551
};
1497
1552
1498
1553
template <typename T1,
@@ -1572,11 +1627,14 @@ sycl::event _gemm_batch_new_nm_impl(sycl::queue &exec_q,
1572
1627
const ResIndexerT &res_indexer,
1573
1628
std::vector<sycl::event> const &depends)
1574
1629
{
1575
- using parametersT = GemmBatchFunctorThreadNM_vecm_HyperParameters<resTy>;
1630
+ constexpr GemmBatchFunctorThreadNM_vecm_HyperParametersSelector<resTy>
1631
+ selector{};
1632
+ constexpr auto hyper_params = selector.get ();
1576
1633
1577
- constexpr std::uint32_t wi_delta_n = parametersT::wi_delta_n;
1578
- constexpr std::uint32_t wi_delta_m_vecs = parametersT::wi_delta_m_vecs;
1579
- constexpr std::uint32_t m_vec_size = parametersT::m_vec_size;
1634
+ constexpr std::uint32_t wi_delta_n = hyper_params.get_wi_delta_n ();
1635
+ constexpr std::uint32_t wi_delta_m_vecs =
1636
+ hyper_params.get_wi_delta_m_vecs ();
1637
+ constexpr std::uint32_t m_vec_size = hyper_params.get_m_vec_size ();
1580
1638
1581
1639
constexpr std::uint32_t wi_total_delta_m = wi_delta_m_vecs * m_vec_size;
1582
1640
@@ -3078,7 +3136,7 @@ gemm_batch_new_nm_impl(sycl::queue &exec_q,
3078
3136
sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
3079
3137
lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
3080
3138
OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3081
- exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, m, k , batch_indexer,
3139
+ exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m , batch_indexer,
3082
3140
lhs_indexer, rhs_indexer, res_indexer, depends);
3083
3141
3084
3142
return gemm_ev;
@@ -3643,41 +3701,67 @@ sycl::event gemm_new_nm_impl(sycl::queue &exec_q,
3643
3701
sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
3644
3702
lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
3645
3703
OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3646
- exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, m, k ,
3704
+ exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m ,
3647
3705
batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends);
3648
3706
3649
3707
return gemm_ev;
3650
3708
}
3651
3709
3652
3710
template <typename lhsTy, typename rhsTy, typename resTy>
3653
3711
sycl::event
3654
- gemm_new_nm_contig_impl (sycl::queue &exec_q,
3655
- const lhsTy *lhs_tp,
3656
- const rhsTy *rhs_tp,
3657
- resTy *res_tp,
3658
- size_t n,
3659
- size_t k,
3660
- size_t m,
3661
- std::vector<sycl::event> const &depends = {})
3712
+ gemm_batch_new_nm_contig_impl (sycl::queue &exec_q,
3713
+ const lhsTy *lhs_tp,
3714
+ const rhsTy *rhs_tp,
3715
+ resTy *res_tp,
3716
+ const size_t batch_nelems,
3717
+ const size_t n,
3718
+ const size_t k,
3719
+ const size_t m,
3720
+ std::vector<sycl::event> const &depends = {})
3662
3721
{
3663
3722
using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
3664
3723
constexpr OuterInnerDimsIndexerT lhs_indexer{};
3665
3724
constexpr OuterInnerDimsIndexerT rhs_indexer{};
3666
3725
constexpr OuterInnerDimsIndexerT res_indexer{};
3667
3726
3668
- using BatchDimsIndexerT =
3669
- dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer;
3670
- constexpr BatchDimsIndexerT batch_indexer{};
3671
-
3672
3727
constexpr size_t single_batch_nelems = 1 ;
3673
3728
3674
- sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
3675
- lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
3676
- OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3677
- exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, m, k,
3678
- batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends);
3729
+ if (batch_nelems == single_batch_nelems) {
3730
+ using BatchDimsIndexerT =
3731
+ dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer;
3732
+ constexpr BatchDimsIndexerT batch_indexer{};
3679
3733
3680
- return gemm_ev;
3734
+ sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
3735
+ lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
3736
+ OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3737
+ exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m,
3738
+ batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends);
3739
+
3740
+ return gemm_ev;
3741
+ }
3742
+ else {
3743
+ using dpctl::tensor::offset_utils::Strided1DIndexer;
3744
+ using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer;
3745
+ using BatchDimsIndexerT =
3746
+ ThreeOffsets_CombinedIndexer<Strided1DIndexer, Strided1DIndexer,
3747
+ Strided1DIndexer>;
3748
+
3749
+ using dpctl::tensor::offset_utils::Strided1DIndexer;
3750
+
3751
+ const ssize_t ss_batch_nelems = static_cast <ssize_t >(batch_nelems);
3752
+ const BatchDimsIndexerT batch_indexer (
3753
+ Strided1DIndexer{0 , ss_batch_nelems, static_cast <ssize_t >(n * k)},
3754
+ Strided1DIndexer{0 , ss_batch_nelems, static_cast <ssize_t >(k * m)},
3755
+ Strided1DIndexer{0 , ss_batch_nelems, static_cast <ssize_t >(n * m)});
3756
+
3757
+ sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
3758
+ lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
3759
+ OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3760
+ exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m,
3761
+ batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends);
3762
+
3763
+ return gemm_ev;
3764
+ }
3681
3765
}
3682
3766
3683
3767
template <typename lhsTy, typename rhsTy, typename resTy>
@@ -3705,8 +3789,8 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
3705
3789
const size_t max_nm = std::max (n, m);
3706
3790
3707
3791
if (min_nm > 0 && (max_nm >= ((64 * 1024 ) / min_nm))) {
3708
- return gemm_new_nm_contig_impl <lhsTy, rhsTy, resTy>(
3709
- exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends);
3792
+ return gemm_batch_new_nm_contig_impl <lhsTy, rhsTy, resTy>(
3793
+ exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends);
3710
3794
}
3711
3795
3712
3796
if (k == 0 ) {
@@ -4518,8 +4602,10 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
4518
4602
const size_t max_nm = std::max (n, m);
4519
4603
4520
4604
if (min_nm > 0 && (max_nm >= ((64 * 1024 ) / min_nm))) {
4521
- return gemm_new_nm_contig_impl<lhsTy, rhsTy, resTy>(
4522
- exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends);
4605
+ constexpr size_t single_batch_nelems = 1 ;
4606
+ return gemm_batch_new_nm_contig_impl<lhsTy, rhsTy, resTy>(
4607
+ exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m,
4608
+ depends);
4523
4609
}
4524
4610
4525
4611
if (k == 0 ) {
0 commit comments