Skip to content

Commit

Permalink
[REFACTOR] Establish printer in the source folder (#4752)
Browse files Browse the repository at this point in the history
* [REFACTOR] Establish printer in the source folder.

As we move towards the unified IR, we will eventually want to build a unified
printers for both relay and TIR.

This PR isolate the printer component into a separate folder in src as a first step.

- Refactored the Doc DSL using Object, clean up APIs.
- Isolate out the meta data into a header.
- move printer into relay_text_printer, add comments about further TODos.

* Rename NodePrinter -> ReprPrinter to distinguish it from other printers
  • Loading branch information
tqchen authored Jan 21, 2020
1 parent f8f75ca commit e4d817d
Show file tree
Hide file tree
Showing 51 changed files with 901 additions and 740 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ file(GLOB_RECURSE COMPILER_SRCS
src/autotvm/*.cc
src/tir/*.cc
src/driver/*.cc
src/printer/*.cc
src/api/*.cc
)

Expand Down
4 changes: 2 additions & 2 deletions apps/lldb/tvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _GetContext(debugger):
def PrettyPrint(debugger, command, result, internal_dict):
ctx = _GetContext(debugger)
rc = ctx.EvaluateExpression(
"tvm::relay::PrettyPrint({command})".format(command=command)
"tvm::PrettyPrint({command})".format(command=command)
)
result.AppendMessage(str(rc))

Expand Down Expand Up @@ -175,7 +175,7 @@ def _EvalExpressionAsString(logger, ctx, expr):

def _EvalAsNodeRef(logger, ctx, value):
return _EvalExpressionAsString(
logger, ctx, "tvm::relay::PrettyPrint({name})".format(name=value.name)
logger, ctx, "tvm::PrettyPrint({name})".format(name=value.name)
)


Expand Down
28 changes: 28 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,5 +308,33 @@ class IRModule : public ObjectRef {
TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
};

/*!
* \brief Pretty print a node for debug purposes.
*
* \param node The node to be printed.
* \return The text reperesentation.
* \note This function does not show version or meta-data.
* Use AsText if you want to store the text.
* \sa AsText.
*/
TVM_DLL std::string PrettyPrint(const ObjectRef& node);

/*!
* \brief Render the node as a string in the text format.
*
* \param node The node to be rendered.
* \param show_meta_data Whether to print meta data section.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
*
* \note We support a limited set of IR nodes that are part of
* relay IR and
*
* \sa PrettyPrint.
* \return The text representation.
*/
TVM_DLL std::string AsText(const ObjectRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate = nullptr);
} // namespace tvm
#endif // TVM_IR_MODULE_H_
14 changes: 7 additions & 7 deletions include/tvm/node/functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* \brief Useful macro to set NodeFunctor dispatch in a global static field.
*
* \code
* // Use NodeFunctor to implement NodePrinter similar to Visitor Pattern.
* // Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern.
* // vtable allows easy patch of new Node types, without changing
* // interface of NodePrinter.
* // interface of ReprPrinter.
*
* class NodePrinter {
* class ReprPrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
Expand All @@ -152,18 +152,18 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* f(e, this);
* }
*
* using FType = NodeFunctor<void (const ObjectRef&, NodePrinter* )>;
* using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* NodePrinter::FType& NodePrinter::vtable() { // NOLINT(*)
* ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*)
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, NodePrinter* p) {
* TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, ReprPrinter* p) {
* auto* n = static_cast<const Add*>(ref.get());
* p->print(n->a);
* p->stream << '+'
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>

#include <string>
#include <vector>
Expand Down
16 changes: 8 additions & 8 deletions include/tvm/node/printer.h → include/tvm/node/repr_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,33 @@
* under the License.
*/
/*!
* \file tvm/node/printer.h
* \file tvm/node/repr_printer.h
* \brief Printer class to print repr string of each AST/IR nodes.
*/
#ifndef TVM_NODE_PRINTER_H_
#define TVM_NODE_PRINTER_H_
#ifndef TVM_NODE_REPR_PRINTER_H_
#define TVM_NODE_REPR_PRINTER_H_

#include <tvm/node/functor.h>
#include <iostream>

namespace tvm {
/*! \brief A printer class to print the AST/IR nodes. */
class NodePrinter {
class ReprPrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};

explicit NodePrinter(std::ostream& stream) // NOLINT(*)
explicit ReprPrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}

/*! \brief The node to be printed. */
TVM_DLL void Print(const ObjectRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, NodePrinter*)>;
using FType = NodeFunctor<void(const ObjectRef&, ReprPrinter*)>;
TVM_DLL static FType& vtable();
};

Expand All @@ -60,9 +60,9 @@ namespace runtime {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
NodePrinter(os).Print(n);
ReprPrinter(os).Print(n);
return os;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_NODE_PRINTER_H_
#endif // TVM_NODE_REPR_PRINTER_H_
16 changes: 2 additions & 14 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/ir/attrs.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <string>
#include <functional>
#include "./base.h"
Expand All @@ -40,6 +41,7 @@ using BaseFunc = tvm::BaseFunc;
using BaseFuncNode = tvm::BaseFuncNode;
using GlobalVar = tvm::GlobalVar;
using GlobalVarNode = tvm::GlobalVarNode;
using tvm::PrettyPrint;

/*!
* \brief Constant tensor, backed by an NDArray on the cpu(0) device.
Expand Down Expand Up @@ -539,20 +541,6 @@ class TempExpr : public Expr {
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode);
};

/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const ObjectRef& node);

/*!
* \brief Render the node as a string in the Relay text format.
* \param node The node to be rendered.
* \param show_meta_data Whether to print meta data section.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* \return The text representation.
*/
std::string AsText(const ObjectRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);

/*! \brief namespace of the attributes that are attached to a function. */
namespace attr {
Expand Down
4 changes: 2 additions & 2 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) {
}
}

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ConstIntBoundNode*>(node.get());
p->stream << "ConstIntBound[";
PrintBoundValue(p->stream, op->min_value);
Expand Down
4 changes: 2 additions & 2 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -813,8 +813,8 @@ IntSet EvalSet(Range r,

TVM_REGISTER_NODE_TYPE(IntervalSetNode);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntervalSetNode*>(node.get());
p->stream << "IntervalSet"
<< "[" << op->min_value << ", "
Expand Down
4 changes: 2 additions & 2 deletions src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) {
data_ = std::move(node);
}

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ModularSetNode*>(node.get());
p->stream << "ModularSet("
<< "coeff=" << op->coeff << ", base="
Expand Down
8 changes: 4 additions & 4 deletions src/ir/adt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ TVM_REGISTER_GLOBAL("relay._make.Constructor")
return Constructor(name_hint, inputs, belong_to);
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ConstructorNode*>(ref.get());
p->stream << "ConstructorNode(" << node->name_hint << ", "
<< node->inputs << ", " << node->belong_to << ")";
Expand All @@ -71,8 +71,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeData")
return TypeData(header, type_vars, constructors);
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TypeDataNode*>(ref.get());
p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
<< node->constructors << ")";
Expand Down
4 changes: 2 additions & 2 deletions src/ir/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) {
return Attrs(n);
}

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DictAttrsNode*>(node.get());
p->stream << op->dict;
});
Expand Down
4 changes: 2 additions & 2 deletions src/ir/env_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<EnvFuncNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<EnvFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const EnvFuncNode*>(node.get());
p->stream << "EnvFunc(" << op->name << ")";
});
Expand Down
2 changes: 1 addition & 1 deletion src/ir/error.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) {
//
// The annotation callback will annotate the error messages
// contained in the map.
annotated_prog << relay::AsText(func, false, [&err_map](tvm::relay::Expr expr) {
annotated_prog << AsText(func, false, [&err_map](const ObjectRef& expr) {
auto it = err_map.find(expr);
if (it != err_map.end()) {
CHECK_NE(it->second.size(), 0);
Expand Down
28 changes: 14 additions & 14 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ TVM_REGISTER_GLOBAL("make.IntImm")

TVM_REGISTER_NODE_TYPE(IntImmNode);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntImmNode*>(node.get());
if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
Expand All @@ -104,8 +104,8 @@ TVM_REGISTER_GLOBAL("make.FloatImm")

TVM_REGISTER_NODE_TYPE(FloatImmNode);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloatImmNode*>(node.get());
auto& stream = p->stream;
switch (op->dtype.bits()) {
Expand Down Expand Up @@ -134,8 +134,8 @@ Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
return Range(make_object<RangeNode>(min, extent));
}

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
Expand All @@ -159,15 +159,15 @@ TVM_REGISTER_GLOBAL("relay._make.GlobalVar")
return GlobalVar(name);
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const GlobalVarNode*>(ref.get());
p->stream << "GlobalVar(" << node->name_hint << ")";
});

// Container printer
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ArrayNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ArrayNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ArrayNode*>(node.get());
p->stream << '[';
for (size_t i = 0 ; i < op->data.size(); ++i) {
Expand All @@ -179,8 +179,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ']';
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<MapNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MapNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MapNode*>(node.get());
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
Expand All @@ -194,8 +194,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << '}';
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<StrMapNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StrMapNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StrMapNode*>(node.get());
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
Expand Down
4 changes: 2 additions & 2 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd")
mod->ImportFromStd(path);
});;

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IRModuleNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const IRModuleNode*>(ref.get());
p->stream << "IRModuleNode( " << node->functions << ")";
});
Expand Down
4 changes: 2 additions & 2 deletions src/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ TVM_REGISTER_NODE_TYPE(OpNode)
return static_cast<const OpNode*>(n)->name;
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const OpNode*>(ref.get());
p->stream << "Op(" << node->name << ")";
});
Expand Down
Loading

0 comments on commit e4d817d

Please sign in to comment.