@@ -40,7 +40,7 @@ class GemmWarpPolicyNode : public Object {
4040 .def_ro (" n_warp" , &GemmWarpPolicyNode::n_warp);
4141 }
4242
43- std::pair<int , int > ComputeWarpPartition (int M, int N, int block_size,
43+ std::pair<int , int > computeWarpPartition (int M, int N, int block_size,
4444 Target target,
4545 GemmInst gemm_inst) const ;
4646
@@ -84,47 +84,47 @@ class GemmWarpPolicy : public ObjectRef {
8484
8585class GemmNode : public TileOperatorNode {
8686public:
87- bool CheckWGMMA () const ;
88- tir::Buffer A, B, C ;
89- // pointer to the A, B, C
90- PrimExpr Aptr, Bptr, Cptr ;
91- bool trans_A, trans_B ;
92- int M, N, K ;
93- int stride_A, stride_B ;
94- int offset_A, offset_B ;
95- PrimExpr clear_accum = const_false();
87+ bool checkWgmma () const ;
88+ tir::Buffer a_, b_, c_ ;
89+ // BufferRegion for A, B and C
90+ BufferRegion aRegion_, bRegion_, cRegion_ ;
91+ bool transA_, transB_ ;
92+ int m_, n_, k_ ;
93+ int strideA_, strideB_ ;
94+ int offsetA_, offsetB_ ;
95+ PrimExpr clearAccum_ = const_false();
9696 // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
9797 // only will be enabled under cdna mfma instructions
98- int kPack = 1 ;
99- int wg_wait = 0 ;
100- PrimExpr mbarptr ;
101- std::optional<tir::Buffer> mbar ; // mbar is optional, only used for TCGEN5MMA
102- Array<PrimExpr> C_coords ;
103- mutable GemmWarpPolicy policy ;
98+ int kPack_ = 1 ;
99+ int wgWait_ = 0 ;
100+ PrimExpr mbarPtr_ ;
101+ std::optional<tir::Buffer> mbar_ ; // mbar is optional, only used for TCGEN5MMA
102+ Array<PrimExpr> cCoords_ ;
103+ mutable GemmWarpPolicy policy_ ;
104104 TVM_FFI_DECLARE_OBJECT_INFO_FINAL (" tl.Gemm" , GemmNode, TileOperatorNode);
105105
106106 static void RegisterReflection () {
107107 namespace refl = tvm::ffi::reflection;
108108 refl::ObjectDef<GemmNode>()
109- .def_ro (" A " , &GemmNode::A )
110- .def_ro (" B " , &GemmNode::B )
111- .def_ro (" C " , &GemmNode::C )
112- .def_ro (" Aptr " , &GemmNode::Aptr )
113- .def_ro (" Bptr " , &GemmNode::Bptr )
114- .def_ro (" Cptr " , &GemmNode::Cptr )
115- .def_ro (" trans_A " , &GemmNode::trans_A )
116- .def_ro (" trans_B " , &GemmNode::trans_B )
117- .def_ro (" M " , &GemmNode::M )
118- .def_ro (" N " , &GemmNode::N )
119- .def_ro (" K " , &GemmNode::K )
120- .def_ro (" stride_A " , &GemmNode::stride_A )
121- .def_ro (" stride_B " , &GemmNode::stride_B )
122- .def_ro (" offset_A " , &GemmNode::offset_A )
123- .def_ro (" offset_B " , &GemmNode::offset_B )
124- .def_ro (" clear_accum " , &GemmNode::clear_accum )
125- .def_ro (" kPack" , &GemmNode::kPack )
126- .def_ro (" wg_wait " , &GemmNode::wg_wait )
127- .def_ro (" policy" , &GemmNode::policy );
109+ .def_ro (" a " , &GemmNode::a_ )
110+ .def_ro (" b " , &GemmNode::b_ )
111+ .def_ro (" c " , &GemmNode::c_ )
112+ .def_ro (" aRegion " , &GemmNode::aRegion_ )
113+ .def_ro (" bRegion " , &GemmNode::bRegion_ )
114+ .def_ro (" cRegion " , &GemmNode::cRegion_ )
115+ .def_ro (" transA " , &GemmNode::transA_ )
116+ .def_ro (" transB " , &GemmNode::transB_ )
117+ .def_ro (" m " , &GemmNode::m_ )
118+ .def_ro (" n " , &GemmNode::n_ )
119+ .def_ro (" k " , &GemmNode::k_ )
120+ .def_ro (" strideA " , &GemmNode::strideA_ )
121+ .def_ro (" strideB " , &GemmNode::strideB_ )
122+ .def_ro (" offsetA " , &GemmNode::offsetA_ )
123+ .def_ro (" offsetB " , &GemmNode::offsetB_ )
124+ .def_ro (" clearAccum " , &GemmNode::clearAccum_ )
125+ .def_ro (" kPack" , &GemmNode::kPack_ )
126+ .def_ro (" wgWait " , &GemmNode::wgWait_ )
127+ .def_ro (" policy" , &GemmNode::policy_ );
128128 }
129129
130130 Stmt Lower (const LowerArgs &T, arith::Analyzer *analyzer) const override ;
@@ -134,9 +134,9 @@ class GemmNode : public TileOperatorNode {
134134 TileOperator Clone () const ;
135135
136136private:
137- GemmInst GetGemmInst (int block_size, Target target) const ;
138- bool AllowTCGEN5MMA (Target target) const ;
139- bool AllowWGMMA (int block_size, Target target) const ;
137+ GemmInst getGemmInst (int block_size, Target target) const ;
138+ bool allowTcgen5Mma (Target target) const ;
139+ bool allowWgmma (int block_size, Target target) const ;
140140
141141 mutable bool completed_ = false ;
142142};
0 commit comments