Skip to content

Commit b3b2c6f

Browse files
MasterJH5574yzh119
authored andcommitted
[SparseTIR] Introduce SpIterVar (#6)
* [SparseTIR] Introduce SpIterVar * Add conversion to PrimExpr
1 parent 8b64e12 commit b3b2c6f

File tree

3 files changed

+119
-1
lines changed

3 files changed

+119
-1
lines changed

include/tvm/tir/sparse.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,64 @@ class SparseBuffer : public ObjectRef {
355355
TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
356356
};
357357

358+
enum class SpIterKind : int {
359+
kDenseFixed = 0,
360+
kDenseVariable = 1,
361+
kSparseFixed = 2,
362+
kSparseVariable = 3
363+
};
364+
365+
/*!
366+
* \brief Iterator variables in SparseTIR
367+
*/
368+
class SpIterVarNode : public Object {
369+
public:
370+
Var var;
371+
PrimExpr max_extent;
372+
SpIterKind kind;
373+
Optional<Axis> axis;
374+
375+
void VisitAttrs(AttrVisitor* v) {
376+
v->Visit("var", &var);
377+
v->Visit("max_extent", &max_extent);
378+
v->Visit("axis", &axis);
379+
v->Visit("kind", &kind);
380+
}
381+
382+
bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const {
383+
return equal(var, other->var) && equal(max_extent, other->max_extent) &&
384+
equal(axis, other->axis) && equal(kind, other->kind);
385+
}
386+
387+
void SHashReduce(SHashReducer hash_reduce) const {
388+
hash_reduce(var);
389+
hash_reduce(max_extent);
390+
hash_reduce(axis);
391+
hash_reduce(kind);
392+
}
393+
394+
static constexpr const char* _type_key = "tir.sparse.SpIterVar";
395+
static constexpr const bool _type_has_method_sequal_reduce = true;
396+
static constexpr const bool _type_has_method_shash_reduce = true;
397+
TVM_DECLARE_FINAL_OBJECT_INFO(SpIterVarNode, Object);
398+
};
399+
400+
class SpIterVar : public ObjectRef {
401+
public:
402+
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind,
403+
Optional<Axis> axis = NullOpt);
404+
405+
/*!
406+
* \return the corresponding var in the IterVar.
407+
*/
408+
inline operator PrimExpr() const;
409+
410+
TVM_DEFINE_OBJECT_REF_METHODS(SpIterVar, ObjectRef, SpIterVarNode);
411+
};
412+
413+
// inline implementations
414+
inline SpIterVar::operator PrimExpr() const { return (*this)->var; }
415+
358416
} // namespace tir
359417
} // namespace tvm
360418

python/tvm/tir/sparse.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tvm._ffi
2121
from tvm.ir import PrimExpr
2222
from tvm.runtime import Object, const
23+
from tvm.tir import Var
2324

2425
from . import _ffi_api
2526
from .buffer import Buffer
@@ -166,7 +167,7 @@ def __init__(self, axis_parent_map) -> None:
166167

167168

168169
@tvm._ffi.register_object("tir.sparse.SparseBuffer")
169-
class SparseBuffer:
170+
class SparseBuffer(Object):
170171
"""SparseBuffer node
171172
172173
Parameters
@@ -197,3 +198,39 @@ def __init__(self, tree, axes, data, name, dtype=None):
197198
self.__init_handle_by_constructor__(
198199
_ffi_api.SparseBuffer, tree, axes, data, name, dtype # type: ignore
199200
)
201+
202+
203+
@tvm._ffi.register_object("tir.sparse.SpIterVar")
204+
class SpIterVar(Object):
205+
"""IterVar in SparseTIR
206+
207+
Parameters
208+
----------
209+
var : Var
210+
The var of the SpIterVar
211+
212+
max_extent : PrimExpr
213+
The maximum extent of the SpIterVar
214+
215+
kind : int
216+
The kind of the SpIterVar
217+
218+
axis : Optional[Axis]
219+
The axis over which the SpIterVar iterates. Required to be defined
220+
when `kind` is not `DenseFixed`
221+
"""
222+
var: Var
223+
max_extent: PrimExpr
224+
kind: int
225+
axis: Optional[Axis]
226+
227+
DenseFixed = 0
228+
DenseVariable = 1
229+
SparseFixed = 2
230+
SparseVariable = 3
231+
232+
def __init__(self, var, max_extent, kind, axis=None):
233+
self.__init_handle_by_constructor__(
234+
_ffi_api.SpIterVar, var, max_extent, kind, axis # type: ignore
235+
)
236+

src/tir/ir/sparse.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,5 +163,28 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
163163
return SparseBuffer(tree, axes, data, name, dtype);
164164
});
165165

166+
// SpIterVar
167+
SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, Optional<Axis> axis) {
168+
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();
169+
170+
if (kind != SpIterKind::kDenseFixed) {
171+
CHECK(axis.defined()) << "ValueError: To create a SpIterVar that is not fixed-dense, one must "
172+
"specify the axis over which the SpIterVar iterates";
173+
}
174+
175+
node->var = Var(std::move(name));
176+
node->max_extent = std::move(max_extent);
177+
node->kind = kind;
178+
node->axis = std::move(axis);
179+
data_ = std::move(node);
180+
}
181+
182+
TVM_REGISTER_NODE_TYPE(SpIterVarNode);
183+
184+
TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar")
185+
.set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, Optional<Axis> axis) {
186+
return SpIterVar(name, max_extent, kind, axis);
187+
});
188+
166189
} // namespace tir
167190
} // namespace tvm

0 commit comments

Comments
 (0)