Skip to content

Commit

Permalink
Add Namer (apache#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and Hzfengsy committed Jul 27, 2022
1 parent 8bc748b commit 3c7f908
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 26 deletions.
36 changes: 22 additions & 14 deletions src/script/builder/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,6 @@ void Builder::ExitWithScope() {
std::vector<Builder>* stack = ThreadLocalBuilderStack();
ICHECK(!stack->empty());
stack->pop_back();
// IRModuleFrame frame = Downcast<IRModuleFrame>(n->frames.back());
// n->frames.pop_back();
// if (!frame->stmts.empty()) {
// ICHECK(frame->global_vars.empty());
// ICHECK(frame->functions.empty());
// n->result = frame->stmts;
// } else {
// Map<GlobalVar, BaseFunc> func_map;
// ICHECK_EQ(frame->functions.size(), frame->global_vars.size());
// int m = frame->functions.size();
// for (int i = 0; i < m; ++i) {
// func_map.Set(frame->global_vars[i], frame->functions[i]);
// }
// }
}

Builder Builder::Current() {
Expand All @@ -70,6 +56,28 @@ Builder Builder::Current() {
return stack->back();
}

Namer::FType& Namer::vtable() {
static FType inst;
return inst;
}

void Namer::Name(ObjectRef node, String name) {
static const FType& f = vtable();
CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name;
CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \""
<< node->GetTypeKey();
f(node, name);
}

namespace details {

ObjectRef DefImpl(String name, ObjectRef obj) {
Namer::Name(obj, name);
return obj;
}

} // namespace details

TVM_REGISTER_NODE_TYPE(BuilderNode);

} // namespace builder
Expand Down
20 changes: 20 additions & 0 deletions src/script/builder/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,26 @@ class Builder : public runtime::ObjectRef {
static Builder Current();
};

template <class TObjectRef>
inline TObjectRef Def(String name, TObjectRef obj);

namespace details {
ObjectRef DefImpl(String name, ObjectRef obj);
}

class Namer {
public:
using FType = NodeFunctor<void(const ObjectRef&, String)>;
static FType& vtable();

static void Name(ObjectRef node, String name);
};

template <class TObjectRef>
inline TObjectRef Def(String name, TObjectRef obj) {
return Downcast<TObjectRef>(details::DefImpl(name, obj));
}

template <typename TFrame>
inline Optional<TFrame> BuilderNode::FindFrame() const {
using TFrameNode = typename TFrame::ContainerType;
Expand Down
15 changes: 15 additions & 0 deletions src/script/builder/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ir/module.h>

#include "./builder.h"

namespace tvm {
Expand Down Expand Up @@ -44,7 +46,20 @@ IRModuleFrame::IRModuleFrame() {
data_ = std::move(n);
}

void IRModuleFrameNode::ExitWithScope() {
ICHECK_EQ(functions.size(), global_vars.size());
int n = functions.size();
Map<GlobalVar, BaseFunc> func_map;
for (int i = 0; i < n; ++i) {
func_map.Set(global_vars[i], functions[i]);
}
Builder builder = Builder::Current();
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
builder->result = tvm::IRModule(func_map);
}

TVM_REGISTER_NODE_TYPE(FrameNode);
TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);

} // namespace builder
} // namespace script
Expand Down
3 changes: 3 additions & 0 deletions src/script/builder/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ class IRModuleFrameNode : public FrameNode {

static constexpr const char* _type_key = "script.builder.IRModuleFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, FrameNode);

public:
void ExitWithScope() final;
};

class IRModuleFrame : public Frame {
Expand Down
16 changes: 8 additions & 8 deletions src/script/builder/tir/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,18 @@ void TestPOC() {
With<Builder> builder;
{
With<PrimFuncFrame> _{T::PrimFunc_("main")};
Buffer A = T::Arg(T::Buffer_({128, 128, 128}, DataType::Float(32)));
Buffer B = T::Arg(T::Buffer_({128, 128, 128}, DataType::Float(32)));
Buffer A = T::Arg("A", T::Buffer_({128, 128, 128}, DataType::Float(32)));
Buffer B = T::Arg("B", T::Buffer_({128, 128, 128}, DataType::Float(32)));
{
With<ForFrame> _{T::Grid({128, 128, 128})};
Var i = _()->vars[0];
Var j = _()->vars[1];
Var k = _()->vars[2];
Var i = Def("i", _()->vars[0]);
Var j = Def("j", _()->vars[1]);
Var k = Def("k", _()->vars[2]);
{
With<BlockFrame> _{T::Block_("block")};
IterVar vi = T::axis::Spatial(Range(0, 128), i);
IterVar vj = T::axis::Spatial(Range(0, 128), j);
IterVar vk = T::axis::Reduce(Range(0, 128), k);
IterVar vi = Def("vi", T::axis::Spatial(Range(0, 128), i));
IterVar vj = Def("vj", T::axis::Spatial(Range(0, 128), j));
IterVar vk = Def("vk", T::axis::Reduce(Range(0, 128), k));
}
LOG(INFO) << "ForFrame:\n" << _()->stmts;
}
Expand Down
6 changes: 4 additions & 2 deletions src/script/builder/tir/prim_func_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,16 @@ PrimFuncFrame PrimFunc_(String name) {
return PrimFuncFrame(n);
}

tvm::tir::Var Arg(tvm::tir::Var var) {
tvm::tir::Var Arg(String name, tvm::tir::Var var) {
Namer::Name(var, name);
PrimFuncFrame frame = Builder::Current()->FindFrame<PrimFuncFrame>().value();
frame->args.push_back(var);
return var;
}

tvm::tir::Buffer Arg(tvm::tir::Buffer buffer) {
tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) {
using namespace tvm::tir;
Namer::Name(buffer, name);
PrimFuncFrame frame = Builder::Current()->FindFrame<PrimFuncFrame>().value();
Var handle(buffer->name + "_handle", DataType::Handle());
frame->args.push_back(handle);
Expand Down
4 changes: 2 additions & 2 deletions src/script/builder/tir/prim_func_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class PrimFuncFrame : public TIRFrame {
};

PrimFuncFrame PrimFunc_(String name);
tvm::tir::Var Arg(tvm::tir::Var var);
tvm::tir::Buffer Arg(tvm::tir::Buffer buffer);
tvm::tir::Var Arg(String name, tvm::tir::Var var);
tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer);

} // namespace tir
} // namespace builder
Expand Down
36 changes: 36 additions & 0 deletions src/script/builder/tir/var.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,42 @@ tvm::tir::Buffer Buffer_(Array<PrimExpr> shape, DataType dtype, String name, Str
return tvm::tir::decl_buffer(shape, dtype, name, storage_scope);
}

TVM_STATIC_IR_FUNCTOR(Namer, vtable)
.set_dispatch<tvm::tir::BufferNode>([](const ObjectRef& node, String name) -> void {
using namespace tvm::tir;
BufferNode* buffer = const_cast<BufferNode*>(node.as<BufferNode>());
buffer->name = name;
Namer::Name(buffer->data, name + "_data");
int n = buffer->strides.size();
for (int i = 0; i < n; ++i) {
PrimExpr e = buffer->strides[i];
if (const VarNode* v = e.as<VarNode>()) {
Namer::Name(GetRef<Var>(v), name + "_s" + std::to_string(i));
}
}
});

TVM_STATIC_IR_FUNCTOR(Namer, vtable)
.set_dispatch<tvm::tir::SizeVarNode>([](const ObjectRef& node, String name) -> void {
using namespace tvm::tir;
SizeVarNode* var = const_cast<SizeVarNode*>(node.as<SizeVarNode>());
var->name_hint = name;
});

TVM_STATIC_IR_FUNCTOR(Namer, vtable)
.set_dispatch<tvm::tir::VarNode>([](const ObjectRef& node, String name) -> void {
using namespace tvm::tir;
VarNode* var = const_cast<VarNode*>(node.as<VarNode>());
var->name_hint = name;
});

TVM_STATIC_IR_FUNCTOR(Namer, vtable)
.set_dispatch<tvm::tir::IterVarNode>([](const ObjectRef& node, String name) -> void {
using namespace tvm::tir;
IterVarNode* var = const_cast<IterVarNode*>(node.as<IterVarNode>());
Namer::Name(var->var, name);
});

} // namespace tir
} // namespace builder
} // namespace script
Expand Down

0 comments on commit 3c7f908

Please sign in to comment.