- 
                Notifications
    You must be signed in to change notification settings 
- Fork 291
          [TileOp] Introduce a experimental python defined T.gemm_v2
          #793
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
    
  
     Merged
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            10 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      1b5dde9
              
                Refactor GEMM and GEMM-SP operations to enhance clarity and maintaina…
              
              
                LeiWang1999 52800a5
              
                Refactor GEMM and frontend legalize operations for improved clarity a…
              
              
                LeiWang1999 e299b41
              
                Enhance CUDA code generation and testing for GEMM operations
              
              
                LeiWang1999 1ab46ef
              
                Refactor GEMM layout and testing for improved clarity and functionality
              
              
                LeiWang1999 e36740d
              
                Refactor GEMM layout and Python integration for improved functionality
              
              
                LeiWang1999 a3f2564
              
                Refactor GEMM layout and testing for improved clarity and functionality
              
              
                LeiWang1999 ded566e
              
                tfloat32 support.
              
              
                LeiWang1999 3da08a1
              
                lint fix
              
              
                LeiWang1999 aa62efb
              
                lint fix
              
              
                LeiWang1999 b5f327c
              
                Refactor shared memory allocation in GEMM tests
              
              
                LeiWang1999 File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
  Submodule tvm
    updated
      
        from 1fc757 to eddefb
      
    
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,279 @@ | ||
| /*! | ||
| * \file tl/op/gemm_py.cc | ||
| * \brief Implementation of General Matrix Multiplication (GEMM) operators | ||
| */ | ||
|  | ||
| #include "gemm_py.h" | ||
|  | ||
| #include "builtin.h" | ||
| #include <tvm/tir/builtin.h> | ||
| #include <tvm/tir/op.h> | ||
| #include <tvm/tir/op_attr_types.h> | ||
| #include <tvm/tir/transform.h> | ||
|  | ||
| #include "../target/utils.h" | ||
| #include "tvm/ffi/string.h" | ||
|  | ||
| namespace tvm { | ||
| namespace tl { | ||
|  | ||
| using namespace tir; | ||
|  | ||
| /** | ||
| * @brief Construct a Gemm operator from serialized TL arguments and a buffer | ||
| * map. | ||
| * | ||
| * This constructor deserializes operator parameters from `args` and resolves | ||
| * buffer references via `vmap`, populating an internal GemmPyNode with: | ||
| * - device pointers for A, B, C and their corresponding Buffer objects, | ||
| * - transpose flags for A and B, | ||
| * - matrix dimensions M, N, K, | ||
| * - warp allocation policy and clear_accum flag, | ||
| * - strides and memory offsets for A and B, | ||
| * - optional kPack (must be 1 or 2) and optional wg_wait. | ||
| * | ||
| * The populated GemmPyNode is stored into the wrapper's internal `data_`. | ||
| * | ||
| * @param args Positional serialized arguments produced by the TL frontend: | ||
| * expected layout is: | ||
| * [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), | ||
| * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), | ||
| * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), | ||
| * (optional) kPack (Int), (optional) wg_wait (Int)] | ||
| * @param vmap Mapping from access pointer vars to Buffer objects used to | ||
| * resolve the Buffer corresponding to each pointer argument. | ||
| * | ||
| * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor | ||
| * fails with an ICHECK (runtime assertion). No other validation is | ||
| * performed here. | ||
| */ | ||
| GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { | ||
| ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>(); | ||
|  | ||
| node->Aptr = args[0]; | ||
| node->Bptr = args[1]; | ||
| node->Cptr = args[2]; | ||
| node->A = vmap[GetVarFromAccessPtr(node->Aptr)]; | ||
| node->B = vmap[GetVarFromAccessPtr(node->Bptr)]; | ||
| node->C = vmap[GetVarFromAccessPtr(node->Cptr)]; | ||
| node->trans_A = args[3].as<Bool>().value(); | ||
| node->trans_B = args[4].as<Bool>().value(); | ||
| node->M = args[5].as<IntImm>().value()->value; | ||
| node->N = args[6].as<IntImm>().value()->value; | ||
| node->K = args[7].as<IntImm>().value()->value; | ||
| node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value); | ||
| node->clear_accum = args[9].as<Bool>().value(); | ||
| node->stride_A = args[10].as<IntImm>().value()->value; | ||
| node->stride_B = args[11].as<IntImm>().value()->value; | ||
| node->offset_A = args[12].as<IntImm>().value()->value; | ||
| node->offset_B = args[13].as<IntImm>().value()->value; | ||
| if (args.size() > 14) { | ||
| node->kPack = args[14].as<IntImm>().value()->value; | ||
| if (node->kPack != 1 && node->kPack != 2) { | ||
| ICHECK(false) << "kPack must be 1 or 2"; | ||
| } | ||
| } | ||
| if (args.size() > 15) { | ||
| node->wg_wait = args[15].as<IntImm>().value()->value; | ||
| } | ||
| data_ = std::move(node); | ||
| } | ||
|  | ||
| /** | ||
| * @brief Create a copy of this GemmPyNode as a TileOperator. | ||
| * | ||
| * Constructs a new GemmPyNode by copying the current node state and returns it | ||
| * wrapped in a Gemm TileOperator. | ||
| * | ||
| * @return TileOperator A Gemm operator that owns a copy of this node. | ||
| */ | ||
| TileOperator GemmPyNode::Clone() const { | ||
| auto op = make_object<GemmPyNode>(*this); | ||
| return GemmPy(op); | ||
| } | ||
|  | ||
| GemmPyNode::GemmInst GemmPyNode::GetGemmInst(int block_size, | ||
| Target target) const { | ||
| int warp_size = TargetGetWarpSize(target); | ||
| int num_warps = block_size / warp_size; | ||
| bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) && | ||
| (num_warps % 4 == 0) && CheckWGMMA(); | ||
| if (allow_wgmma) { | ||
| return GemmInst::kWGMMA; | ||
| } else if (TargetIsCDNA(target)) { | ||
| return GemmInst::kMFMA; | ||
| } else if (TargetIsCuda(target)) { | ||
| return GemmInst::kMMA; | ||
| } else { | ||
| ICHECK(0) << "Unsupported target for gemm: " << target->str(); | ||
| } | ||
| } | ||
|  | ||
| /** | ||
| * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM. | ||
| * | ||
| * Evaluates device-memory placement, data-type combinations, transpose flags, | ||
| * and K divisibility constraints required for the Hopper WGMMA code path. | ||
| * | ||
| * The check returns true only when: | ||
| * - B resides in shared memory ("shared" or "shared.dyn"); and | ||
| * - (C, A, B) dtypes match one of the supported combinations below and K | ||
| * satisfies the required alignment; and | ||
| * - for combinations that require specific orientations, A is not transposed | ||
| * and B is transposed. | ||
| * | ||
| * Supported combinations and constraints: | ||
| * - C=float16: | ||
| * - A=float16, B=float16: K % 16 == 0 | ||
| * - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % | ||
| * 32 == 0 | ||
| * - C=float32: | ||
| * - A=float16, B=float16: K % 16 == 0 | ||
| * - A=bfloat16, B=bfloat16: K % 16 == 0 | ||
| * - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0 | ||
| * - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0 | ||
| * - C=int32: | ||
| * - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) | ||
| * and K % 32 == 0 | ||
| * | ||
| * @return true if WGMMA is supported for the current buffers, dtypes, and | ||
| * transpose/shape constraints; false otherwise. | ||
| */ | ||
| bool GemmPyNode::CheckWGMMA() const { | ||
| if (B.scope() != "shared.dyn" && B.scope() != "shared") { | ||
| return false; | ||
| } | ||
|  | ||
| if (C->dtype == DataType::Float(16)) { | ||
| if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) | ||
| return K % 16 == 0; | ||
| else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else | ||
| return false; | ||
| } else if (C->dtype == DataType::Float(32)) { | ||
| if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) | ||
| return K % 16 == 0; | ||
| else if (A->dtype == DataType::BFloat(16) && | ||
| B->dtype == DataType::BFloat(16)) | ||
| return K % 16 == 0; | ||
| else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32)) | ||
| return (!trans_A) && trans_B && K % 8 == 0; | ||
| else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else | ||
| return false; | ||
| } else if (C->dtype == DataType::Int(32)) { | ||
| if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8)) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8)) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8)) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8)) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else | ||
| return false; | ||
| } else { | ||
| return false; | ||
| } | ||
| } | ||
|  | ||
| /** | ||
| * @brief Parse and return the numeric GPU architecture from a Target's "arch" | ||
| * attribute. | ||
| * | ||
| * Examines the target's "arch" string and, if it matches the pattern | ||
| * "sm_<num>", returns <num> as an int. If the attribute is present but does not | ||
| * match that pattern, returns 0. | ||
| * | ||
| * Preconditions: the target must have an "arch" attribute (this is checked via | ||
| * ICHECK). | ||
| * | ||
| * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if | ||
| * the arch string does not match "sm_<num>". | ||
| */ | ||
| static int GetArchInt(Target target) { | ||
| int arch_int = 0; | ||
| auto s = target->GetAttr<String>("arch"); | ||
| ICHECK(s.defined()); | ||
| std::string arch = s.value(); | ||
| if (arch.rfind("sm_", 0) == 0) { | ||
| arch_int = std::stoi(arch.substr(3)); | ||
| } else { | ||
| arch_int = 0; | ||
| } | ||
| return arch_int; | ||
| } | ||
| 
      Comment on lines
    
      +208
     to 
      +219
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. | ||
|  | ||
| Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { | ||
| auto block_size = *as_const_int(T.thread_bounds->extent); | ||
| GemmInst gemm_inst = GetGemmInst(block_size, T.target); | ||
| auto [warp_m, warp_n] = policy->ComputeWarpPartition( | ||
| M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA); | ||
|  | ||
| if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { | ||
| auto prim_func = Downcast<PrimFunc>( | ||
| (*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds, T.thread_var)); | ||
| ICHECK(prim_func->attrs.defined()); | ||
| auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol"); | ||
| ICHECK(global_symbol.defined()); | ||
| if (prim_func->body.as<BlockRealizeNode>()) { | ||
| BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body); | ||
| auto block = block_realize->block; | ||
| { | ||
| BlockNode *n = block.CopyOnWrite(); | ||
| n->name_hint = global_symbol.value(); | ||
| } | ||
| return BlockRealize(block_realize->iter_values, block_realize->predicate, | ||
| block); | ||
| } | ||
| // warp with block realize node | ||
| return BlockRealize( | ||
| /*iter_values=*/Array<PrimExpr>(), | ||
| /*predicate=*/const_true(), | ||
| /*block=*/ | ||
| Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, | ||
| /*name_hint=*/global_symbol.value(), prim_func->body)); | ||
| } else { | ||
| LOG(FATAL) << "No lower function found for gemm_py"; | ||
| } | ||
| } | ||
|  | ||
| LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T, | ||
| InferLevel level) const { | ||
| if (completed_) | ||
| return {}; | ||
| LayoutMap results; | ||
|  | ||
| if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) { | ||
| results = Downcast<LayoutMap>( | ||
| (*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds)); | ||
| } else { | ||
| LOG(FATAL) << "No infer layout function found for gemm_py"; | ||
| } | ||
|  | ||
| completed_ = true; | ||
| return results; | ||
| } | ||
|  | ||
| TIR_REGISTER_TL_OP(GemmPy, gemm_py) | ||
| .set_num_inputs(5) | ||
| .set_attr<TCallEffectKind>("TCallEffectKind", | ||
| Integer(CallEffectKind::kOpaque)); | ||
|  | ||
| TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); }); | ||
| } // namespace tl | ||
| } // namespace tvm | ||
      
      Oops, something went wrong.
        
    
  
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bounds-check args and validate BufferMap lookups.
Constructor indexes args[0..13] unconditionally and uses vmap[...] without presence checks. This risks OOB access or null buffers.
Apply:
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>(); + // Required: Aptr, Bptr, Cptr, trans_A, trans_B, M, N, K, policy, clear_accum, + // stride_A, stride_B, offset_A, offset_B + ICHECK_GE(args.size(), 14) << "gemm_py expects at least 14 positional args"; + + // Defaults for optional fields + node->kPack = 1; + node->wg_wait = 0; node->Aptr = args[0]; node->Bptr = args[1]; node->Cptr = args[2]; - node->A = vmap[GetVarFromAccessPtr(node->Aptr)]; - node->B = vmap[GetVarFromAccessPtr(node->Bptr)]; - node->C = vmap[GetVarFromAccessPtr(node->Cptr)]; - node->trans_A = args[3].as<Bool>().value(); - node->trans_B = args[4].as<Bool>().value(); + auto Avar = GetVarFromAccessPtr(node->Aptr); + auto Bvar = GetVarFromAccessPtr(node->Bptr); + auto Cvar = GetVarFromAccessPtr(node->Cptr); + ICHECK(vmap.count(Avar)) << "Aptr not found in BufferMap"; + ICHECK(vmap.count(Bvar)) << "Bptr not found in BufferMap"; + ICHECK(vmap.count(Cvar)) << "Cptr not found in BufferMap"; + node->A = vmap.at(Avar); + node->B = vmap.at(Bvar); + node->C = vmap.at(Cvar); + node->trans_A = args[3].as<IntImm>().value()->value != 0; + node->trans_B = args[4].as<IntImm>().value()->value != 0; node->M = args[5].as<IntImm>().value()->value; node->N = args[6].as<IntImm>().value()->value; node->K = args[7].as<IntImm>().value()->value; node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value); - node->clear_accum = args[9].as<Bool>().value(); + node->clear_accum = args[9].as<IntImm>().value()->value != 0; node->stride_A = args[10].as<IntImm>().value()->value; node->stride_B = args[11].as<IntImm>().value()->value; node->offset_A = args[12].as<IntImm>().value()->value; node->offset_B = args[13].as<IntImm>().value()->value; if (args.size() > 14) { node->kPack = args[14].as<IntImm>().value()->value; if (node->kPack != 1 && node->kPack != 2) { ICHECK(false) << "kPack must be 1 or 2"; } } if (args.size() > 15) { node->wg_wait = args[15].as<IntImm>().value()->value; } data_ = std::move(node); }📝 Committable suggestion
🤖 Prompt for AI Agents