Skip to content

Commit 675f348

Browse files
authored
Better host handling in CompilationConfig & debug printing (#9460)
(This is a bit of a grab bag in preparation for #9326 which I'm trying to minimize) While switching the device planner to use SEScopes I had a lot of trouble with Target's not matching up. - If no explicit host target is given but the given TargetMap has targets with hosts, try to use those to establish the host_target. - Make sure both the 'legacy' TargetMap representation and the newer representation agree to pointer equality on their targets. - Make sure the Interpreter uses the target from CompilationConfig since it's been normalized. To debug the above: - When in pretty printing with show_meta_data_ false give as much detail on SEScopes, Targets and call attributes as possible. That needed some rework in the relay_text_printer.cc. - Ditto for critical 'target' attribute on PrimFuncs. - Also added a Target::ToDebugString so I could see the host fields along with everything else since a lot of problems were caused by a mismatch of 'the same' Target with and without a host. (Tried using that for the ReprPrinter but broken unit tests.) Note that the codebase assumes Targets are compared by ObjectPtrEquality, yet CheckAndUpdateHostConsistency (I count 65 call sites) changes the targets. Ultimately CompilationConfig or it's ultimate replacement should ensure we munge targets only once at the 'main' entry points.
1 parent 6549f47 commit 675f348

File tree

12 files changed

+340
-134
lines changed

12 files changed

+340
-134
lines changed

include/tvm/target/target.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,15 @@ class TargetNode : public Object {
6666
/*! \return The Optional<Target> typed target host of the TargetNode */
6767
TVM_DLL Optional<Target> GetHost() const;
6868

69+
/*!
70+
* \brief Returns a human readable representation of \p Target which includes all fields,
71+
* especially the host. Useful for diagnostic messages and debugging.
72+
*
73+
* TODO(mbs): The ReprPrinter version should perhaps switch to this form, however currently
74+
* code depends on str() and << being the same.
75+
*/
76+
String ToDebugString() const;
77+
6978
void VisitAttrs(AttrVisitor* v) {
7079
v->Visit("kind", &kind);
7180
v->Visit("tag", &tag);

src/parser/parser.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1955,7 +1955,8 @@ TVM_REGISTER_GLOBAL("parser.ParseExpr")
19551955
TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed([]() {
19561956
return CreateModulePass(
19571957
[](const IRModule& mod, const PassContext& ctx) {
1958-
auto text = AsText(mod, true);
1958+
String text = AsText(mod, /*show_meta_data=*/true);
1959+
VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
19591960
return ParseModule("GeneratedSource", text);
19601961
},
19611962
0, "AnnotateSpans", {});

src/printer/relay_text_printer.cc

Lines changed: 117 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include <tvm/relay/attrs/annotation.h>
3838
#include <tvm/relay/expr_functor.h>
3939
#include <tvm/relay/pattern_functor.h>
40+
#include <tvm/target/se_scope.h>
4041
#include <tvm/tir/function.h>
4142

4243
#include "../ir/attr_functor.h"
@@ -120,9 +121,6 @@ Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) {
120121
return PrintPattern(Downcast<Pattern>(node), meta);
121122
} else if (node.as<IRModuleNode>()) {
122123
return PrintMod(Downcast<IRModule>(node));
123-
} else if (!show_meta_data_ && node.as<BaseAttrsNode>()) {
124-
// Show attributes in readable form.
125-
return PrintAttrs(Downcast<Attrs>(node));
126124
} else {
127125
// default module.
128126
std::ostringstream os;
@@ -444,7 +442,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
444442
for (Var param : fn->params) {
445443
params.push_back(AllocVar(param));
446444
}
447-
for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
445+
for (const Doc& d : PrintDictAttrs(fn->attrs)) {
448446
params.push_back(d);
449447
}
450448
doc << Doc::Concat(params) << ") ";
@@ -684,8 +682,10 @@ Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) {
684682
Doc doc;
685683
doc << "Tensor[(";
686684
std::vector<Doc> shapes;
687-
for (ObjectRef shape : node->shape) {
688-
shapes.push_back(PrintAttr(shape));
685+
for (const PrimExpr& prim_expr : node->shape) {
686+
// Though not bound within an attribute the attribute visitor will handle the PrimExprs we
687+
// care about.
688+
shapes.push_back(PrintAttributeValue(prim_expr));
689689
}
690690
doc << Doc::Concat(shapes);
691691
return doc << "), " << PrintDType(node->dtype) << "]";
@@ -766,34 +766,18 @@ Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) {
766766
// Overload of Attr printing functions
767767
//------------------------------------
768768

769-
Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) {
770-
if (value.defined()) {
771-
Doc printed_attr;
772-
if (value.as<tvm::tir::AnyNode>()) {
773-
printed_attr << "?";
774-
} else if (auto str_obj = value.as<tvm::StringObj>()) {
775-
printed_attr << Doc::StrLiteral(GetRef<String>(str_obj));
776-
} else if (meta) {
777-
printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value));
778-
} else {
779-
printed_attr = VisitAttr(value);
780-
}
781-
return printed_attr;
782-
} else {
783-
return Doc::Text("None");
784-
}
785-
}
786-
787769
Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) {
788-
return PrintAttr(GetRef<ObjectRef>(op), /*meta=*/true);
770+
// Since we don't have any overload for a specific attribute type we'll need to force
771+
// the meta[...] representation to avoid infinite regress.
772+
return PrintAttributeValue(GetRef<ObjectRef>(op), /*force_meta=*/true);
789773
}
790774

791775
Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) {
792776
Doc doc;
793777
doc << "[";
794778
std::vector<Doc> arr_vals;
795-
for (auto val : *op) {
796-
arr_vals.push_back(PrintAttr(val));
779+
for (const auto& val : *op) {
780+
arr_vals.push_back(PrintAttributeValue(val));
797781
}
798782
doc << Doc::Concat(arr_vals);
799783
doc << "]";
@@ -831,6 +815,7 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor {
831815
doc << key << "=" << *value << "f";
832816
docs->push_back(doc);
833817
}
818+
834819
void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); }
835820
void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); }
836821
void Visit(const char* key, int* value) final { PrintKV(key, *value); }
@@ -844,58 +829,134 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor {
844829
LOG(FATAL) << "do not allow NDarray as argument";
845830
}
846831
void Visit(const char* key, runtime::ObjectRef* obj) final {
847-
PrintKV(key, parent_->PrintAttr(*obj));
832+
PrintKV(key, parent_->PrintAttributeValue(*obj));
848833
}
849834

850835
private:
851836
std::vector<Doc>* docs;
852837
RelayTextPrinter* parent_;
853838
};
854839

855-
Doc RelayTextPrinter::PrintAttrs(const Attrs& attrs) {
856-
std::vector<Doc> docs;
857-
AttrPrinter printer(&docs, this);
858-
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&printer);
859-
Doc doc;
860-
doc << "{" << Doc::Concat(docs) << "}";
861-
862-
return doc;
840+
void RelayTextPrinter::AppendGenericAttrs(std::vector<Doc>* docs, const Attrs& attrs,
841+
bool include_type_key) {
842+
if (!attrs.defined()) {
843+
return;
844+
}
845+
AttrPrinter printer(docs, this);
846+
// Need to drop cost cast since in general VisitNonDefaultAttrs can mutate, but in this
847+
// case we are read-only.
848+
const_cast<BaseAttrsNode*>(attrs.get())->VisitNonDefaultAttrs(&printer);
849+
if (include_type_key) {
850+
std::string s = attrs->GetTypeKey();
851+
printer.Visit("attrs_type_key", &s);
852+
}
863853
}
864854

865855
std::vector<Doc> RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) {
866856
std::vector<Doc> docs;
867-
if (!attrs.defined()) return docs;
857+
if (!attrs.defined()) {
858+
return docs;
859+
}
868860
const auto* op_node = op.as<OpNode>();
869861
if (show_meta_data_ && op_node && (attrs->type_index() != op_node->attrs_type_index)) {
870-
// fallback
862+
// The parser can only understand calls with attributes if they match the operator's
863+
// declared attribute type. If that's not the case fall back to the meta[...] representation.
864+
docs.push_back(meta_->GetMetaNode(attrs));
865+
} else {
866+
AppendGenericAttrs(&docs, attrs, /*include_type_key=*/!op_node);
867+
}
868+
return docs;
869+
}
870+
871+
std::vector<Doc> RelayTextPrinter::PrintDictAttrs(const DictAttrs& dict_attrs) {
872+
if (!dict_attrs.defined()) {
873+
return {};
874+
}
875+
return PrintDictAttrs(dict_attrs->dict);
876+
}
877+
878+
std::vector<Doc> RelayTextPrinter::PrintDictAttrs(const Map<String, ObjectRef>& dict_attrs) {
879+
std::vector<Doc> docs;
880+
if (!dict_attrs.defined()) {
881+
return docs;
882+
}
883+
for (const auto& k : dict_attrs) {
871884
Doc doc;
872-
doc << meta_->GetMetaNode(attrs);
885+
doc << k.first << "=" << PrintAttributeValue(k.second);
873886
docs.push_back(doc);
874-
return docs;
875-
} else {
876-
// Show attributes in readable form.
877-
AttrPrinter printer(&docs, this);
878-
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&printer);
879-
if (!op_node) {
880-
// print call attr type key to restore expr for relay parser
881-
std::string s = std::string(attrs->GetTypeKey());
882-
printer.Visit("attrs_type_key", &s);
887+
}
888+
return docs;
889+
}
890+
891+
Doc RelayTextPrinter::PrintAttributeValue(const ObjectRef& value, bool force_meta) {
892+
if (value.defined()) {
893+
Doc printed_attr;
894+
if (value.as<tvm::tir::AnyNode>()) {
895+
printed_attr << "?";
896+
} else if (auto str_obj = value.as<tvm::StringObj>()) {
897+
printed_attr << Doc::StrLiteral(GetRef<String>(str_obj));
898+
} else if (force_meta) {
899+
printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value));
900+
} else if (const auto* se_scope_node = value.as<SEScopeNode>()) {
901+
if (show_meta_data_) {
902+
printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(se_scope_node));
903+
} else {
904+
// Special case: The ReprPrinter for SEScopeNodes is much easier to work with while
905+
// debugging.
906+
std::ostringstream os;
907+
os << GetRef<SEScope>(se_scope_node);
908+
return Doc::Text(os.str());
909+
}
910+
} else if (const auto* base_attr_node = value.as<BaseAttrsNode>()) {
911+
if (show_meta_data_) {
912+
printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(base_attr_node));
913+
} else {
914+
// Special case: The non-meta form for attributes are much easier to work with while
915+
// debugging.
916+
printed_attr = PrintAttrsAsAttributeValue(GetRef<Attrs>(base_attr_node));
917+
}
918+
} else if (const auto* base_map_node = value.as<MapNode>()) {
919+
if (show_meta_data_) {
920+
printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(base_map_node));
921+
} else {
922+
// Special case: Show maps fields as key=value pairs to help debugging.
923+
printed_attr << PrintMapAsAttributeValue(GetRef<Map<ObjectRef, ObjectRef>>(base_map_node));
924+
}
925+
} else if (const auto* global_var_node = value.as<GlobalVarNode>()) {
926+
if (show_meta_data_) {
927+
printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(global_var_node));
928+
} else {
929+
printed_attr << "'" << global_var_node->name_hint << "'";
930+
}
931+
} else {
932+
printed_attr = VisitAttr(value);
883933
}
884-
return docs;
934+
return printed_attr;
935+
} else {
936+
return Doc::Text("None");
885937
}
886938
}
887939

888-
std::vector<Doc> RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) {
940+
Doc RelayTextPrinter::PrintAttrsAsAttributeValue(const Attrs& attrs) {
889941
std::vector<Doc> docs;
890-
if (!attrs.defined()) return docs;
891-
const auto* dict_attrs = attrs.as<DictAttrsNode>();
892-
ICHECK(dict_attrs);
893-
for (const auto& k : dict_attrs->dict) {
942+
AppendGenericAttrs(&docs, attrs, /*include_type_key=*/false);
943+
Doc doc;
944+
doc << "{" << Doc::Concat(docs) << "}";
945+
return doc;
946+
}
947+
948+
Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map) {
949+
std::vector<Doc> docs;
950+
for (const auto& k : map) {
894951
Doc doc;
895-
doc << k.first << "=" << Print(k.second);
952+
doc << PrintAttributeValue(k.first);
953+
doc << "=";
954+
doc << PrintAttributeValue(k.second);
896955
docs.push_back(doc);
897956
}
898-
return docs;
957+
Doc doc;
958+
doc << "{" << Doc::Concat(docs) << "}";
959+
return doc;
899960
}
900961

901962
Doc RelayTextPrinter::PrintSpan(const Span& span) {

src/printer/text_printer.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Doc TextPrinter::PrintMod(const IRModule& mod) {
5858
os << "def @" << kv.first->name_hint;
5959
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second);
6060
} else if (kv.second.as<tir::PrimFuncNode>()) {
61+
doc << "@" << kv.first->name_hint << " = ";
6162
doc << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
6263
}
6364
doc << Doc::NewLine();

src/printer/text_printer.h

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,42 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
7777
// numbers to be reused and prevents hoisted vars from escaping too far
7878
Doc PrintScope(const ObjectRef& node);
7979
Doc PrintFinal(const ObjectRef& node);
80-
Doc PrintAttrs(const Attrs& attrs);
80+
81+
/*!
82+
* \brief Returns \p attrs printed using the generic attribute visitor, as a sequence
83+
* of key=value entries, if any.
84+
*/
85+
void AppendGenericAttrs(std::vector<Doc>* docs, const Attrs& attrs, bool include_type_key);
86+
87+
/*!
88+
* \brief Returns \p attrs printed as a sequence of key=value entries, if any.
89+
* This is used for call attributes.
90+
*/
8191
std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
82-
std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
92+
93+
/*!
94+
* \brief Returns \p dict_attrs printed as a sequence of key=value entries, if any.
95+
* This is used for function definition attributes.
96+
*/
97+
std::vector<Doc> PrintDictAttrs(const DictAttrs& dict_attrs);
98+
std::vector<Doc> PrintDictAttrs(const Map<String, ObjectRef>& dict_attrs);
99+
100+
/*!
101+
* \brief Returns \p value printed as the rhs of an attribute key=value entry. If \p force_meta
102+
* is true then value is printed in meta[...] for irrespective of the show_meta_data_ flag.
103+
*/
104+
Doc PrintAttributeValue(const ObjectRef& value, bool force_meta = false);
105+
106+
/*!
107+
* \brief Returns \p attrs printed as a self-contained value, ie wrapped in braces.
108+
*/
109+
Doc PrintAttrsAsAttributeValue(const Attrs& attrs);
110+
111+
/*!
112+
* \brief Returns \p map printed as a self-contained value, ie wrapped in braces.
113+
*/
114+
Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map);
115+
83116
Doc PrintSpan(const Span& span);
84117

85118
Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);
@@ -162,7 +195,6 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
162195
//------------------------------------
163196
// Overload of Attr printing functions
164197
//------------------------------------
165-
Doc PrintAttr(const ObjectRef& value, bool meta = false);
166198
Doc VisitAttrDefault_(const Object* op) final;
167199
Doc VisitAttr_(const ArrayNode* op) final;
168200
Doc VisitAttr_(const tir::IntImmNode* op) final;

src/printer/tir_text_printer.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <tvm/ir/type.h>
2828
#include <tvm/ir/type_functor.h>
2929
#include <tvm/node/serialization.h>
30+
#include <tvm/target/target.h>
3031
#include <tvm/tir/expr.h>
3132
#include <tvm/tir/function.h>
3233
#include <tvm/tir/op.h>
@@ -71,6 +72,8 @@ Doc TIRTextPrinter::Print(const ObjectRef& node) {
7172
return PrintString(node.as<StringObj>());
7273
} else if (node->IsInstance<BufferRegionNode>()) {
7374
return PrintBufferRegion(node.as<BufferRegionNode>());
75+
} else if (node->IsInstance<TargetNode>()) {
76+
return Doc::Text(node.as<TargetNode>()->ToDebugString());
7477
} else {
7578
return this->meta_->GetMetaNode(node);
7679
}

0 commit comments

Comments
 (0)