@@ -363,8 +363,8 @@ bool Copy::CheckBulkLoad(Target target) const {
363363  //  1. arch must have bulk copy support
364364  if  (!TargetHasBulkCopy (target))
365365    return  false ;
366-   //  2. src and dst must be shared.dyn  and local.fragment 
367-   if  (src.scope () != " global" scope () != " shared.dyn" 
366+   //  2. src and dst must be global  and shared 
367+   if  (src.scope () != " global" ( dst.scope () != " shared.dyn"  && dst. scope () !=  " shared " ) )
368368    return  false ;
369369  //  3. check shape.
370370  //  TODO(lei): validate if we can utilize tma under this shape.
@@ -391,7 +391,7 @@ bool Copy::CheckBulkStore(Target target) const {
391391  if  (!TargetHasBulkCopy (target))
392392    return  false ;
393393  //  2. src and dst must be shared.dyn and local.fragment
394-   if  (src.scope () != " shared.dyn" scope () != " global" 
394+   if  (( src.scope () != " shared.dyn"  && src. scope () !=  " shared " )  || dst.scope () != " global" 
395395    return  false ;
396396  //  3. check shape.
397397  //  TODO(lei): validate if we can utilize tma under this shape.
@@ -414,7 +414,7 @@ bool Copy::CheckBulkStore(Target target) const {
414414 * otherwise. 
415415 */  
416416bool  Copy::CheckLDSMCopy (Target target) const  {
417-   return  TargetHasLdmatrix (target) && src.scope () == " shared.dyn" 
417+   return  TargetHasLdmatrix (target) && ( src.scope () == " shared.dyn"  || src. scope () ==  " shared " )  &&
418418         dst.scope () == " local.fragment" 
419419}
420420
@@ -428,7 +428,7 @@ bool Copy::CheckLDSMCopy(Target target) const {
428428 */  
429429bool  Copy::CheckSTSMCopy (Target target) const  {
430430  return  TargetHasStmatrix (target) && src.scope () == " local.fragment" 
431-          dst.scope () == " shared.dyn" 
431+          ( dst.scope () == " shared.dyn"  || dst. scope () ==  " shared " ) ;
432432}
433433
434434/* !
@@ -883,11 +883,7 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
883883    ICHECK (stride != nullptr  && continuous != nullptr );
884884    //  We also need to check if the shape satisfies the following doc:
885885    //  https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
886-     if  (StructuralEqual ()(shared_layout, makeGemmABLayoutPadded (
887-                                              *stride, *continuous,
888-                                              shared_tensor->dtype .bits ()))) {
889-       desc.swizzle  = static_cast <int >(CU_TENSOR_MAP_SWIZZLE_NONE);
890-     } else  if  (StructuralEqual ()(
886+     if  (StructuralEqual ()(
891887                   shared_layout,
892888                   makeQuarterBankSwizzleLayout (*stride, *continuous,
893889                                                shared_tensor->dtype .bits ()))) {
@@ -902,9 +898,18 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
902898                   makeFullBankSwizzleLayout (*stride, *continuous,
903899                                             shared_tensor->dtype .bits ()))) {
904900      desc.swizzle  = static_cast <int >(CU_TENSOR_MAP_SWIZZLE_128B);
901+     } else  if  (StructuralEqual ()(shared_layout, makeGemmABLayoutPadded (
902+       *stride, *continuous,
903+       shared_tensor->dtype .bits ()))) {
904+       LOG (WARNING) << " Bulk copy cannot support a padded layout for src: " 
905+                    << src->name  << " , dst: " name  
906+                    << " , fallback to normal copy" 
907+       return  LowerNormalCopy (T, analyzer);
905908    } else  {
906909      LOG (WARNING)
907-           << " Came across unsupported swizzle layout, fallback to normal copy" 
910+           << " Came across unsupported swizzle layout for src: " 
911+           << src->name  << " , dst: " name  
912+           << " , fallback to normal copy" 
908913      return  LowerNormalCopy (T, analyzer);
909914    }
910915  }
0 commit comments