Skip to content

Commit 3cfefc8

Browse files
authored
[Refactor] Support python reflection for tile operators (#783)
* Implement Fill operator and related reflection methods in TileLang - Added Fill operator implementation in `fill.cc` and `fill.h` for element-wise filling of buffers. - Introduced reflection methods for Fill, AtomicAdd, Copy, Conv2DIm2Col, FinalizeReducer, Gemm, and Parallel operators to enhance introspection capabilities. - Updated relevant files to register reflection methods and ensure proper initialization in static blocks. - Removed outdated comments and unnecessary code in various operator files to improve clarity and maintainability. - Added new Python bindings for the Fill operator in `tilelang/ir/fill.py` and updated the module imports accordingly. * Refactor operator reflection methods and improve code clarity - Updated reflection methods for AtomicAdd, Copy, FinalizeReducer, Gemm, and Parallel operators to enhance readability by using `empty()` instead of size checks. - Consolidated static initialization blocks for various operators to a single line for improved consistency. - Cleaned up whitespace and formatting in multiple files to adhere to coding standards and improve maintainability. - Added new Python bindings for operators in the `tilelang/ir` module, ensuring proper registration and organization of imports. * Refactor GEMM and AtomicAdd operations for improved clarity - Updated the `GetArchInt` function in `atomic_add.cc` to use `std::string` and `std::stoi` for better readability and type safety. - Removed unnecessary variables and comments in `gemm_sp.cc` and `gemm.cc` to streamline the `ComputeWarpPartition` method. - Cleaned up the `layout_reducer.cc` file by removing unused variable declarations, enhancing code clarity. - Added import for the `ir` module in `tilelang/__init__.py` to ensure proper organization of module imports. * Remove deprecated operator files from the tilelang IR module - Deleted files for Fill, AtomicAdd, Copy, Gemm, GemmSP, FinalizeReducer, Parallel, Reduce, and Region operators to streamline the codebase. - This cleanup enhances maintainability by removing unused code and improving overall organization of the module. * Refactor imports in tilelang IR module for improved organization - Updated import statements in `tilelang/ir.py` to reflect changes in the TVM library structure, enhancing clarity and maintainability of the codebase. * lint fix * Refactor GEMM and GEMM-SP operations to enhance clarity and maintainability - Updated the `Gemm` and `GemmSP` classes to utilize a new `GemmWarpPolicy` object for warp partitioning, improving encapsulation and readability. - Removed deprecated `ComputeWarpPartition` methods and replaced them with calls to the new policy object, streamlining the code. - Cleaned up comments and unnecessary code in `gemm.cc`, `gemm_sp.cc`, and related header files to enhance overall clarity. - Introduced a new `GemmWarpPolicyNode` class to manage warp policy attributes and methods, facilitating better organization of related functionalities. - Updated reflection methods to include the new policy structure, ensuring proper registration and introspection capabilities. * Refactor Reduce operation to utilize ReduceType class for improved clarity and maintainability - Replaced multiple conditional checks for reduce types with a single ReduceType object, simplifying the code structure. - Introduced a new ReduceTypeNode class to encapsulate reduce type logic and methods, enhancing organization. - Updated MakeInitValue, MakeReduce, and Lower methods to leverage the new ReduceType class, improving readability. - Added Python bindings for the ReduceType class in tilelang IR module to ensure proper registration and usability. * comment * Refactor operator header files for improved readability - Cleaned up formatting and whitespace in `atomic_add.h`, `copy.h`, `fill.h`, `reduce.cc`, and `reduce.h` to enhance code clarity. - Consolidated comments and adjusted line breaks for better organization and maintainability across multiple operator definitions. * Refactor MakeReduce method in ReduceOpNode for clarity - Updated the parameter name in the MakeReduce method from `rhs` to `b` and assigned it to `rhs` for improved readability. - This change enhances the clarity of the method's purpose and aligns with the overall refactoring efforts in the Reduce operation. * Update Reduce operation type checks for consistency - Changed string comparisons for reduce types in the MakeReduce method from "abs_sum" to "abssum" and "abs_max" to "absmax" for uniformity. - This adjustment enhances the clarity and consistency of the reduce type handling in the codebase.
1 parent 141e01f commit 3cfefc8

File tree

24 files changed

+757
-1254
lines changed

24 files changed

+757
-1254
lines changed

src/op/atomic_add.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ static int GetArchInt(Target target) {
3737
int arch_int = 0;
3838
auto s = target->GetAttr<String>("arch");
3939
ICHECK(s.defined());
40-
const char *arch_str = s.value().c_str();
41-
if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
42-
arch_int = atoi(&arch_str[3]);
40+
std::string arch = s.value();
41+
if (arch.rfind("sm_", 0) == 0) {
42+
arch_int = std::stoi(arch.substr(3));
4343
} else {
4444
arch_int = 0;
4545
}
@@ -255,7 +255,7 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
255255
*/
256256
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
257257
Array<IterVar> loop_vars = MakeIterVars();
258-
bool is_scalar = loop_vars.size() == 0;
258+
bool is_scalar = loop_vars.empty();
259259
if (is_scalar) {
260260
return For(Var("i"), 0, 1, ForKind::kSerial,
261261
BufferStore(dst, BufferLoad(src, {0}), {0}));
@@ -425,5 +425,7 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
425425
.set_attr<TCallEffectKind>("TCallEffectKind",
426426
Integer(CallEffectKind::kOpaque));
427427

428+
TVM_FFI_STATIC_INIT_BLOCK({ AtomicAddNode::RegisterReflection(); });
429+
428430
} // namespace tl
429431
} // namespace tvm

src/op/atomic_add.h

Lines changed: 40 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
/*!
22
* \file tl/op/atomic_add.h
3-
* \brief Define atomic add operator.
4-
*
3+
* \brief Atomic addition operations for concurrent memory updates
54
*/
65

76
#ifndef TVM_TL_OP_ATOMIC_ADD_H_
@@ -10,91 +9,20 @@
109
#include "operator.h"
1110
#include "parallel.h"
1211

13-
/**
14-
* Lower this tile operator into a TIR statement for the given lowering context.
15-
*
16-
* @param T Lowering context containing mapped buffers and iteration
17-
* information.
18-
* @param analyzer Arithmetic analyzer used to simplify and reason about
19-
* expressions.
20-
* @return A TIR Stmt that implements the atomic-add tile operation for the
21-
* provided context.
22-
*/
23-
/**
24-
* Infer memory/layout mapping for tensors and buffers used by this operator.
25-
*
26-
* @param T Layout inference context providing buffer and shape information.
27-
* @param level Inference aggressiveness level; higher levels may perform more
28-
* speculative decisions.
29-
* @return A LayoutMap describing inferred layouts for the operator's inputs and
30-
* outputs.
31-
*/
32-
/**
33-
* Get the Op registration that identifies this tile operator.
34-
*
35-
* @return A reference to the registered Op representing this operator.
36-
*/
37-
/**
38-
* Create a deep copy of this tile operator node wrapped as a TileOperator.
39-
*
40-
* @return A TileOperator handle owning a cloned AtomicAddNode.
41-
*/
42-
/**
43-
* Construct a SIMT-style For loop nest (thread/block mapping) appropriate for
44-
* the operator.
45-
*
46-
* @param analyzer Arithmetic analyzer used to simplify loop bounds and
47-
* predicates.
48-
* @return A For loop node representing the SIMT-parallel loop structure.
49-
*/
50-
/**
51-
* Create iteration variables used by this operator's loop nest.
52-
*
53-
* @return An array of IterVar objects describing the loop iteration axes.
54-
*/
55-
/**
56-
* Produce index expressions for either source or destination buffer access
57-
* based on iteration vars.
58-
*
59-
* @param ivs IterVars created by MakeIterVars().
60-
* @param src_dst Selects which indices to produce: 0 for source indices, 1 for
61-
* destination indices.
62-
* @return An array of PrimExpr index expressions suitable for indexing the
63-
* selected buffer.
64-
*/
65-
/**
66-
* Build a predicate expression that guards out-of-bounds or conditional
67-
* accesses for src or dst.
68-
*
69-
* @param analyzer Arithmetic analyzer used to simplify the predicate.
70-
* @param ivs IterVars created by MakeIterVars().
71-
* @param extents The loop extents corresponding to the itervars.
72-
* @param src_dst Selects which side the predicate is for: 0 for source, 1 for
73-
* destination.
74-
* @return A PrimExpr boolean predicate that evaluates to true for valid
75-
* iterations.
76-
*/
77-
/**
78-
* Construct an AtomicAdd tile operator from operation arguments and a buffer
79-
* mapping.
80-
*
81-
* @param args Operation arguments (e.g., values or indices) specific to the
82-
* atomic-add semantics.
83-
* @param vmap Mapping from buffer names to Buffer objects used by this
84-
* operator.
85-
*/
8612
namespace tvm {
8713
namespace tl {
8814

8915
using namespace tir;
9016

17+
/// Node class for atomic addition operations
9118
class AtomicAddNode : public TileOperatorNode {
9219
public:
93-
Buffer src, dst;
94-
Array<Range> src_range, dst_range;
95-
IntImm coalesced_width;
20+
Buffer src, dst; ///< Source and destination buffers
21+
Array<Range> src_range,
22+
dst_range; ///< Access ranges for source and destination
23+
IntImm coalesced_width; ///< Width for memory coalescing optimization
9624

97-
mutable ParallelOp par_op_;
25+
mutable ParallelOp par_op_; ///< Associated parallel operation
9826
static constexpr const char *_type_key = "tl.AtomicAdd";
9927
TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode);
10028

@@ -104,18 +32,47 @@ class AtomicAddNode : public TileOperatorNode {
10432
static const Op &Get();
10533
TileOperator Clone() const;
10634

35+
static void RegisterReflection() {
36+
namespace refl = tvm::ffi::reflection;
37+
refl::ObjectDef<AtomicAddNode>()
38+
.def_ro("src", &AtomicAddNode::src)
39+
.def_ro("dst", &AtomicAddNode::dst)
40+
.def_ro("src_range", &AtomicAddNode::src_range)
41+
.def_ro("dst_range", &AtomicAddNode::dst_range)
42+
.def_ro("coalesced_width", &AtomicAddNode::coalesced_width);
43+
}
44+
45+
bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const {
46+
return equal(src, other->src) && equal(dst, other->dst) &&
47+
equal(src_range, other->src_range) &&
48+
equal(dst_range, other->dst_range) &&
49+
equal(coalesced_width, other->coalesced_width);
50+
}
51+
52+
void SHashReduce(SHashReducer hash_reduce) const {
53+
hash_reduce(src);
54+
hash_reduce(dst);
55+
hash_reduce(src_range);
56+
hash_reduce(dst_range);
57+
hash_reduce(coalesced_width);
58+
}
59+
60+
static constexpr bool _type_has_method_sequal_reduce = true;
61+
static constexpr bool _type_has_method_shash_reduce = true;
62+
10763
protected:
64+
/// Create SIMT-style parallel loop structure
10865
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
66+
/// Generate iteration variables for loop nest
10967
Array<IterVar> MakeIterVars() const;
110-
111-
// ivs: itervars returned by MakeIterVars()
112-
// src_dst: 0 for src_indices, 1 for dst_indices
68+
/// Generate buffer indices from iteration variables
11369
Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
114-
70+
/// Create boundary predicate for memory safety
11571
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
11672
Array<PrimExpr> extents, int src_dst) const;
11773
};
11874

75+
/// Wrapper class for atomic addition operations
11976
class AtomicAdd : public TileOperator {
12077
public:
12178
TVM_DEFINE_OBJECT_REF_METHODS(AtomicAdd, TileOperator, AtomicAddNode);

src/op/copy.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer,
297297
*/
298298
For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
299299
Array<IterVar> loop_vars = MakeIterVars();
300-
bool is_scalar = loop_vars.size() == 0;
300+
bool is_scalar = loop_vars.empty();
301301
if (is_scalar) {
302302
return For(Var("i"), 0, 1, ForKind::kSerial,
303303
BufferStore(dst, BufferLoad(src, {0}), {0}));
@@ -1197,7 +1197,7 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
11971197
int swizzle;
11981198
int max_dim;
11991199
};
1200-
static const SwizzleCheck swizzle_checks[] = {
1200+
static const std::vector<SwizzleCheck> swizzle_checks = {
12011201
{static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B), 32},
12021202
{static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B), 64},
12031203
{static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B), 128},
@@ -1559,5 +1559,9 @@ TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
15591559
.set_attr<TCallEffectKind>("TCallEffectKind",
15601560
Integer(CallEffectKind::kOpaque));
15611561

1562+
TVM_FFI_STATIC_INIT_BLOCK({
1563+
CopyNode::RegisterReflection();
1564+
Conv2DIm2ColOpNode::RegisterReflection();
1565+
});
15621566
} // namespace tl
15631567
} // namespace tvm

0 commit comments

Comments
 (0)