Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Checks: >
-clang-analyzer-optin.cplusplus.UninitializedObject,
-cppcoreguidelines-pro-type-static-cast-downcast,
-performance-unnecessary-value-param,
-performance-enum-size,

WarningsAsErrors: '*'

Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 1fc757 to eddefb
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ tilelang_file_glob(GLOB TILE_LANG_SRCS
if(USE_CUDA)
tilelang_file_glob(GLOB TILE_LANG_CUDA_SRCS
src/runtime/*.cc
src/target/ptx.cc
src/target/codegen_cuda.cc
src/target/rt_mod_cuda.cc
)
Expand Down
1 change: 0 additions & 1 deletion src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower =
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();

auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
T.layout_map, T.analyzer, T.buffer_oob);
if (copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore) {
Expand Down
49 changes: 23 additions & 26 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,6 @@ namespace tl {

using namespace tir;

/**
* @brief Compute the prime factorization of an integer.
*
* Returns the prime factors of x in non-decreasing order by repeatedly dividing
* out the smallest possible factor.
*
* @param x Integer to factorize. If x <= 1, an empty vector is returned.
* @return std::vector<int> Prime factors of x (with multiplicity), in
* non-decreasing order.
*/
static std::vector<int> toPrimeFactors(int x) {
int i = 2;
std::vector<int> result;
while (x > 1) {
if (x % i == 0) {
x /= i;
result.push_back(i);
} else {
i++;
}
}
return result;
}

/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
Expand Down Expand Up @@ -268,14 +244,20 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size,
int best_m = 1;
int best_n = 1;
float best_balance = std::numeric_limits<float>::max();

// Try all possible combinations that satisfy the constraints
for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
int n = num_warps / m;

// Calculate how balanced this partition is
float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
// m_per_warp and n_per_warp must be greater than 1
if (m_per_warp < 1 || n_per_warp < 1)
continue;
// m * n must equal num_warps
if (m * n != num_warps)
continue;

float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);

if (balance < best_balance) {
Expand All @@ -290,7 +272,6 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size,
} else {
ICHECK(0) << "Unknown GemmWarpPolicy";
}

// Store the computed values in the object's member variables
this->m_warp = m_warp;
this->n_warp = n_warp;
Expand Down Expand Up @@ -632,5 +613,21 @@ TIR_REGISTER_TL_OP(Gemm, gemm)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TVM_REGISTER_OP("tl.GemmWarpPolicy")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");

TVM_FFI_STATIC_INIT_BLOCK({
GemmNode::RegisterReflection();
GemmWarpPolicyNode::RegisterReflection();
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition",
[](GemmWarpPolicy policy, int M, int N, int block_size,
Target target, bool is_wgmma) {
policy->ComputeWarpPartition(M, N, block_size, target,
is_wgmma);
return;
});
});

} // namespace tl
} // namespace tvm
279 changes: 279 additions & 0 deletions src/op/gemm_py.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
/*!
* \file tl/op/gemm_py.cc
* \brief Implementation of General Matrix Multiplication (GEMM) operators
*/

#include "gemm_py.h"

#include "builtin.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>

#include "../target/utils.h"
#include "tvm/ffi/string.h"

namespace tvm {
namespace tl {

using namespace tir;

/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
*
* This constructor deserializes operator parameters from `args` and resolves
* buffer references via `vmap`, populating an internal GemmPyNode with:
* - device pointers for A, B, C and their corresponding Buffer objects,
* - transpose flags for A and B,
* - matrix dimensions M, N, K,
* - warp allocation policy and clear_accum flag,
* - strides and memory offsets for A and B,
* - optional kPack (must be 1 or 2) and optional wg_wait.
*
* The populated GemmPyNode is stored into the wrapper's internal `data_`.
*
* @param args Positional serialized arguments produced by the TL frontend:
* expected layout is:
* [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>();

node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
node->trans_A = args[3].as<Bool>().value();
node->trans_B = args[4].as<Bool>().value();
node->M = args[5].as<IntImm>().value()->value;
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<Bool>().value();
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value;
if (args.size() > 14) {
node->kPack = args[14].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
Comment on lines +50 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bounds-check args and validate BufferMap lookups.

Constructor indexes args[0..13] unconditionally and uses vmap[...] without presence checks. This risks OOB access or null buffers.

Apply:

 GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
   ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>();
+  // Required: Aptr, Bptr, Cptr, trans_A, trans_B, M, N, K, policy, clear_accum,
+  //           stride_A, stride_B, offset_A, offset_B
+  ICHECK_GE(args.size(), 14) << "gemm_py expects at least 14 positional args";
+
+  // Defaults for optional fields
+  node->kPack = 1;
+  node->wg_wait = 0;

   node->Aptr = args[0];
   node->Bptr = args[1];
   node->Cptr = args[2];
-  node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
-  node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
-  node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
-  node->trans_A = args[3].as<Bool>().value();
-  node->trans_B = args[4].as<Bool>().value();
+  auto Avar = GetVarFromAccessPtr(node->Aptr);
+  auto Bvar = GetVarFromAccessPtr(node->Bptr);
+  auto Cvar = GetVarFromAccessPtr(node->Cptr);
+  ICHECK(vmap.count(Avar)) << "Aptr not found in BufferMap";
+  ICHECK(vmap.count(Bvar)) << "Bptr not found in BufferMap";
+  ICHECK(vmap.count(Cvar)) << "Cptr not found in BufferMap";
+  node->A = vmap.at(Avar);
+  node->B = vmap.at(Bvar);
+  node->C = vmap.at(Cvar);
+  node->trans_A = args[3].as<IntImm>().value()->value != 0;
+  node->trans_B = args[4].as<IntImm>().value()->value != 0;
   node->M = args[5].as<IntImm>().value()->value;
   node->N = args[6].as<IntImm>().value()->value;
   node->K = args[7].as<IntImm>().value()->value;
   node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
-  node->clear_accum = args[9].as<Bool>().value();
+  node->clear_accum = args[9].as<IntImm>().value()->value != 0;
   node->stride_A = args[10].as<IntImm>().value()->value;
   node->stride_B = args[11].as<IntImm>().value()->value;
   node->offset_A = args[12].as<IntImm>().value()->value;
   node->offset_B = args[13].as<IntImm>().value()->value;
   if (args.size() > 14) {
     node->kPack = args[14].as<IntImm>().value()->value;
     if (node->kPack != 1 && node->kPack != 2) {
       ICHECK(false) << "kPack must be 1 or 2";
     }
   }
   if (args.size() > 15) {
     node->wg_wait = args[15].as<IntImm>().value()->value;
   }
   data_ = std::move(node);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>();
node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
node->trans_A = args[3].as<Bool>().value();
node->trans_B = args[4].as<Bool>().value();
node->M = args[5].as<IntImm>().value()->value;
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<Bool>().value();
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value;
if (args.size() > 14) {
node->kPack = args[14].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>();
// Required: Aptr, Bptr, Cptr, trans_A, trans_B, M, N, K, policy, clear_accum,
// stride_A, stride_B, offset_A, offset_B
ICHECK_GE(args.size(), 14) << "gemm_py expects at least 14 positional args";
// Defaults for optional fields
node->kPack = 1;
node->wg_wait = 0;
node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
auto Avar = GetVarFromAccessPtr(node->Aptr);
auto Bvar = GetVarFromAccessPtr(node->Bptr);
auto Cvar = GetVarFromAccessPtr(node->Cptr);
ICHECK(vmap.count(Avar)) << "Aptr not found in BufferMap";
ICHECK(vmap.count(Bvar)) << "Bptr not found in BufferMap";
ICHECK(vmap.count(Cvar)) << "Cptr not found in BufferMap";
node->A = vmap.at(Avar);
node->B = vmap.at(Bvar);
node->C = vmap.at(Cvar);
node->trans_A = args[3].as<IntImm>().value()->value != 0;
node->trans_B = args[4].as<IntImm>().value()->value != 0;
node->M = args[5].as<IntImm>().value()->value;
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<IntImm>().value()->value != 0;
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value;
if (args.size() > 14) {
node->kPack = args[14].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
🤖 Prompt for AI Agents
In src/op/gemm_py.cc around lines 50 to 80, the constructor assumes args has at
least 14 entries and that vmap contains entries for GetVarFromAccessPtr results,
which can cause OOB or missing-buffer crashes; add explicit checks before each
indexed access (ICHECK or if-checks verifying args.size() >= required_index+1)
and validate that .as<...>() returns non-null (or use ICHECK(node) semantics)
before dereferencing value(); for vmap lookups, compute the key once, check
vmap.count(key) (or vmap.find(key) != vmap.end()) and handle missing buffers
with a clear ICHECK(false) or throw with a descriptive message; keep existing
optional-field handling but gate accesses to args[14] and args[15] behind
args.size() checks; ensure all messages indicate which arg or buffer is missing
to aid debugging.


/**
* @brief Create a copy of this GemmPyNode as a TileOperator.
*
* Constructs a new GemmPyNode by copying the current node state and returns it
* wrapped in a Gemm TileOperator.
*
* @return TileOperator A Gemm operator that owns a copy of this node.
*/
TileOperator GemmPyNode::Clone() const {
auto op = make_object<GemmPyNode>(*this);
return GemmPy(op);
}

GemmPyNode::GemmInst GemmPyNode::GetGemmInst(int block_size,
Target target) const {
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
}
}

/**
* @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
*
* Evaluates device-memory placement, data-type combinations, transpose flags,
* and K divisibility constraints required for the Hopper WGMMA code path.
*
* The check returns true only when:
* - B resides in shared memory ("shared" or "shared.dyn"); and
* - (C, A, B) dtypes match one of the supported combinations below and K
* satisfies the required alignment; and
* - for combinations that require specific orientations, A is not transposed
* and B is transposed.
*
* Supported combinations and constraints:
* - C=float16:
* - A=float16, B=float16: K % 16 == 0
* - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
* 32 == 0
* - C=float32:
* - A=float16, B=float16: K % 16 == 0
* - A=bfloat16, B=bfloat16: K % 16 == 0
* - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
* - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
* - C=int32:
* - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
* and K % 32 == 0
*
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
*/
bool GemmPyNode::CheckWGMMA() const {
if (B.scope() != "shared.dyn" && B.scope() != "shared") {
return false;
}

if (C->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else
return false;
} else if (C->dtype == DataType::Float(32)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
else if (A->dtype == DataType::BFloat(16) &&
B->dtype == DataType::BFloat(16))
return K % 16 == 0;
else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
return (!trans_A) && trans_B && K % 8 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else
return false;
} else if (C->dtype == DataType::Int(32)) {
if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0;
else
return false;
} else {
return false;
}
}

/**
* @brief Parse and return the numeric GPU architecture from a Target's "arch"
* attribute.
*
* Examines the target's "arch" string and, if it matches the pattern
* "sm_<num>", returns <num> as an int. If the attribute is present but does not
* match that pattern, returns 0.
*
* Preconditions: the target must have an "arch" attribute (this is checked via
* ICHECK).
*
* @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
* the arch string does not match "sm_<num>".
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
} else {
arch_int = 0;
}
return arch_int;
}
Comment on lines +208 to +219
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function GetArchInt is duplicated from src/target/utils.cc. Since gemm_py.cc includes ../target/utils.h, this local static version is not needed. It also appears to be dead code as it's not called within this file. Please remove this duplicated function to improve maintainability.


Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);

if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
auto prim_func = Downcast<PrimFunc>(
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds, T.thread_var));
ICHECK(prim_func->attrs.defined());
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol");
ICHECK(global_symbol.defined());
if (prim_func->body.as<BlockRealizeNode>()) {
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
auto block = block_realize->block;
{
BlockNode *n = block.CopyOnWrite();
n->name_hint = global_symbol.value();
}
return BlockRealize(block_realize->iter_values, block_realize->predicate,
block);
}
// warp with block realize node
return BlockRealize(
/*iter_values=*/Array<PrimExpr>(),
/*predicate=*/const_true(),
/*block=*/
Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/global_symbol.value(), prim_func->body));
} else {
LOG(FATAL) << "No lower function found for gemm_py";
}
}

LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (completed_)
return {};
LayoutMap results;

if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) {
results = Downcast<LayoutMap>(
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds));
} else {
LOG(FATAL) << "No infer layout function found for gemm_py";
}

completed_ = true;
return results;
}

TIR_REGISTER_TL_OP(GemmPy, gemm_py)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); });
} // namespace tl
} // namespace tvm
Loading
Loading