@@ -143,10 +143,8 @@ class DenseVariableAxisNode : public DenseAxisNode {
143143 v->Visit (" indptr" , &indptr);
144144 }
145145
146- bool SEqualReduce (const DenseVariableAxisNode* other,
147- SEqualReducer equal) const {
148- return equal (name, other->name ) && equal (length, other->length ) &&
149- equal (indptr, other->indptr );
146+ bool SEqualReduce (const DenseVariableAxisNode* other, SEqualReducer equal) const {
147+ return equal (name, other->name ) && equal (length, other->length ) && equal (indptr, other->indptr );
150148 }
151149
152150 void SHashReduce (SHashReducer hash_reduce) const {
@@ -165,11 +163,9 @@ class DenseVariableAxisNode : public DenseAxisNode {
165163 */
166164class DenseVariableAxis : public DenseAxis {
167165 public:
168- TVM_DLL explicit DenseVariableAxis (String name, PrimExpr length,
169- Buffer indptr);
166+ TVM_DLL explicit DenseVariableAxis (String name, PrimExpr length, Buffer indptr);
170167
171- TVM_DEFINE_OBJECT_REF_METHODS (DenseVariableAxis, DenseAxis,
172- DenseVariableAxisNode);
168+ TVM_DEFINE_OBJECT_REF_METHODS (DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
173169};
174170
175171/* !
@@ -206,8 +202,7 @@ class SparseFixedAxisNode : public SparseAxisNode {
206202 v->Visit (" num_cols" , &num_cols);
207203 }
208204
209- bool SEqualReduce (const SparseFixedAxisNode* other,
210- SEqualReducer equal) const {
205+ bool SEqualReduce (const SparseFixedAxisNode* other, SEqualReducer equal) const {
211206 return equal (name, other->name ) && equal (length, other->length ) &&
212207 equal (indices, other->indices ) && equal (num_cols, other->num_cols );
213208 }
@@ -229,11 +224,9 @@ class SparseFixedAxisNode : public SparseAxisNode {
229224 */
230225class SparseFixedAxis : public SparseAxis {
231226 public:
232- TVM_DLL explicit SparseFixedAxis (String name, PrimExpr length, Buffer indices,
233- PrimExpr num_cols);
227+ TVM_DLL explicit SparseFixedAxis (String name, PrimExpr length, Buffer indices, PrimExpr num_cols);
234228
235- TVM_DEFINE_OBJECT_REF_METHODS (SparseFixedAxis, SparseAxis,
236- SparseFixedAxisNode);
229+ TVM_DEFINE_OBJECT_REF_METHODS (SparseFixedAxis, SparseAxis, SparseFixedAxisNode);
237230};
238231
239232/* !
@@ -251,8 +244,7 @@ class SparseVariableAxisNode : public SparseAxisNode {
251244 v->Visit (" indices" , &indices);
252245 }
253246
254- bool SEqualReduce (const SparseVariableAxisNode* other,
255- SEqualReducer equal) const {
247+ bool SEqualReduce (const SparseVariableAxisNode* other, SEqualReducer equal) const {
256248 return equal (name, other->name ) && equal (length, other->length ) &&
257249 equal (indptr, other->indptr ) && equal (indices, other->indices );
258250 }
@@ -274,11 +266,9 @@ class SparseVariableAxisNode : public SparseAxisNode {
274266 */
275267class SparseVariableAxis : public SparseAxis {
276268 public:
277- TVM_DLL explicit SparseVariableAxis (String name, PrimExpr length,
278- Buffer indptr, Buffer indices);
269+ TVM_DLL explicit SparseVariableAxis (String name, PrimExpr length, Buffer indptr, Buffer indices);
279270
280- TVM_DEFINE_OBJECT_REF_METHODS (SparseVariableAxis, SparseAxis,
281- SparseVariableAxisNode);
271+ TVM_DEFINE_OBJECT_REF_METHODS (SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
282272};
283273
284274/* !
@@ -287,12 +277,9 @@ class SparseVariableAxis : public SparseAxis {
287277class AxisTreeNode : public Object {
288278 public:
289279 // unordered map that stores the parent relationship between axes.
290- std::unordered_map<String, Optional<String>, ObjectPtrHash, ObjectPtrEqual>
291- parent;
280+ std::unordered_map<String, Optional<String>, ObjectPtrHash, ObjectPtrEqual> parent;
292281 // unordered map that stores the children relationship between axes.
293- std::unordered_map<Optional<String>, Array<String>, ObjectPtrHash,
294- ObjectPtrEqual>
295- children;
282+ std::unordered_map<Optional<String>, Array<String>, ObjectPtrHash, ObjectPtrEqual> children;
296283
297284 void VisitAttrs (AttrVisitor* v) {}
298285
@@ -306,8 +293,7 @@ class AxisTreeNode : public Object {
306293 */
307294class AxisTree : public ObjectRef {
308295 public:
309- TVM_DLL AxisTree (Array<String> axis_names,
310- Array<Optional<String>> axis_parent_names);
296+ TVM_DLL AxisTree (Array<String> axis_names, Array<Optional<String>> axis_parent_names);
311297
312298 TVM_DEFINE_OBJECT_REF_METHODS (AxisTree, ObjectRef, AxisTreeNode);
313299};
@@ -333,8 +319,7 @@ class SparseBufferNode : public Object {
333319 }
334320
335321 bool SEqualReduce (const SparseBufferNode* other, SEqualReducer equal) const {
336- return equal (axes, other->axes ) && equal (data, other->data ) &&
337- equal (name, other->name );
322+ return equal (axes, other->axes ) && equal (data, other->data ) && equal (name, other->name );
338323 }
339324
340325 void SHashReduce (SHashReducer hash_reduce) const {
@@ -386,8 +371,8 @@ class SpIterVarNode : public Object {
386371
387372 bool SEqualReduce (const SpIterVarNode* other, SEqualReducer equal) const {
388373 return equal (var, other->var ) && equal (max_extent, other->max_extent ) &&
389- equal (axis, other->axis ) &&
390- equal (is_reduction, other-> is_reduction ) && equal ( kind, other->kind );
374+ equal (axis, other->axis ) && equal (is_reduction, other-> is_reduction ) &&
375+ equal (kind, other->kind );
391376 }
392377
393378 void SHashReduce (SHashReducer hash_reduce) const {
@@ -406,8 +391,8 @@ class SpIterVarNode : public Object {
406391
407392class SpIterVar : public ObjectRef {
408393 public:
409- TVM_DLL explicit SpIterVar (String name , PrimExpr max_extent, SpIterKind kind,
410- bool is_reduction, Optional<Axis> axis = NullOpt);
394+ TVM_DLL explicit SpIterVar (Var var , PrimExpr max_extent, SpIterKind kind, bool is_reduction ,
395+ Optional<Axis> axis = NullOpt);
411396
412397 /* !
413398 * \return the corresponding var in the IterVar.
0 commit comments