-
Couldn't load subscription status.
- Fork 286
[Refactor] Support python reflection for tile operators #783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
313298e
bb52df6
23b63e2
7714cb0
2837b74
89c08aa
a5c2643
de43389
082ad17
9e3473d
362c19e
f2604a3
14d7b3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,7 +1,6 @@ | ||||||||||||||||||||||||||||||||||||||
| /*! | ||||||||||||||||||||||||||||||||||||||
| * \file tl/op/atomic_add.h | ||||||||||||||||||||||||||||||||||||||
| * \brief Define atomic add operator. | ||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||
| * \brief Atomic addition operations for concurrent memory updates | ||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| #ifndef TVM_TL_OP_ATOMIC_ADD_H_ | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -10,91 +9,20 @@ | |||||||||||||||||||||||||||||||||||||
| #include "operator.h" | ||||||||||||||||||||||||||||||||||||||
| #include "parallel.h" | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||
| * Lower this tile operator into a TIR statement for the given lowering context. | ||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||
| * @param T Lowering context containing mapped buffers and iteration | ||||||||||||||||||||||||||||||||||||||
| * information. | ||||||||||||||||||||||||||||||||||||||
| * @param analyzer Arithmetic analyzer used to simplify and reason about | ||||||||||||||||||||||||||||||||||||||
| * expressions. | ||||||||||||||||||||||||||||||||||||||
| * @return A TIR Stmt that implements the atomic-add tile operation for the | ||||||||||||||||||||||||||||||||||||||
| * provided context. | ||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||
| * Infer memory/layout mapping for tensors and buffers used by this operator. | ||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||
| * @param T Layout inference context providing buffer and shape information. | ||||||||||||||||||||||||||||||||||||||
| * @param level Inference aggressiveness level; higher levels may perform more | ||||||||||||||||||||||||||||||||||||||
| * speculative decisions. | ||||||||||||||||||||||||||||||||||||||
| * @return A LayoutMap describing inferred layouts for the operator's inputs and | ||||||||||||||||||||||||||||||||||||||
| * outputs. | ||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||
| * Get the Op registration that identifies this tile operator. | ||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||
| * @return A reference to the registered Op representing this operator. | ||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||
| * Create a deep copy of this tile operator node wrapped as a TileOperator. | ||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||
| * @return A TileOperator handle owning a cloned AtomicAddNode. | ||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||
| * Construct a SIMT-style For loop nest (thread/block mapping) appropriate for | ||||||||||||||||||||||||||||||||||||||
| * the operator. | ||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||
| * @param analyzer Arithmetic analyzer used to simplify loop bounds and | ||||||||||||||||||||||||||||||||||||||
| * predicates. | ||||||||||||||||||||||||||||||||||||||
| * @return A For loop node representing the SIMT-parallel loop structure. | ||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||
| * Create iteration variables used by this operator's loop nest. | ||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||
| * @return An array of IterVar objects describing the loop iteration axes. | ||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||
| * Produce index expressions for either source or destination buffer access | ||||||||||||||||||||||||||||||||||||||
| * based on iteration vars. | ||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||
| * @param ivs IterVars created by MakeIterVars(). | ||||||||||||||||||||||||||||||||||||||
| * @param src_dst Selects which indices to produce: 0 for source indices, 1 for | ||||||||||||||||||||||||||||||||||||||
| * destination indices. | ||||||||||||||||||||||||||||||||||||||
| * @return An array of PrimExpr index expressions suitable for indexing the | ||||||||||||||||||||||||||||||||||||||
| * selected buffer. | ||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||
| * Build a predicate expression that guards out-of-bounds or conditional | ||||||||||||||||||||||||||||||||||||||
| * accesses for src or dst. | ||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||
| * @param analyzer Arithmetic analyzer used to simplify the predicate. | ||||||||||||||||||||||||||||||||||||||
| * @param ivs IterVars created by MakeIterVars(). | ||||||||||||||||||||||||||||||||||||||
| * @param extents The loop extents corresponding to the itervars. | ||||||||||||||||||||||||||||||||||||||
| * @param src_dst Selects which side the predicate is for: 0 for source, 1 for | ||||||||||||||||||||||||||||||||||||||
| * destination. | ||||||||||||||||||||||||||||||||||||||
| * @return A PrimExpr boolean predicate that evaluates to true for valid | ||||||||||||||||||||||||||||||||||||||
| * iterations. | ||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||
| * Construct an AtomicAdd tile operator from operation arguments and a buffer | ||||||||||||||||||||||||||||||||||||||
| * mapping. | ||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||
| * @param args Operation arguments (e.g., values or indices) specific to the | ||||||||||||||||||||||||||||||||||||||
| * atomic-add semantics. | ||||||||||||||||||||||||||||||||||||||
| * @param vmap Mapping from buffer names to Buffer objects used by this | ||||||||||||||||||||||||||||||||||||||
| * operator. | ||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||
| namespace tvm { | ||||||||||||||||||||||||||||||||||||||
| namespace tl { | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| using namespace tir; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| /// Node class for atomic addition operations | ||||||||||||||||||||||||||||||||||||||
| class AtomicAddNode : public TileOperatorNode { | ||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||
| Buffer src, dst; | ||||||||||||||||||||||||||||||||||||||
| Array<Range> src_range, dst_range; | ||||||||||||||||||||||||||||||||||||||
| IntImm coalesced_width; | ||||||||||||||||||||||||||||||||||||||
| Buffer src, dst; ///< Source and destination buffers | ||||||||||||||||||||||||||||||||||||||
| Array<Range> src_range, | ||||||||||||||||||||||||||||||||||||||
| dst_range; ///< Access ranges for source and destination | ||||||||||||||||||||||||||||||||||||||
| IntImm coalesced_width; ///< Width for memory coalescing optimization | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| mutable ParallelOp par_op_; | ||||||||||||||||||||||||||||||||||||||
| mutable ParallelOp par_op_; ///< Associated parallel operation | ||||||||||||||||||||||||||||||||||||||
| static constexpr const char *_type_key = "tl.AtomicAdd"; | ||||||||||||||||||||||||||||||||||||||
| TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
@@ -104,18 +32,47 @@ class AtomicAddNode : public TileOperatorNode { | |||||||||||||||||||||||||||||||||||||
| static const Op &Get(); | ||||||||||||||||||||||||||||||||||||||
| TileOperator Clone() const; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| static void RegisterReflection() { | ||||||||||||||||||||||||||||||||||||||
| namespace refl = tvm::ffi::reflection; | ||||||||||||||||||||||||||||||||||||||
| refl::ObjectDef<AtomicAddNode>() | ||||||||||||||||||||||||||||||||||||||
| .def_ro("src", &AtomicAddNode::src) | ||||||||||||||||||||||||||||||||||||||
| .def_ro("dst", &AtomicAddNode::dst) | ||||||||||||||||||||||||||||||||||||||
| .def_ro("src_range", &AtomicAddNode::src_range) | ||||||||||||||||||||||||||||||||||||||
| .def_ro("dst_range", &AtomicAddNode::dst_range) | ||||||||||||||||||||||||||||||||||||||
| .def_ro("coalesced_width", &AtomicAddNode::coalesced_width); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+35
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainEnsure RegisterReflection() runs at static init. Please verify a TU calls AtomicAddNode::RegisterReflection() (and node type registration) during startup; otherwise Python reflection won’t work. 🏁 Script executed: #!/bin/bash
# Expect: at least one call to RegisterReflection() and node/op registrations.
rg -n -C2 -e 'AtomicAddNode::RegisterReflection\(' -e 'TVM_REGISTER_NODE_TYPE\(\s*AtomicAddNode\s*\)' -e 'TIR_REGISTER_TL_OP\(\s*AtomicAdd'Length of output: 663 Register node type for AtomicAddNode 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const { | ||||||||||||||||||||||||||||||||||||||
| return equal(src, other->src) && equal(dst, other->dst) && | ||||||||||||||||||||||||||||||||||||||
| equal(src_range, other->src_range) && | ||||||||||||||||||||||||||||||||||||||
| equal(dst_range, other->dst_range) && | ||||||||||||||||||||||||||||||||||||||
| equal(coalesced_width, other->coalesced_width); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| void SHashReduce(SHashReducer hash_reduce) const { | ||||||||||||||||||||||||||||||||||||||
| hash_reduce(src); | ||||||||||||||||||||||||||||||||||||||
| hash_reduce(dst); | ||||||||||||||||||||||||||||||||||||||
| hash_reduce(src_range); | ||||||||||||||||||||||||||||||||||||||
| hash_reduce(dst_range); | ||||||||||||||||||||||||||||||||||||||
| hash_reduce(coalesced_width); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| static constexpr bool _type_has_method_sequal_reduce = true; | ||||||||||||||||||||||||||||||||||||||
| static constexpr bool _type_has_method_shash_reduce = true; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| protected: | ||||||||||||||||||||||||||||||||||||||
| /// Create SIMT-style parallel loop structure | ||||||||||||||||||||||||||||||||||||||
| For MakeSIMTLoop(arith::Analyzer *analyzer) const; | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+64
to
65
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: scalar path in MakeSIMTLoop does a plain store, not an atomic add. In src/op/atomic_add.cc, the is_scalar branch emits BufferStore(dst, src) instead of an atomic add. This silently changes semantics. Patch suggestion (cc file): @@
- if (is_scalar) {
- return For(Var("i"), 0, 1, ForKind::kSerial,
- BufferStore(dst, BufferLoad(src, {0}), {0}));
- }
+ if (is_scalar) {
+ Array<PrimExpr> new_args;
+ new_args.push_back(StringImm("AtomicAdd"));
+ PrimExpr src_value = BufferLoad(src, {0});
+ if (src->dtype != dst->dtype) {
+ src_value = Cast(dst->dtype, src_value);
+ }
+ PrimExpr dst_value = BufferLoad(dst, {0});
+ Call address_of_value =
+ tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_value});
+ new_args.push_back(address_of_value);
+ new_args.push_back(src_value);
+ Call atomicadd_call =
+ tvm::tir::Call(dst->dtype, builtin::call_extern(), new_args);
+ return For(Var("i"), 0, 1, ForKind::kSerial, Evaluate(atomicadd_call));
+ }📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
| /// Generate iteration variables for loop nest | ||||||||||||||||||||||||||||||||||||||
| Array<IterVar> MakeIterVars() const; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // ivs: itervars returned by MakeIterVars() | ||||||||||||||||||||||||||||||||||||||
| // src_dst: 0 for src_indices, 1 for dst_indices | ||||||||||||||||||||||||||||||||||||||
| /// Generate buffer indices from iteration variables | ||||||||||||||||||||||||||||||||||||||
| Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| /// Create boundary predicate for memory safety | ||||||||||||||||||||||||||||||||||||||
| PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs, | ||||||||||||||||||||||||||||||||||||||
| Array<PrimExpr> extents, int src_dst) const; | ||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| /// Wrapper class for atomic addition operations | ||||||||||||||||||||||||||||||||||||||
| class AtomicAdd : public TileOperator { | ||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||
| TVM_DEFINE_OBJECT_REF_METHODS(AtomicAdd, TileOperator, AtomicAddNode); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid std::stoi; handle “sm_90a” and prevent exceptions in GetArchInt
std::stoi throws on non-digit suffixes (e.g., valid CUDA arch strings like “sm_90a”) and changes behavior vs. lenient parsing. Parse the digit span only and use from_chars to avoid exceptions.
Apply this diff:
Add required headers (outside the hunk):