37
37
#include < tvm/relay/attrs/annotation.h>
38
38
#include < tvm/relay/expr_functor.h>
39
39
#include < tvm/relay/pattern_functor.h>
40
+ #include < tvm/target/se_scope.h>
40
41
#include < tvm/tir/function.h>
41
42
42
43
#include " ../ir/attr_functor.h"
@@ -120,9 +121,6 @@ Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) {
120
121
return PrintPattern (Downcast<Pattern>(node), meta);
121
122
} else if (node.as <IRModuleNode>()) {
122
123
return PrintMod (Downcast<IRModule>(node));
123
- } else if (!show_meta_data_ && node.as <BaseAttrsNode>()) {
124
- // Show attributes in readable form.
125
- return PrintAttrs (Downcast<Attrs>(node));
126
124
} else {
127
125
// default module.
128
126
std::ostringstream os;
@@ -444,7 +442,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
444
442
for (Var param : fn->params ) {
445
443
params.push_back (AllocVar (param));
446
444
}
447
- for (const Doc& d : PrintFuncAttrs (fn->attrs )) {
445
+ for (const Doc& d : PrintDictAttrs (fn->attrs )) {
448
446
params.push_back (d);
449
447
}
450
448
doc << Doc::Concat (params) << " ) " ;
@@ -684,8 +682,10 @@ Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) {
684
682
Doc doc;
685
683
doc << " Tensor[(" ;
686
684
std::vector<Doc> shapes;
687
- for (ObjectRef shape : node->shape ) {
688
- shapes.push_back (PrintAttr (shape));
685
+ for (const PrimExpr& prim_expr : node->shape ) {
686
+ // Though not bound within an attribute the attribute visitor will handle the PrimExprs we
687
+ // care about.
688
+ shapes.push_back (PrintAttributeValue (prim_expr));
689
689
}
690
690
doc << Doc::Concat (shapes);
691
691
return doc << " ), " << PrintDType (node->dtype ) << " ]" ;
@@ -766,34 +766,18 @@ Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) {
766
766
// Overload of Attr printing functions
767
767
// ------------------------------------
768
768
769
- Doc RelayTextPrinter::PrintAttr (const ObjectRef& value, bool meta) {
770
- if (value.defined ()) {
771
- Doc printed_attr;
772
- if (value.as <tvm::tir::AnyNode>()) {
773
- printed_attr << " ?" ;
774
- } else if (auto str_obj = value.as <tvm::StringObj>()) {
775
- printed_attr << Doc::StrLiteral (GetRef<String>(str_obj));
776
- } else if (meta) {
777
- printed_attr = meta_->GetMetaNode (Downcast<ObjectRef>(value));
778
- } else {
779
- printed_attr = VisitAttr (value);
780
- }
781
- return printed_attr;
782
- } else {
783
- return Doc::Text (" None" );
784
- }
785
- }
786
-
787
769
Doc RelayTextPrinter::VisitAttrDefault_ (const Object* op) {
788
- return PrintAttr (GetRef<ObjectRef>(op), /* meta=*/ true );
770
+ // Since we don't have any overload for a specific attribute type we'll need to force
771
+ // the meta[...] representation to avoid infinite regress.
772
+ return PrintAttributeValue (GetRef<ObjectRef>(op), /* force_meta=*/ true );
789
773
}
790
774
791
775
Doc RelayTextPrinter::VisitAttr_ (const ArrayNode* op) {
792
776
Doc doc;
793
777
doc << " [" ;
794
778
std::vector<Doc> arr_vals;
795
- for (auto val : *op) {
796
- arr_vals.push_back (PrintAttr (val));
779
+ for (const auto & val : *op) {
780
+ arr_vals.push_back (PrintAttributeValue (val));
797
781
}
798
782
doc << Doc::Concat (arr_vals);
799
783
doc << " ]" ;
@@ -831,6 +815,7 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor {
831
815
doc << key << " =" << *value << " f" ;
832
816
docs->push_back (doc);
833
817
}
818
+
834
819
void Visit (const char * key, int64_t * value) final { PrintKV (key, *value); }
835
820
void Visit (const char * key, uint64_t * value) final { PrintKV (key, *value); }
836
821
void Visit (const char * key, int * value) final { PrintKV (key, *value); }
@@ -844,58 +829,134 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor {
844
829
LOG (FATAL) << " do not allow NDarray as argument" ;
845
830
}
846
831
void Visit (const char * key, runtime::ObjectRef* obj) final {
847
- PrintKV (key, parent_->PrintAttr (*obj));
832
+ PrintKV (key, parent_->PrintAttributeValue (*obj));
848
833
}
849
834
850
835
private:
851
836
std::vector<Doc>* docs;
852
837
RelayTextPrinter* parent_;
853
838
};
854
839
855
- Doc RelayTextPrinter::PrintAttrs (const Attrs& attrs) {
856
- std::vector<Doc> docs;
857
- AttrPrinter printer (&docs, this );
858
- const_cast <BaseAttrsNode*>(attrs.operator ->())->VisitNonDefaultAttrs (&printer);
859
- Doc doc;
860
- doc << " {" << Doc::Concat (docs) << " }" ;
861
-
862
- return doc;
840
+ void RelayTextPrinter::AppendGenericAttrs (std::vector<Doc>* docs, const Attrs& attrs,
841
+ bool include_type_key) {
842
+ if (!attrs.defined ()) {
843
+ return ;
844
+ }
845
+ AttrPrinter printer (docs, this );
846
+ // Need to drop cost cast since in general VisitNonDefaultAttrs can mutate, but in this
847
+ // case we are read-only.
848
+ const_cast <BaseAttrsNode*>(attrs.get ())->VisitNonDefaultAttrs (&printer);
849
+ if (include_type_key) {
850
+ std::string s = attrs->GetTypeKey ();
851
+ printer.Visit (" attrs_type_key" , &s);
852
+ }
863
853
}
864
854
865
855
std::vector<Doc> RelayTextPrinter::PrintCallAttrs (const Attrs& attrs, const Expr& op) {
866
856
std::vector<Doc> docs;
867
- if (!attrs.defined ()) return docs;
857
+ if (!attrs.defined ()) {
858
+ return docs;
859
+ }
868
860
const auto * op_node = op.as <OpNode>();
869
861
if (show_meta_data_ && op_node && (attrs->type_index () != op_node->attrs_type_index )) {
870
- // fallback
862
+ // The parser can only understand calls with attributes if they match the operator's
863
+ // declared attribute type. If that's not the case fall back to the meta[...] representation.
864
+ docs.push_back (meta_->GetMetaNode (attrs));
865
+ } else {
866
+ AppendGenericAttrs (&docs, attrs, /* include_type_key=*/ !op_node);
867
+ }
868
+ return docs;
869
+ }
870
+
871
+ std::vector<Doc> RelayTextPrinter::PrintDictAttrs (const DictAttrs& dict_attrs) {
872
+ if (!dict_attrs.defined ()) {
873
+ return {};
874
+ }
875
+ return PrintDictAttrs (dict_attrs->dict );
876
+ }
877
+
878
+ std::vector<Doc> RelayTextPrinter::PrintDictAttrs (const Map<String, ObjectRef>& dict_attrs) {
879
+ std::vector<Doc> docs;
880
+ if (!dict_attrs.defined ()) {
881
+ return docs;
882
+ }
883
+ for (const auto & k : dict_attrs) {
871
884
Doc doc;
872
- doc << meta_-> GetMetaNode (attrs );
885
+ doc << k. first << " = " << PrintAttributeValue (k. second );
873
886
docs.push_back (doc);
874
- return docs;
875
- } else {
876
- // Show attributes in readable form.
877
- AttrPrinter printer (&docs, this );
878
- const_cast <BaseAttrsNode*>(attrs.operator ->())->VisitNonDefaultAttrs (&printer);
879
- if (!op_node) {
880
- // print call attr type key to restore expr for relay parser
881
- std::string s = std::string (attrs->GetTypeKey ());
882
- printer.Visit (" attrs_type_key" , &s);
887
+ }
888
+ return docs;
889
+ }
890
+
891
+ Doc RelayTextPrinter::PrintAttributeValue (const ObjectRef& value, bool force_meta) {
892
+ if (value.defined ()) {
893
+ Doc printed_attr;
894
+ if (value.as <tvm::tir::AnyNode>()) {
895
+ printed_attr << " ?" ;
896
+ } else if (auto str_obj = value.as <tvm::StringObj>()) {
897
+ printed_attr << Doc::StrLiteral (GetRef<String>(str_obj));
898
+ } else if (force_meta) {
899
+ printed_attr = meta_->GetMetaNode (Downcast<ObjectRef>(value));
900
+ } else if (const auto * se_scope_node = value.as <SEScopeNode>()) {
901
+ if (show_meta_data_) {
902
+ printed_attr = meta_->GetMetaNode (GetRef<ObjectRef>(se_scope_node));
903
+ } else {
904
+ // Special case: The ReprPrinter for SEScopeNodes is much easier to work with while
905
+ // debugging.
906
+ std::ostringstream os;
907
+ os << GetRef<SEScope>(se_scope_node);
908
+ return Doc::Text (os.str ());
909
+ }
910
+ } else if (const auto * base_attr_node = value.as <BaseAttrsNode>()) {
911
+ if (show_meta_data_) {
912
+ printed_attr = meta_->GetMetaNode (GetRef<ObjectRef>(base_attr_node));
913
+ } else {
914
+ // Special case: The non-meta form for attributes are much easier to work with while
915
+ // debugging.
916
+ printed_attr = PrintAttrsAsAttributeValue (GetRef<Attrs>(base_attr_node));
917
+ }
918
+ } else if (const auto * base_map_node = value.as <MapNode>()) {
919
+ if (show_meta_data_) {
920
+ printed_attr = meta_->GetMetaNode (GetRef<ObjectRef>(base_map_node));
921
+ } else {
922
+ // Special case: Show maps fields as key=value pairs to help debugging.
923
+ printed_attr << PrintMapAsAttributeValue (GetRef<Map<ObjectRef, ObjectRef>>(base_map_node));
924
+ }
925
+ } else if (const auto * global_var_node = value.as <GlobalVarNode>()) {
926
+ if (show_meta_data_) {
927
+ printed_attr = meta_->GetMetaNode (GetRef<ObjectRef>(global_var_node));
928
+ } else {
929
+ printed_attr << " '" << global_var_node->name_hint << " '" ;
930
+ }
931
+ } else {
932
+ printed_attr = VisitAttr (value);
883
933
}
884
- return docs;
934
+ return printed_attr;
935
+ } else {
936
+ return Doc::Text (" None" );
885
937
}
886
938
}
887
939
888
- std::vector< Doc> RelayTextPrinter::PrintFuncAttrs (const Attrs& attrs) {
940
+ Doc RelayTextPrinter::PrintAttrsAsAttributeValue (const Attrs& attrs) {
889
941
std::vector<Doc> docs;
890
- if (!attrs.defined ()) return docs;
891
- const auto * dict_attrs = attrs.as <DictAttrsNode>();
892
- ICHECK (dict_attrs);
893
- for (const auto & k : dict_attrs->dict ) {
942
+ AppendGenericAttrs (&docs, attrs, /* include_type_key=*/ false );
943
+ Doc doc;
944
+ doc << " {" << Doc::Concat (docs) << " }" ;
945
+ return doc;
946
+ }
947
+
948
+ Doc RelayTextPrinter::PrintMapAsAttributeValue (const Map<ObjectRef, ObjectRef>& map) {
949
+ std::vector<Doc> docs;
950
+ for (const auto & k : map) {
894
951
Doc doc;
895
- doc << k.first << " =" << Print (k.second );
952
+ doc << PrintAttributeValue (k.first );
953
+ doc << " =" ;
954
+ doc << PrintAttributeValue (k.second );
896
955
docs.push_back (doc);
897
956
}
898
- return docs;
957
+ Doc doc;
958
+ doc << " {" << Doc::Concat (docs) << " }" ;
959
+ return doc;
899
960
}
900
961
901
962
Doc RelayTextPrinter::PrintSpan (const Span& span) {
0 commit comments