@@ -364,7 +364,8 @@ bool Copy::CheckBulkLoad(Target target) const {
364364  if  (!TargetHasBulkCopy (target))
365365    return  false ;
366366  //  2. src and dst must be global and shared
367-   if  (src.scope () != " global" scope () != " shared.dyn" scope () != " shared" 
367+   if  (src.scope () != " global" 
368+       (dst.scope () != " shared.dyn" scope () != " shared" 
368369    return  false ;
369370  //  3. check shape.
370371  //  TODO(lei): validate if we can utilize tma under this shape.
@@ -391,7 +392,8 @@ bool Copy::CheckBulkStore(Target target) const {
391392  if  (!TargetHasBulkCopy (target))
392393    return  false ;
393394  //  2. src and dst must be shared.dyn and local.fragment
394-   if  ((src.scope () != " shared.dyn" scope () != " shared" scope () != " global" 
395+   if  ((src.scope () != " shared.dyn" scope () != " shared" 
396+       dst.scope () != " global" 
395397    return  false ;
396398  //  3. check shape.
397399  //  TODO(lei): validate if we can utilize tma under this shape.
@@ -414,7 +416,8 @@ bool Copy::CheckBulkStore(Target target) const {
414416 * otherwise. 
415417 */  
416418bool  Copy::CheckLDSMCopy (Target target) const  {
417-   return  TargetHasLdmatrix (target) && (src.scope () == " shared.dyn" scope () == " shared" 
419+   return  TargetHasLdmatrix (target) &&
420+          (src.scope () == " shared.dyn" scope () == " shared" 
418421         dst.scope () == " local.fragment" 
419422}
420423
@@ -883,10 +886,9 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
883886    ICHECK (stride != nullptr  && continuous != nullptr );
884887    //  We also need to check if the shape satisfies the following doc:
885888    //  https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
886-     if  (StructuralEqual ()(
887-                    shared_layout,
888-                    makeQuarterBankSwizzleLayout (*stride, *continuous,
889-                                                 shared_tensor->dtype .bits ()))) {
889+     if  (StructuralEqual ()(shared_layout, makeQuarterBankSwizzleLayout (
890+                                              *stride, *continuous,
891+                                              shared_tensor->dtype .bits ()))) {
890892      desc.swizzle  = static_cast <int >(CU_TENSOR_MAP_SWIZZLE_32B);
891893    } else  if  (StructuralEqual ()(
892894                   shared_layout,
@@ -898,18 +900,18 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
898900                   makeFullBankSwizzleLayout (*stride, *continuous,
899901                                             shared_tensor->dtype .bits ()))) {
900902      desc.swizzle  = static_cast <int >(CU_TENSOR_MAP_SWIZZLE_128B);
901-     } else  if  (StructuralEqual ()(shared_layout, makeGemmABLayoutPadded (
902-       *stride, *continuous,
903-       shared_tensor->dtype .bits ()))) {
904-       LOG (WARNING) << " Bulk copy cannot support a padded layout for src: " 
905-                    << src->name  << " , dst: " name  
903+     } else  if  (StructuralEqual ()(
904+                    shared_layout,
905+                    makeGemmABLayoutPadded (*stride, *continuous,
906+                                           shared_tensor->dtype .bits ()))) {
907+       LOG (WARNING) << " Bulk copy cannot support a padded layout for src: " 
908+                    << src->name  << " , dst: " name 
906909                   << " , fallback to normal copy" 
907910      return  LowerNormalCopy (T, analyzer);
908911    } else  {
909-       LOG (WARNING)
910-           << " Came across unsupported swizzle layout for src: " 
911-           << src->name  << " , dst: " name  
912-           << " , fallback to normal copy" 
912+       LOG (WARNING) << " Came across unsupported swizzle layout for src: " 
913+                    << src->name  << " , dst: " name 
914+                    << " , fallback to normal copy" 
913915      return  LowerNormalCopy (T, analyzer);
914916    }
915917  }
0 commit comments