@@ -497,19 +497,19 @@ __forceinline__ __device__ void copy(
497497
498498// //////////////////////////////////////////////////////////////////////////////////////////////////
499499
500- template <bool Is_even_MN=true , bool Clear_OOB_MN=false ,
500+ template <bool Is_even_MN=true , bool Clear_OOB_MN=true ,
501501 typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
502502 typename Engine2, typename Layout2>
503503__forceinline__ __device__ void copy_ZOH (
504504 TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
505505 Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
506506 const int max_M=0 , const int max_N=0
507507) {
508- CUTE_STATIC_ASSERT_V (rank (S) == Int<3 >{}); // (MMA, M, N )
509- CUTE_STATIC_ASSERT_V (rank (D) == Int<3 >{}); // (MMA, M, N )
508+ CUTE_STATIC_ASSERT_V (rank (S) == Int<3 >{}); // (MMA, MMA_M, MMA_N )
509+ CUTE_STATIC_ASSERT_V (rank (D) == Int<3 >{}); // (MMA, MMA_M, MMA_N )
510510 CUTE_STATIC_ASSERT_V (size<0 >(S) == size<0 >(D)); // MMA
511- CUTE_STATIC_ASSERT_V (size<1 >(S) == size<1 >(D)); // M
512- CUTE_STATIC_ASSERT_V (size<2 >(S) == size<2 >(D)); // N
511+ CUTE_STATIC_ASSERT_V (size<1 >(S) == size<1 >(D)); // MMA_M
512+ CUTE_STATIC_ASSERT_V (size<2 >(S) == size<2 >(D)); // MMA_N
513513
514514 #pragma unroll
515515 for (int m = 0 ; m < size<1 >(S); ++m) {
0 commit comments