@@ -56,6 +56,7 @@ class StringImmNode : public PrimExprNode {
56
56
void VisitAttrs (AttrVisitor* v) {
57
57
v->Visit (" dtype" , &dtype);
58
58
v->Visit (" value" , &value);
59
+ v->Visit (" span" , &span);
59
60
}
60
61
61
62
bool SEqualReduce (const StringImmNode* other, SEqualReducer equal) const {
@@ -90,6 +91,7 @@ class CastNode : public PrimExprNode {
90
91
void VisitAttrs (AttrVisitor* v) {
91
92
v->Visit (" dtype" , &dtype);
92
93
v->Visit (" value" , &value);
94
+ v->Visit (" span" , &span);
93
95
}
94
96
95
97
bool SEqualReduce (const CastNode* other, SEqualReducer equal) const {
@@ -131,6 +133,7 @@ class BinaryOpNode : public PrimExprNode {
131
133
v->Visit (" dtype" , &(this ->dtype ));
132
134
v->Visit (" a" , &a);
133
135
v->Visit (" b" , &b);
136
+ v->Visit (" span" , &span);
134
137
}
135
138
136
139
bool SEqualReduce (const T* other, SEqualReducer equal) const {
@@ -312,6 +315,7 @@ class CmpOpNode : public PrimExprNode {
312
315
v->Visit (" dtype" , &(this ->dtype ));
313
316
v->Visit (" a" , &a);
314
317
v->Visit (" b" , &b);
318
+ v->Visit (" span" , &span);
315
319
}
316
320
317
321
bool SEqualReduce (const T* other, SEqualReducer equal) const {
@@ -435,6 +439,7 @@ class AndNode : public PrimExprNode {
435
439
v->Visit (" dtype" , &(this ->dtype ));
436
440
v->Visit (" a" , &a);
437
441
v->Visit (" b" , &b);
442
+ v->Visit (" span" , &span);
438
443
}
439
444
440
445
bool SEqualReduce (const AndNode* other, SEqualReducer equal) const {
@@ -473,6 +478,7 @@ class OrNode : public PrimExprNode {
473
478
v->Visit (" dtype" , &dtype);
474
479
v->Visit (" a" , &a);
475
480
v->Visit (" b" , &b);
481
+ v->Visit (" span" , &span);
476
482
}
477
483
478
484
bool SEqualReduce (const OrNode* other, SEqualReducer equal) const {
@@ -508,6 +514,7 @@ class NotNode : public PrimExprNode {
508
514
void VisitAttrs (AttrVisitor* v) {
509
515
v->Visit (" dtype" , &dtype);
510
516
v->Visit (" a" , &a);
517
+ v->Visit (" span" , &span);
511
518
}
512
519
513
520
bool SEqualReduce (const NotNode* other, SEqualReducer equal) const {
@@ -554,6 +561,7 @@ class SelectNode : public PrimExprNode {
554
561
v->Visit (" condition" , &condition);
555
562
v->Visit (" true_value" , &true_value);
556
563
v->Visit (" false_value" , &false_value);
564
+ v->Visit (" span" , &span);
557
565
}
558
566
559
567
bool SEqualReduce (const SelectNode* other, SEqualReducer equal) const {
@@ -604,6 +612,7 @@ class BufferLoadNode : public PrimExprNode {
604
612
v->Visit (" dtype" , &(this ->dtype ));
605
613
v->Visit (" buffer" , &buffer);
606
614
v->Visit (" indices" , &indices);
615
+ v->Visit (" span" , &span);
607
616
}
608
617
609
618
bool SEqualReduce (const BufferLoadNode* other, SEqualReducer equal) const {
@@ -651,6 +660,7 @@ class ProducerLoadNode : public PrimExprNode {
651
660
v->Visit (" dtype" , &(this ->dtype ));
652
661
v->Visit (" producer" , &producer);
653
662
v->Visit (" indices" , &indices);
663
+ v->Visit (" span" , &span);
654
664
}
655
665
656
666
bool SEqualReduce (const ProducerLoadNode* other, SEqualReducer equal) const {
@@ -708,6 +718,7 @@ class LoadNode : public PrimExprNode {
708
718
v->Visit (" buffer_var" , &buffer_var);
709
719
v->Visit (" index" , &index);
710
720
v->Visit (" predicate" , &predicate);
721
+ v->Visit (" span" , &span);
711
722
}
712
723
713
724
bool SEqualReduce (const LoadNode* other, SEqualReducer equal) const {
@@ -760,6 +771,7 @@ class RampNode : public PrimExprNode {
760
771
v->Visit (" base" , &base);
761
772
v->Visit (" stride" , &stride);
762
773
v->Visit (" lanes" , &lanes);
774
+ v->Visit (" span" , &span);
763
775
}
764
776
765
777
bool SEqualReduce (const RampNode* other, SEqualReducer equal) const {
@@ -800,6 +812,7 @@ class BroadcastNode : public PrimExprNode {
800
812
v->Visit (" dtype" , &dtype);
801
813
v->Visit (" value" , &value);
802
814
v->Visit (" lanes" , &lanes);
815
+ v->Visit (" span" , &span);
803
816
}
804
817
805
818
bool SEqualReduce (const BroadcastNode* other, SEqualReducer equal) const {
@@ -843,6 +856,7 @@ class LetNode : public PrimExprNode {
843
856
v->Visit (" var" , &var);
844
857
v->Visit (" value" , &value);
845
858
v->Visit (" body" , &body);
859
+ v->Visit (" span" , &span);
846
860
}
847
861
848
862
bool SEqualReduce (const LetNode* other, SEqualReducer equal) const {
@@ -890,6 +904,7 @@ class CallNode : public PrimExprNode {
890
904
v->Visit (" dtype" , &dtype);
891
905
v->Visit (" op" , &op);
892
906
v->Visit (" args" , &args);
907
+ v->Visit (" span" , &span);
893
908
}
894
909
895
910
bool SEqualReduce (const CallNode* other, SEqualReducer equal) const {
@@ -931,6 +946,7 @@ class ShuffleNode : public PrimExprNode {
931
946
void VisitAttrs (AttrVisitor* v) {
932
947
v->Visit (" vectors" , &vectors);
933
948
v->Visit (" indices" , &indices);
949
+ v->Visit (" span" , &span);
934
950
}
935
951
936
952
bool SEqualReduce (const ShuffleNode* other, SEqualReducer equal) const {
@@ -993,6 +1009,7 @@ class CommReducerNode : public Object {
993
1009
v->Visit (" rhs" , &rhs);
994
1010
v->Visit (" result" , &result);
995
1011
v->Visit (" identity_element" , &identity_element);
1012
+ v->Visit (" span" , &span);
996
1013
}
997
1014
998
1015
bool SEqualReduce (const CommReducerNode* other, SEqualReducer equal) const {
@@ -1052,6 +1069,7 @@ class ReduceNode : public PrimExprNode {
1052
1069
v->Visit (" axis" , &axis);
1053
1070
v->Visit (" condition" , &condition);
1054
1071
v->Visit (" value_index" , &value_index);
1072
+ v->Visit (" span" , &span);
1055
1073
}
1056
1074
1057
1075
bool SEqualReduce (const ReduceNode* other, SEqualReducer equal) const {
@@ -1091,7 +1109,10 @@ class Reduce : public PrimExpr {
1091
1109
/* ! \brief Any shape. */
1092
1110
class AnyNode : public PrimExprNode {
1093
1111
public:
1094
- void VisitAttrs (AttrVisitor* v) { v->Visit (" dtype" , &dtype); }
1112
+ void VisitAttrs (AttrVisitor* v) {
1113
+ v->Visit (" dtype" , &dtype);
1114
+ v->Visit (" span" , &span);
1115
+ }
1095
1116
1096
1117
bool SEqualReduce (const AnyNode* other, SEqualReducer equal) const {
1097
1118
return equal (dtype, other->dtype );
0 commit comments