Skip to content

Commit a885034

Browse files
Fixed issue identified during PR review
There was a lapse in logic in handling batches for contiguous inputs in new_nm implementation for types that fall into tree_contig impl. The hyperparameter selection was refined to address observation of a slowdown for "c8" type inputs. The hyperparameters must be chosen to keep size of registers needed to store private_C matrix the same. This is now accomplished using constexpr selector helper class. Few typos were fixed discovered during debugging that resulted in unreferenced errors (passed n, m, k arguments instead of expected n, k, m).
1 parent 43bba09 commit a885034

File tree

1 file changed

+123
-37
lines changed
  • dpctl/tensor/libtensor/include/kernels/linalg_functions

1 file changed

+123
-37
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 123 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,19 +1480,74 @@ class GemmBatchFunctorThreadNM_vecm
14801480
}
14811481
};
14821482

1483-
template <typename resT> struct GemmBatchFunctorThreadNM_vecm_HyperParameters
1483+
struct GemmBatchFunctorThreadNM_vecm_HyperParameters
14841484
{
1485-
static constexpr std::uint32_t wi_delta_n = 4;
1486-
static constexpr std::uint32_t wi_delta_m_vecs = 1;
1487-
static constexpr std::uint32_t m_vec_size = 4;
1485+
private:
1486+
std::uint32_t wi_delta_n = 2;
1487+
std::uint32_t wi_delta_m_vecs = 4;
1488+
std::uint32_t m_vec_size = 1;
1489+
1490+
public:
1491+
constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters();
1492+
constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters(
1493+
std::uint32_t wi_delta_n_,
1494+
std::uint32_t wi_delta_m_vecs_,
1495+
std::uint32_t m_vec_size_)
1496+
: wi_delta_n(wi_delta_n_), wi_delta_m_vecs(wi_delta_m_vecs_),
1497+
m_vec_size(m_vec_size_)
1498+
{
1499+
}
1500+
1501+
constexpr std::uint32_t get_wi_delta_n() const
1502+
{
1503+
return wi_delta_n;
1504+
}
1505+
constexpr std::uint32_t get_wi_delta_m_vecs() const
1506+
{
1507+
return wi_delta_m_vecs;
1508+
}
1509+
constexpr std::uint32_t get_m_vec_size() const
1510+
{
1511+
return m_vec_size;
1512+
}
14881513
};
14891514

1490-
template <typename T>
1491-
struct GemmBatchFunctorThreadNM_vecm_HyperParameters<std::complex<T>>
1515+
template <typename resT>
1516+
struct GemmBatchFunctorThreadNM_vecm_HyperParametersSelector
14921517
{
1493-
static constexpr std::uint32_t wi_delta_n = 2;
1494-
static constexpr std::uint32_t wi_delta_m_vecs = 2;
1495-
static constexpr std::uint32_t m_vec_size = 1;
1518+
constexpr GemmBatchFunctorThreadNM_vecm_HyperParametersSelector() {}
1519+
1520+
constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters get() const
1521+
{
1522+
if constexpr (sizeof(resT) == 1) {
1523+
// 1 * 8 * 2 * 4 == 64
1524+
return GemmBatchFunctorThreadNM_vecm_HyperParameters(8, 2, 4);
1525+
}
1526+
else if constexpr (sizeof(resT) == 2) {
1527+
// 2 * 4 * 2 * 4 == 64
1528+
return GemmBatchFunctorThreadNM_vecm_HyperParameters(4, 2, 4);
1529+
}
1530+
else if constexpr (sizeof(resT) == 4) {
1531+
// 4 * 4 * 1 * 4 == 64
1532+
return GemmBatchFunctorThreadNM_vecm_HyperParameters(4, 1, 4);
1533+
}
1534+
else if constexpr (sizeof(resT) == 8) {
1535+
// 8 * 2 * 1 * 4 == 64
1536+
if constexpr (std::is_same_v<resT, std::complex<float>>) {
1537+
return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 4, 1);
1538+
}
1539+
else {
1540+
return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 1, 4);
1541+
}
1542+
}
1543+
else if constexpr (std::is_same_v<resT, std::complex<double>>) {
1544+
// 16 * 2 * 2 * 1 == 64
1545+
return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 2, 1);
1546+
}
1547+
else {
1548+
return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 2, 1);
1549+
}
1550+
}
14961551
};
14971552

14981553
template <typename T1,
@@ -1572,11 +1627,14 @@ sycl::event _gemm_batch_new_nm_impl(sycl::queue &exec_q,
15721627
const ResIndexerT &res_indexer,
15731628
std::vector<sycl::event> const &depends)
15741629
{
1575-
using parametersT = GemmBatchFunctorThreadNM_vecm_HyperParameters<resTy>;
1630+
constexpr GemmBatchFunctorThreadNM_vecm_HyperParametersSelector<resTy>
1631+
selector{};
1632+
constexpr auto hyper_params = selector.get();
15761633

1577-
constexpr std::uint32_t wi_delta_n = parametersT::wi_delta_n;
1578-
constexpr std::uint32_t wi_delta_m_vecs = parametersT::wi_delta_m_vecs;
1579-
constexpr std::uint32_t m_vec_size = parametersT::m_vec_size;
1634+
constexpr std::uint32_t wi_delta_n = hyper_params.get_wi_delta_n();
1635+
constexpr std::uint32_t wi_delta_m_vecs =
1636+
hyper_params.get_wi_delta_m_vecs();
1637+
constexpr std::uint32_t m_vec_size = hyper_params.get_m_vec_size();
15801638

15811639
constexpr std::uint32_t wi_total_delta_m = wi_delta_m_vecs * m_vec_size;
15821640

@@ -3078,7 +3136,7 @@ gemm_batch_new_nm_impl(sycl::queue &exec_q,
30783136
sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
30793137
lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
30803138
OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3081-
exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, m, k, batch_indexer,
3139+
exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_indexer,
30823140
lhs_indexer, rhs_indexer, res_indexer, depends);
30833141

30843142
return gemm_ev;
@@ -3643,41 +3701,67 @@ sycl::event gemm_new_nm_impl(sycl::queue &exec_q,
36433701
sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
36443702
lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
36453703
OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3646-
exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, m, k,
3704+
exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m,
36473705
batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends);
36483706

36493707
return gemm_ev;
36503708
}
36513709

36523710
template <typename lhsTy, typename rhsTy, typename resTy>
36533711
sycl::event
3654-
gemm_new_nm_contig_impl(sycl::queue &exec_q,
3655-
const lhsTy *lhs_tp,
3656-
const rhsTy *rhs_tp,
3657-
resTy *res_tp,
3658-
size_t n,
3659-
size_t k,
3660-
size_t m,
3661-
std::vector<sycl::event> const &depends = {})
3712+
gemm_batch_new_nm_contig_impl(sycl::queue &exec_q,
3713+
const lhsTy *lhs_tp,
3714+
const rhsTy *rhs_tp,
3715+
resTy *res_tp,
3716+
const size_t batch_nelems,
3717+
const size_t n,
3718+
const size_t k,
3719+
const size_t m,
3720+
std::vector<sycl::event> const &depends = {})
36623721
{
36633722
using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
36643723
constexpr OuterInnerDimsIndexerT lhs_indexer{};
36653724
constexpr OuterInnerDimsIndexerT rhs_indexer{};
36663725
constexpr OuterInnerDimsIndexerT res_indexer{};
36673726

3668-
using BatchDimsIndexerT =
3669-
dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer;
3670-
constexpr BatchDimsIndexerT batch_indexer{};
3671-
36723727
constexpr size_t single_batch_nelems = 1;
36733728

3674-
sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
3675-
lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
3676-
OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3677-
exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, m, k,
3678-
batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends);
3729+
if (batch_nelems == single_batch_nelems) {
3730+
using BatchDimsIndexerT =
3731+
dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer;
3732+
constexpr BatchDimsIndexerT batch_indexer{};
36793733

3680-
return gemm_ev;
3734+
sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
3735+
lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
3736+
OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3737+
exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m,
3738+
batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends);
3739+
3740+
return gemm_ev;
3741+
}
3742+
else {
3743+
using dpctl::tensor::offset_utils::Strided1DIndexer;
3744+
using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer;
3745+
using BatchDimsIndexerT =
3746+
ThreeOffsets_CombinedIndexer<Strided1DIndexer, Strided1DIndexer,
3747+
Strided1DIndexer>;
3748+
3749+
using dpctl::tensor::offset_utils::Strided1DIndexer;
3750+
3751+
const ssize_t ss_batch_nelems = static_cast<ssize_t>(batch_nelems);
3752+
const BatchDimsIndexerT batch_indexer(
3753+
Strided1DIndexer{0, ss_batch_nelems, static_cast<ssize_t>(n * k)},
3754+
Strided1DIndexer{0, ss_batch_nelems, static_cast<ssize_t>(k * m)},
3755+
Strided1DIndexer{0, ss_batch_nelems, static_cast<ssize_t>(n * m)});
3756+
3757+
sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
3758+
lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
3759+
OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3760+
exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m,
3761+
batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends);
3762+
3763+
return gemm_ev;
3764+
}
36813765
}
36823766

36833767
template <typename lhsTy, typename rhsTy, typename resTy>
@@ -3705,8 +3789,8 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
37053789
const size_t max_nm = std::max(n, m);
37063790

37073791
if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) {
3708-
return gemm_new_nm_contig_impl<lhsTy, rhsTy, resTy>(
3709-
exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends);
3792+
return gemm_batch_new_nm_contig_impl<lhsTy, rhsTy, resTy>(
3793+
exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends);
37103794
}
37113795

37123796
if (k == 0) {
@@ -4518,8 +4602,10 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
45184602
const size_t max_nm = std::max(n, m);
45194603

45204604
if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) {
4521-
return gemm_new_nm_contig_impl<lhsTy, rhsTy, resTy>(
4522-
exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends);
4605+
constexpr size_t single_batch_nelems = 1;
4606+
return gemm_batch_new_nm_contig_impl<lhsTy, rhsTy, resTy>(
4607+
exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m,
4608+
depends);
45234609
}
45244610

45254611
if (k == 0) {

0 commit comments

Comments
 (0)