Skip to content

Commit e8138f7

Browse files
authored
[TIR] Remove ProducerConsumer and AllocateNode::new_expr (#5333)
* [TIR] Remove ProducerConsumer and AllocateNode::new_expr This PR removes two legacy IR parts in TIR that are deprecated. ProducerConsumer node only serves as a hint markup and may no longer be informative after extensive transformations in the pass. If necessary, we can add related info via AttrStmt. The new_expr field in the AllocateNode is deprecated since it can just be replaced by a LetStmt. - Remove dependencies of passes on ProducerConsumer. - Remove ProducerConsumer from the IR. - Remove the deprecated fields (new_expr, free_function) from AllocateNode. * Fix additional testcases
1 parent f143881 commit e8138f7

40 files changed

+196
-459
lines changed

include/tvm/tir/stmt.h

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -178,42 +178,6 @@ class AssertStmtNode : public StmtNode {
178178
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
179179
};
180180

181-
// TODO(tvm-team): consider consolidate with AttrStmt.
182-
/*! \brief annotation node of producer/consumer relation. */
183-
class ProducerConsumerNode : public StmtNode {
184-
public:
185-
/*! \brief The corresponding tensor. */
186-
FunctionRef func;
187-
/*! \brief Whether the relation is producer. */
188-
bool is_producer;
189-
/*! \brief Body to be executed. */
190-
Stmt body;
191-
192-
void VisitAttrs(AttrVisitor* v) {
193-
v->Visit("func", &func);
194-
v->Visit("is_producer", &is_producer);
195-
v->Visit("body", &body);
196-
}
197-
198-
bool SEqualReduce(const ProducerConsumerNode* other, SEqualReducer equal) const {
199-
return
200-
equal(func, other->func) &&
201-
equal(is_producer, other->is_producer) &&
202-
equal(body, other->body);
203-
}
204-
205-
void SHashReduce(SHashReducer hash_reduce) const {
206-
hash_reduce(func);
207-
hash_reduce(is_producer);
208-
hash_reduce(body);
209-
}
210-
211-
TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
212-
213-
static constexpr const char* _type_key = "ProducerConsumer";
214-
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumerNode, StmtNode);
215-
};
216-
217181
/*!
218182
* \brief Store value to the buffer.
219183
*
@@ -385,10 +349,6 @@ class AllocateNode : public StmtNode {
385349
PrimExpr condition;
386350
/*! \brief The body to be executed. */
387351
Stmt body;
388-
// The following two fields are deprecated
389-
// kept for backward compatibility and will be refactored later.
390-
PrimExpr new_expr;
391-
std::string free_function;
392352

393353
void VisitAttrs(AttrVisitor* v) {
394354
v->Visit("buffer_var", &buffer_var);
@@ -419,9 +379,7 @@ class AllocateNode : public StmtNode {
419379
DataType dtype,
420380
Array<PrimExpr> extents,
421381
PrimExpr condition,
422-
Stmt body,
423-
PrimExpr new_expr = PrimExpr(),
424-
std::string free_function = std::string());
382+
Stmt body);
425383

426384
/*!
427385
* \brief If the buffer size is constant, return the size.
@@ -589,8 +547,6 @@ class SeqStmt : public Stmt {
589547
*
590548
* - When an argument is nullptr, it will be ignored.
591549
* - When an argument is an array or a SeqStmt, it will be flattened recursively.
592-
* - When an argument is a consumer block in ProducerConsumer, the consumer
593-
* tag will be dropped as such information is not useful in lowering.
594550
* - A normal Stmt will be appended to the end of the sequence.
595551
*
596552
* \note This function can directly return an element
@@ -618,13 +574,6 @@ class SeqStmt : public Stmt {
618574
if (!stmt.defined()) return;
619575
if (auto* op = stmt.as<SeqStmtNode>()) {
620576
operator()(0, op->seq);
621-
} else if (auto* op = stmt.as<ProducerConsumerNode>()) {
622-
// NOTE: The consumer block annotation was not as useful and can be safely dropped.
623-
if (!op->is_producer) {
624-
operator()(0, op->body);
625-
} else {
626-
seq_->push_back(stmt);
627-
}
628577
} else {
629578
seq_->push_back(stmt);
630579
}

include/tvm/tir/stmt_functor.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
9494
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9595
virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9696
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
97-
virtual R VisitStmt_(const ProducerConsumerNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9897
virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9998
virtual R VisitStmt_(const RealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
10099
virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
@@ -117,7 +116,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
117116
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
118117
IR_STMT_FUNCTOR_DISPATCH(FreeNode);
119118
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
120-
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumerNode);
121119
IR_STMT_FUNCTOR_DISPATCH(ProvideNode);
122120
IR_STMT_FUNCTOR_DISPATCH(RealizeNode);
123121
IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
@@ -158,7 +156,6 @@ class TVM_DLL StmtVisitor :
158156
void VisitStmt_(const BufferStoreNode* op) override;
159157
void VisitStmt_(const FreeNode* op) override;
160158
void VisitStmt_(const AssertStmtNode* op) override;
161-
void VisitStmt_(const ProducerConsumerNode* op) override;
162159
void VisitStmt_(const ProvideNode* op) override;
163160
void VisitStmt_(const RealizeNode* op) override;
164161
void VisitStmt_(const PrefetchNode* op) override;
@@ -253,7 +250,6 @@ class TVM_DLL StmtMutator :
253250
Stmt VisitStmt_(const BufferStoreNode* op) override;
254251
Stmt VisitStmt_(const FreeNode* op) override;
255252
Stmt VisitStmt_(const AssertStmtNode* op) override;
256-
Stmt VisitStmt_(const ProducerConsumerNode* op) override;
257253
Stmt VisitStmt_(const ProvideNode* op) override;
258254
Stmt VisitStmt_(const RealizeNode* op) override;
259255
Stmt VisitStmt_(const PrefetchNode* op) override;

python/tvm/tir/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
2828
from .expr import IterVar, Any
2929

30-
from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
30+
from .stmt import Stmt, LetStmt, AssertStmt, For
3131
from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
3232
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
3333

python/tvm/tir/stmt.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -76,26 +76,6 @@ def __init__(self, condition, message, body):
7676
_ffi_api.AssertStmt, condition, message, body)
7777

7878

79-
@tvm._ffi.register_object
80-
class ProducerConsumer(Stmt):
81-
"""ProducerConsumer node.
82-
83-
Parameters
84-
----------
85-
func : Operation
86-
The Operation.
87-
88-
is_producer : bool
89-
Whether if the node is producer.
90-
91-
body : Stmt
92-
The body statement.
93-
"""
94-
def __init__(self, func, is_producer, body):
95-
self.__init_handle_by_constructor__(
96-
_ffi_api.ProducerConsumer, func, is_producer, body)
97-
98-
9979
@tvm._ffi.register_object
10080
class For(Stmt):
10181
"""For node.
@@ -425,6 +405,4 @@ def stmt_list(stmt):
425405
for x in stmt:
426406
res += stmt_list(x)
427407
return res
428-
if isinstance(stmt, ProducerConsumer):
429-
return stmt_list(stmt.body)
430408
return [stmt]

src/contrib/hybrid/codegen_hybrid.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,6 @@ void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) {
399399
stream << str << "\n";
400400
}
401401

402-
void CodeGenHybrid::VisitStmt_(const ProducerConsumerNode* op) {
403-
PrintStmt(op->body);
404-
}
405-
406402
void CodeGenHybrid::PrintIndent() {
407403
stream << std::string(indent_, ' ');
408404
}

src/contrib/hybrid/codegen_hybrid.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ class CodeGenHybrid :
131131
void VisitStmt_(const AssertStmtNode* op) override;
132132
void VisitStmt_(const EvaluateNode* op) override;
133133
void VisitStmt_(const SeqStmtNode* op) override;
134-
void VisitStmt_(const ProducerConsumerNode* op) override;
135134
/*!
136135
* \brief Print Type represetnation of type t.
137136
* \param t The type representation.

src/target/llvm/codegen_amdgpu.cc

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -71,55 +71,53 @@ class CodeGenAMDGPU : public CodeGenLLVM {
7171
void VisitStmt_(const AllocateNode* op) final {
7272
CHECK(!is_zero(op->condition));
7373
llvm::Value* buf = nullptr;
74-
if (op->new_expr.defined()) {
75-
CHECK_EQ(op->free_function, "nop");
76-
buf = MakeValue(op->new_expr);
77-
} else {
78-
int32_t constant_size = op->constant_allocation_size();
79-
CHECK_GT(constant_size, 0)
80-
<< "Can only handle constant size stack allocation in GPU";
81-
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
82-
if (constant_size % 4 == 0 && info.alignment == 0) {
83-
info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
84-
}
85-
// maximum necessary alignment in the AMD devices
86-
if (info.alignment > 16) {
87-
info.alignment = 16;
88-
}
89-
if (info.scope.rank == runtime::StorageRank::kLocal) {
90-
// const int local_address_space = 5;
91-
// TODO(tqchen): for higher version of LLVM, local address space can be set.
92-
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
93-
return builder_->CreateAlloca(
94-
DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
95-
});
96-
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
74+
75+
int32_t constant_size = op->constant_allocation_size();
76+
CHECK_GT(constant_size, 0)
77+
<< "Can only handle constant size stack allocation in GPU";
78+
79+
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
80+
if (constant_size % 4 == 0 && info.alignment == 0) {
81+
info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
82+
}
83+
// maximum necessary alignment in the AMD devices
84+
if (info.alignment > 16) {
85+
info.alignment = 16;
86+
}
87+
if (info.scope.rank == runtime::StorageRank::kLocal) {
88+
// const int local_address_space = 5;
89+
// TODO(tqchen): for higher version of LLVM, local address space can be set.
90+
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
91+
return builder_->CreateAlloca(
92+
DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
93+
});
94+
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
9795
#if TVM_LLVM_VERSION >= 100
98-
alloca->setAlignment(llvm::Align(info.alignment));
96+
alloca->setAlignment(llvm::Align(info.alignment));
9997
#else
100-
alloca->setAlignment(info.alignment);
98+
alloca->setAlignment(info.alignment);
10199
#endif
102-
}
103-
buf = alloca;
104-
} else {
105-
CHECK(info.scope.rank == runtime::StorageRank::kShared)
106-
<< "Can only allocate shared or local memory inside kernel";
107-
// Shared memory: address space == 3
108-
const unsigned shared_address_space = 3;
109-
llvm::Type* type = llvm::ArrayType::get(
110-
DTypeToLLVMType(op->dtype), constant_size);
111-
// Allocate shared memory in global, address_space = 3
112-
llvm::GlobalVariable *global = new llvm::GlobalVariable(
113-
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
114-
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
100+
}
101+
buf = alloca;
102+
} else {
103+
CHECK(info.scope.rank == runtime::StorageRank::kShared)
104+
<< "Can only allocate shared or local memory inside kernel";
105+
// Shared memory: address space == 3
106+
const unsigned shared_address_space = 3;
107+
llvm::Type* type = llvm::ArrayType::get(
108+
DTypeToLLVMType(op->dtype), constant_size);
109+
// Allocate shared memory in global, address_space = 3
110+
llvm::GlobalVariable *global = new llvm::GlobalVariable(
111+
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
112+
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
115113
#if TVM_LLVM_VERSION >= 100
116-
global->setAlignment(llvm::Align(info.alignment));
114+
global->setAlignment(llvm::Align(info.alignment));
117115
#else
118-
global->setAlignment(info.alignment);
116+
global->setAlignment(info.alignment);
119117
#endif
120-
buf = global;
121-
}
118+
buf = global;
122119
}
120+
123121
buf = builder_->CreatePointerCast(
124122
buf, DTypeToLLVMType(op->dtype)->getPointerTo(
125123
buf->getType()->getPointerAddressSpace()));

src/target/llvm/codegen_llvm.cc

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,10 +1268,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
12681268
void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
12691269
CHECK(!is_zero(op->condition));
12701270
llvm::Value* buf = nullptr;
1271-
if (op->new_expr.defined()) {
1272-
CHECK_EQ(op->free_function, "nop");
1273-
buf = MakeValue(op->new_expr);
1274-
} else {
1271+
12751272
int32_t constant_size = op->constant_allocation_size();
12761273
CHECK_GT(constant_size, 0)
12771274
<< "Can only handle constant size stack allocation";
@@ -1296,7 +1293,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
12961293
}
12971294
info.alignment = alloca->getAlignment();
12981295
buf = alloca;
1299-
}
1296+
13001297
buf = builder_->CreatePointerCast(
13011298
buf, DTypeToLLVMType(op->dtype)->getPointerTo(
13021299
buf->getType()->getPointerAddressSpace()));
@@ -1359,9 +1356,6 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) {
13591356
MakeValue(op->value);
13601357
}
13611358

1362-
void CodeGenLLVM::VisitStmt_(const ProducerConsumerNode* op) {
1363-
this->VisitStmt(op->body);
1364-
}
13651359
} // namespace codegen
13661360
} // namespace tvm
13671361
#endif // TVM_LLVM_VERSION

src/target/llvm/codegen_llvm.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ class CodeGenLLVM :
150150
void VisitStmt_(const LetStmtNode* op) override;
151151
void VisitStmt_(const SeqStmtNode* op) override;
152152
void VisitStmt_(const EvaluateNode* op) override;
153-
void VisitStmt_(const ProducerConsumerNode* op) override;
154153

155154
protected:
156155
/*! \brief The storage information */

0 commit comments

Comments
 (0)