File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed
dpctl/tensor/libtensor/include/kernels/linalg_functions Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -1490,8 +1490,8 @@ template <typename resT> struct GemmBatchFunctorThreadNM_vecm_HyperParameters
1490
1490
template <typename T>
1491
1491
struct GemmBatchFunctorThreadNM_vecm_HyperParameters <std::complex<T>>
1492
1492
{
1493
- static constexpr std::uint32_t wi_delta_n = 4 ;
1494
- static constexpr std::uint32_t wi_delta_m_vecs = 4 ;
1493
+ static constexpr std::uint32_t wi_delta_n = 2 ;
1494
+ static constexpr std::uint32_t wi_delta_m_vecs = 2 ;
1495
1495
static constexpr std::uint32_t m_vec_size = 1 ;
1496
1496
};
1497
1497
@@ -1527,7 +1527,7 @@ get_wg_delta_m_and_wi_delta_k(const size_t slm_byte_size,
1527
1527
? 64
1528
1528
: 32 * static_cast <std::uint32_t >(slm_max_rows / 32 );
1529
1529
1530
- if ( !wi_delta_k) {
1530
+ for (std:: uint32_t it = 0 ; !wi_delta_k && (it < 4 ); ++it ) {
1531
1531
wg_delta_m /= 2 ;
1532
1532
1533
1533
const size_t slm_max_rows =
@@ -1539,7 +1539,9 @@ get_wg_delta_m_and_wi_delta_k(const size_t slm_byte_size,
1539
1539
? 64
1540
1540
: ((slm_max_rows >= 32 )
1541
1541
? 32
1542
- : 16 * static_cast <std::uint32_t >(slm_max_rows / 16 ));
1542
+ : (slm_max_rows >= 16 ? 16
1543
+ : 8 * static_cast <std::uint32_t >(
1544
+ slm_max_rows / 8 )));
1543
1545
}
1544
1546
1545
1547
if (!wi_delta_k) {
You can’t perform that action at this time.
0 commit comments