Skip to content

Commit fdfc7eb

Browse files
authored
[TVMSCRIPT] Attach span information to tir nodes in tvmscript (#6910)
1 parent 0d46cf7 commit fdfc7eb

40 files changed

+1346
-640
lines changed

include/tvm/ir/expr.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ class IntImmNode : public PrimExprNode {
238238
void VisitAttrs(AttrVisitor* v) {
239239
v->Visit("dtype", &dtype);
240240
v->Visit("value", &value);
241+
v->Visit("span", &span);
241242
}
242243

243244
bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
@@ -283,6 +284,7 @@ class FloatImmNode : public PrimExprNode {
283284
void VisitAttrs(AttrVisitor* v) {
284285
v->Visit("dtype", &dtype);
285286
v->Visit("value", &value);
287+
v->Visit("span", &span);
286288
}
287289

288290
bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const {
@@ -415,13 +417,17 @@ class RangeNode : public Object {
415417
PrimExpr min;
416418
/*! \brief the extend of range */
417419
PrimExpr extent;
420+
/*! \brief the location of this range in the source */
421+
mutable Span span;
418422
/*! \brief constructor */
419423
RangeNode() {}
420-
RangeNode(PrimExpr min, PrimExpr extent) : min(min), extent(extent) {}
424+
RangeNode(PrimExpr min, PrimExpr extent, Span span = Span())
425+
: min(min), extent(extent), span(span) {}
421426

422427
void VisitAttrs(AttrVisitor* v) {
423428
v->Visit("min", &min);
424429
v->Visit("extent", &extent);
430+
v->Visit("span", &span);
425431
}
426432

427433
bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const {
@@ -446,8 +452,9 @@ class Range : public ObjectRef {
446452
* \brief constructor by begin and end
447453
* \param begin The begin of the range.
448454
* \param end The end of the range.
455+
* \param span The location of the Range in the source.
449456
*/
450-
TVM_DLL Range(PrimExpr begin, PrimExpr end);
457+
TVM_DLL Range(PrimExpr begin, PrimExpr end, Span span = Span());
451458
/*!
452459
* \brief construct a new range with min and extent
453460
* The corresponding constructor is removed,
@@ -456,8 +463,9 @@ class Range : public ObjectRef {
456463
*
457464
* \param min The minimum range.
458465
* \param extent The extent of the range.
466+
* \param span The location of the Range in the source.
459467
*/
460-
static Range FromMinExtent(PrimExpr min, PrimExpr extent);
468+
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span = Span());
461469
// declare range.
462470
TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
463471
};

include/tvm/runtime/packed_func.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,9 @@ class TVMArgValue : public TVMPODValue_ {
491491
} else if (type_code_ == kTVMStr) {
492492
return std::string(value_.v_str);
493493
} else {
494-
ICHECK(IsObjectRef<tvm::runtime::String>());
494+
ICHECK(IsObjectRef<tvm::runtime::String>())
495+
<< "Could not convert TVM object of type " << runtime::Object::TypeIndex2Key(type_code_)
496+
<< " to a string.";
495497
return AsObjectRef<tvm::runtime::String>().operator std::string();
496498
}
497499
}

include/tvm/tir/buffer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class BufferNode : public Object {
9696
v->Visit("data_alignment", &data_alignment);
9797
v->Visit("offset_factor", &offset_factor);
9898
v->Visit("buffer_type", &buffer_type);
99+
v->Visit("span", &span);
99100
}
100101

101102
bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {

include/tvm/tir/expr.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class StringImmNode : public PrimExprNode {
5656
void VisitAttrs(AttrVisitor* v) {
5757
v->Visit("dtype", &dtype);
5858
v->Visit("value", &value);
59+
v->Visit("span", &span);
5960
}
6061

6162
bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const {
@@ -90,6 +91,7 @@ class CastNode : public PrimExprNode {
9091
void VisitAttrs(AttrVisitor* v) {
9192
v->Visit("dtype", &dtype);
9293
v->Visit("value", &value);
94+
v->Visit("span", &span);
9395
}
9496

9597
bool SEqualReduce(const CastNode* other, SEqualReducer equal) const {
@@ -131,6 +133,7 @@ class BinaryOpNode : public PrimExprNode {
131133
v->Visit("dtype", &(this->dtype));
132134
v->Visit("a", &a);
133135
v->Visit("b", &b);
136+
v->Visit("span", &span);
134137
}
135138

136139
bool SEqualReduce(const T* other, SEqualReducer equal) const {
@@ -312,6 +315,7 @@ class CmpOpNode : public PrimExprNode {
312315
v->Visit("dtype", &(this->dtype));
313316
v->Visit("a", &a);
314317
v->Visit("b", &b);
318+
v->Visit("span", &span);
315319
}
316320

317321
bool SEqualReduce(const T* other, SEqualReducer equal) const {
@@ -435,6 +439,7 @@ class AndNode : public PrimExprNode {
435439
v->Visit("dtype", &(this->dtype));
436440
v->Visit("a", &a);
437441
v->Visit("b", &b);
442+
v->Visit("span", &span);
438443
}
439444

440445
bool SEqualReduce(const AndNode* other, SEqualReducer equal) const {
@@ -473,6 +478,7 @@ class OrNode : public PrimExprNode {
473478
v->Visit("dtype", &dtype);
474479
v->Visit("a", &a);
475480
v->Visit("b", &b);
481+
v->Visit("span", &span);
476482
}
477483

478484
bool SEqualReduce(const OrNode* other, SEqualReducer equal) const {
@@ -508,6 +514,7 @@ class NotNode : public PrimExprNode {
508514
void VisitAttrs(AttrVisitor* v) {
509515
v->Visit("dtype", &dtype);
510516
v->Visit("a", &a);
517+
v->Visit("span", &span);
511518
}
512519

513520
bool SEqualReduce(const NotNode* other, SEqualReducer equal) const {
@@ -554,6 +561,7 @@ class SelectNode : public PrimExprNode {
554561
v->Visit("condition", &condition);
555562
v->Visit("true_value", &true_value);
556563
v->Visit("false_value", &false_value);
564+
v->Visit("span", &span);
557565
}
558566

559567
bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const {
@@ -604,6 +612,7 @@ class BufferLoadNode : public PrimExprNode {
604612
v->Visit("dtype", &(this->dtype));
605613
v->Visit("buffer", &buffer);
606614
v->Visit("indices", &indices);
615+
v->Visit("span", &span);
607616
}
608617

609618
bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const {
@@ -651,6 +660,7 @@ class ProducerLoadNode : public PrimExprNode {
651660
v->Visit("dtype", &(this->dtype));
652661
v->Visit("producer", &producer);
653662
v->Visit("indices", &indices);
663+
v->Visit("span", &span);
654664
}
655665

656666
bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const {
@@ -708,6 +718,7 @@ class LoadNode : public PrimExprNode {
708718
v->Visit("buffer_var", &buffer_var);
709719
v->Visit("index", &index);
710720
v->Visit("predicate", &predicate);
721+
v->Visit("span", &span);
711722
}
712723

713724
bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const {
@@ -760,6 +771,7 @@ class RampNode : public PrimExprNode {
760771
v->Visit("base", &base);
761772
v->Visit("stride", &stride);
762773
v->Visit("lanes", &lanes);
774+
v->Visit("span", &span);
763775
}
764776

765777
bool SEqualReduce(const RampNode* other, SEqualReducer equal) const {
@@ -800,6 +812,7 @@ class BroadcastNode : public PrimExprNode {
800812
v->Visit("dtype", &dtype);
801813
v->Visit("value", &value);
802814
v->Visit("lanes", &lanes);
815+
v->Visit("span", &span);
803816
}
804817

805818
bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const {
@@ -843,6 +856,7 @@ class LetNode : public PrimExprNode {
843856
v->Visit("var", &var);
844857
v->Visit("value", &value);
845858
v->Visit("body", &body);
859+
v->Visit("span", &span);
846860
}
847861

848862
bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
@@ -890,6 +904,7 @@ class CallNode : public PrimExprNode {
890904
v->Visit("dtype", &dtype);
891905
v->Visit("op", &op);
892906
v->Visit("args", &args);
907+
v->Visit("span", &span);
893908
}
894909

895910
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
@@ -931,6 +946,7 @@ class ShuffleNode : public PrimExprNode {
931946
void VisitAttrs(AttrVisitor* v) {
932947
v->Visit("vectors", &vectors);
933948
v->Visit("indices", &indices);
949+
v->Visit("span", &span);
934950
}
935951

936952
bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const {
@@ -993,6 +1009,7 @@ class CommReducerNode : public Object {
9931009
v->Visit("rhs", &rhs);
9941010
v->Visit("result", &result);
9951011
v->Visit("identity_element", &identity_element);
1012+
v->Visit("span", &span);
9961013
}
9971014

9981015
bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const {
@@ -1052,6 +1069,7 @@ class ReduceNode : public PrimExprNode {
10521069
v->Visit("axis", &axis);
10531070
v->Visit("condition", &condition);
10541071
v->Visit("value_index", &value_index);
1072+
v->Visit("span", &span);
10551073
}
10561074

10571075
bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const {
@@ -1091,7 +1109,10 @@ class Reduce : public PrimExpr {
10911109
/*! \brief Any shape. */
10921110
class AnyNode : public PrimExprNode {
10931111
public:
1094-
void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); }
1112+
void VisitAttrs(AttrVisitor* v) {
1113+
v->Visit("dtype", &dtype);
1114+
v->Visit("span", &span);
1115+
}
10951116

10961117
bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
10971118
return equal(dtype, other->dtype);

0 commit comments

Comments
 (0)