Skip to content

Commit 43bba09

Browse files
Changed selector of algo hyperparameters for complex
Made hyperparameter scaling down logic more lenient.
1 parent 74df338 commit 43bba09

File tree

1 file changed

+6
-4
lines changed
  • dpctl/tensor/libtensor/include/kernels/linalg_functions

1 file changed

+6
-4
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,8 +1490,8 @@ template <typename resT> struct GemmBatchFunctorThreadNM_vecm_HyperParameters
14901490
template <typename T>
14911491
struct GemmBatchFunctorThreadNM_vecm_HyperParameters<std::complex<T>>
14921492
{
1493-
static constexpr std::uint32_t wi_delta_n = 4;
1494-
static constexpr std::uint32_t wi_delta_m_vecs = 4;
1493+
static constexpr std::uint32_t wi_delta_n = 2;
1494+
static constexpr std::uint32_t wi_delta_m_vecs = 2;
14951495
static constexpr std::uint32_t m_vec_size = 1;
14961496
};
14971497

@@ -1527,7 +1527,7 @@ get_wg_delta_m_and_wi_delta_k(const size_t slm_byte_size,
15271527
? 64
15281528
: 32 * static_cast<std::uint32_t>(slm_max_rows / 32);
15291529

1530-
if (!wi_delta_k) {
1530+
for (std::uint32_t it = 0; !wi_delta_k && (it < 4); ++it) {
15311531
wg_delta_m /= 2;
15321532

15331533
const size_t slm_max_rows =
@@ -1539,7 +1539,9 @@ get_wg_delta_m_and_wi_delta_k(const size_t slm_byte_size,
15391539
? 64
15401540
: ((slm_max_rows >= 32)
15411541
? 32
1542-
: 16 * static_cast<std::uint32_t>(slm_max_rows / 16));
1542+
: (slm_max_rows >= 16 ? 16
1543+
: 8 * static_cast<std::uint32_t>(
1544+
slm_max_rows / 8)));
15431545
}
15441546

15451547
if (!wi_delta_k) {

0 commit comments

Comments
 (0)