Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 1 addition & 52 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,42 +178,6 @@ class AssertStmtNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
};

// TODO(tvm-team): consider consolidate with AttrStmt.
/*! \brief annotation node of producer/consumer relation. */
class ProducerConsumerNode : public StmtNode {
public:
/*! \brief The corresponding tensor. */
FunctionRef func;
/*! \brief Whether the relation is producer. */
bool is_producer;
/*! \brief Body to be executed. */
Stmt body;

void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("is_producer", &is_producer);
v->Visit("body", &body);
}

bool SEqualReduce(const ProducerConsumerNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(is_producer, other->is_producer) &&
equal(body, other->body);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(is_producer);
hash_reduce(body);
}

TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);

static constexpr const char* _type_key = "ProducerConsumer";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumerNode, StmtNode);
};

/*!
* \brief Store value to the buffer.
*
Expand Down Expand Up @@ -385,10 +349,6 @@ class AllocateNode : public StmtNode {
PrimExpr condition;
/*! \brief The body to be executed. */
Stmt body;
// The following two fields are deprecated
// kept for backward compatibility and will be refactored later.
PrimExpr new_expr;
std::string free_function;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
Expand Down Expand Up @@ -419,9 +379,7 @@ class AllocateNode : public StmtNode {
DataType dtype,
Array<PrimExpr> extents,
PrimExpr condition,
Stmt body,
PrimExpr new_expr = PrimExpr(),
std::string free_function = std::string());
Stmt body);

/*!
* \brief If the buffer size is constant, return the size.
Expand Down Expand Up @@ -589,8 +547,6 @@ class SeqStmt : public Stmt {
*
* - When an argument is nullptr, it will be ignored.
* - When an argument is an array or a SeqStmt, it will be flattened recursively.
* - When an argument is a consumer block in ProducerConsumer, the consumer
* tag will be dropped as such information is not useful in lowering.
* - A normal Stmt will be appended to the end of the sequence.
*
* \note This function can directly return an element
Expand Down Expand Up @@ -618,13 +574,6 @@ class SeqStmt : public Stmt {
if (!stmt.defined()) return;
if (auto* op = stmt.as<SeqStmtNode>()) {
operator()(0, op->seq);
} else if (auto* op = stmt.as<ProducerConsumerNode>()) {
// NOTE: The consumer block annotation was not as useful and can be safely dropped.
if (!op->is_producer) {
operator()(0, op->body);
} else {
seq_->push_back(stmt);
}
} else {
seq_->push_back(stmt);
}
Expand Down
4 changes: 0 additions & 4 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerConsumerNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const RealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand All @@ -117,7 +116,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
IR_STMT_FUNCTOR_DISPATCH(FreeNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumerNode);
IR_STMT_FUNCTOR_DISPATCH(ProvideNode);
IR_STMT_FUNCTOR_DISPATCH(RealizeNode);
IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
Expand Down Expand Up @@ -158,7 +156,6 @@ class TVM_DLL StmtVisitor :
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const FreeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
void VisitStmt_(const ProvideNode* op) override;
void VisitStmt_(const RealizeNode* op) override;
void VisitStmt_(const PrefetchNode* op) override;
Expand Down Expand Up @@ -253,7 +250,6 @@ class TVM_DLL StmtMutator :
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const FreeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override;
Stmt VisitStmt_(const ProducerConsumerNode* op) override;
Stmt VisitStmt_(const ProvideNode* op) override;
Stmt VisitStmt_(const RealizeNode* op) override;
Stmt VisitStmt_(const PrefetchNode* op) override;
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar, Any

from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import Stmt, LetStmt, AssertStmt, For
from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list

Expand Down
22 changes: 0 additions & 22 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,6 @@ def __init__(self, condition, message, body):
_ffi_api.AssertStmt, condition, message, body)


@tvm._ffi.register_object
class ProducerConsumer(Stmt):
"""ProducerConsumer node.

Parameters
----------
func : Operation
The Operation.

is_producer : bool
Whether if the node is producer.

body : Stmt
The body statement.
"""
def __init__(self, func, is_producer, body):
self.__init_handle_by_constructor__(
_ffi_api.ProducerConsumer, func, is_producer, body)


@tvm._ffi.register_object
class For(Stmt):
"""For node.
Expand Down Expand Up @@ -425,6 +405,4 @@ def stmt_list(stmt):
for x in stmt:
res += stmt_list(x)
return res
if isinstance(stmt, ProducerConsumer):
return stmt_list(stmt.body)
return [stmt]
4 changes: 0 additions & 4 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,6 @@ void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) {
stream << str << "\n";
}

void CodeGenHybrid::VisitStmt_(const ProducerConsumerNode* op) {
PrintStmt(op->body);
}

void CodeGenHybrid::PrintIndent() {
stream << std::string(indent_, ' ');
}
Expand Down
1 change: 0 additions & 1 deletion src/contrib/hybrid/codegen_hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ class CodeGenHybrid :
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
/*!
* \brief Print Type represetnation of type t.
* \param t The type representation.
Expand Down
82 changes: 40 additions & 42 deletions src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,55 +71,53 @@ class CodeGenAMDGPU : public CodeGenLLVM {
void VisitStmt_(const AllocateNode* op) final {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
CHECK_EQ(op->free_function, "nop");
buf = MakeValue(op->new_expr);
} else {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation in GPU";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
}
// maximum necessary alignment in the AMD devices
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(
DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {

int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation in GPU";

StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
}
// maximum necessary alignment in the AMD devices
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(
DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
alloca->setAlignment(llvm::Align(info.alignment));
alloca->setAlignment(llvm::Align(info.alignment));
#else
alloca->setAlignment(info.alignment);
alloca->setAlignment(info.alignment);
#endif
}
buf = alloca;
} else {
CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get(
DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
}
buf = alloca;
} else {
CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get(
DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
#if TVM_LLVM_VERSION >= 100
global->setAlignment(llvm::Align(info.alignment));
global->setAlignment(llvm::Align(info.alignment));
#else
global->setAlignment(info.alignment);
global->setAlignment(info.alignment);
#endif
buf = global;
}
buf = global;
}

buf = builder_->CreatePointerCast(
buf, DTypeToLLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
Expand Down
10 changes: 2 additions & 8 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1268,10 +1268,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
CHECK_EQ(op->free_function, "nop");
buf = MakeValue(op->new_expr);
} else {

int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation";
Expand All @@ -1296,7 +1293,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
}
info.alignment = alloca->getAlignment();
buf = alloca;
}

buf = builder_->CreatePointerCast(
buf, DTypeToLLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
Expand Down Expand Up @@ -1359,9 +1356,6 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) {
MakeValue(op->value);
}

void CodeGenLLVM::VisitStmt_(const ProducerConsumerNode* op) {
this->VisitStmt(op->body);
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
1 change: 0 additions & 1 deletion src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ class CodeGenLLVM :
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;

protected:
/*! \brief The storage information */
Expand Down
Loading