Skip to content

Commit d4b4550

Browse files
committed
[SparseTIR] Constructors and Python Interface for Axis and SparseBuffer (#2)
* add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface
1 parent 03f3eb6 commit d4b4550

File tree

3 files changed

+380
-10
lines changed

3 files changed

+380
-10
lines changed

include/tvm/tir/sparse.h

Lines changed: 111 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ class AxisNode : public Object {
4444
/* length of current axis. For sparse axis, length refers to the upperbound of
4545
* the current axis. */
4646
PrimExpr length;
47+
4748
static constexpr const char* _type_key = "tir.sparse.Axis";
49+
static constexpr const bool _type_has_method_sequal_reduce = true;
50+
static constexpr const bool _type_has_method_shash_reduce = true;
4851
TVM_DECLARE_BASE_OBJECT_INFO(AxisNode, Object);
4952
};
5053

@@ -98,6 +101,20 @@ class DenseAxis : public Axis {
98101
*/
99102
class DenseFixedAxisNode : public DenseAxisNode {
100103
public:
104+
void VisitAttrs(AttrVisitor* v) {
105+
v->Visit("name", &name);
106+
v->Visit("length", &length);
107+
}
108+
109+
bool SEqualReduce(const DenseAxisNode* other, SEqualReducer equal) const {
110+
return equal(name, other->name) && equal(length, other->length);
111+
}
112+
113+
void SHashReduce(SHashReducer hash_reduce) const {
114+
hash_reduce(name);
115+
hash_reduce(length);
116+
}
117+
101118
static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
102119
TVM_DECLARE_FINAL_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
103120
};
@@ -108,12 +125,31 @@ class DenseFixedAxisNode : public DenseAxisNode {
108125
*/
109126
class DenseFixedAxis : public DenseAxis {
110127
public:
128+
TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length);
129+
111130
TVM_DEFINE_OBJECT_REF_METHODS(DenseFixedAxis, DenseAxis, DenseFixedAxisNode);
112131
};
113132

114133
class DenseVariableAxisNode : public DenseAxisNode {
115134
public:
116135
Buffer indptr;
136+
137+
void VisitAttrs(AttrVisitor* v) {
138+
v->Visit("name", &name);
139+
v->Visit("length", &length);
140+
v->Visit("indptr", &indptr);
141+
}
142+
143+
bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const {
144+
return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr);
145+
}
146+
147+
void SHashReduce(SHashReducer hash_reduce) const {
148+
hash_reduce(name);
149+
hash_reduce(length);
150+
hash_reduce(indptr);
151+
}
152+
117153
static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
118154
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
119155
};
@@ -124,8 +160,9 @@ class DenseVariableAxisNode : public DenseAxisNode {
124160
*/
125161
class DenseVariableAxis : public DenseAxis {
126162
public:
127-
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis,
128-
DenseVariableAxisNode);
163+
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr);
164+
165+
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
129166
};
130167

131168
/*!
@@ -154,6 +191,26 @@ class SparseFixedAxisNode : public SparseAxisNode {
154191
Buffer indices;
155192
/* fixed number of columns of current sparse axis. */
156193
PrimExpr num_cols;
194+
195+
void VisitAttrs(AttrVisitor* v) {
196+
v->Visit("name", &name);
197+
v->Visit("length", &length);
198+
v->Visit("indptr", &indices);
199+
v->Visit("num_cols", &num_cols);
200+
}
201+
202+
bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
203+
return equal(name, other->name) && equal(length, other->length) &&
204+
equal(indices, other->indices) && equal(num_cols, other->num_cols);
205+
}
206+
207+
void SHashReduce(SHashReducer hash_reduce) const {
208+
hash_reduce(name);
209+
hash_reduce(length);
210+
hash_reduce(indices);
211+
hash_reduce(num_cols);
212+
}
213+
157214
static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
158215
TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode);
159216
};
@@ -164,17 +221,39 @@ class SparseFixedAxisNode : public SparseAxisNode {
164221
*/
165222
class SparseFixedAxis : public SparseAxis {
166223
public:
167-
TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis,
168-
SparseFixedAxisNode);
224+
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols);
225+
226+
TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode);
169227
};
170228

171229
/*!
172230
* \brief Sparse axis with variable number of non-zero columns per row.
173231
*/
174232
class SparseVariableAxisNode : public SparseAxisNode {
175233
public:
176-
Buffer indptr, indices;
177-
static constexpr const char* _type_key = "tir.sparse.SparseVariabledAxis";
234+
Buffer indptr;
235+
Buffer indices;
236+
237+
void VisitAttrs(AttrVisitor* v) {
238+
v->Visit("name", &name);
239+
v->Visit("length", &length);
240+
v->Visit("indptr", &indptr);
241+
v->Visit("indices", &indices);
242+
}
243+
244+
bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const {
245+
return equal(name, other->name) && equal(length, other->length) &&
246+
equal(indptr, other->indptr) && equal(indices, other->indices);
247+
}
248+
249+
void SHashReduce(SHashReducer hash_reduce) const {
250+
hash_reduce(name);
251+
hash_reduce(length);
252+
hash_reduce(indptr);
253+
hash_reduce(indices);
254+
}
255+
256+
static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis";
178257
TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode);
179258
};
180259

@@ -184,8 +263,9 @@ class SparseVariableAxisNode : public SparseAxisNode {
184263
*/
185264
class SparseVariableAxis : public SparseAxis {
186265
public:
187-
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis,
188-
SparseVariableAxisNode);
266+
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices);
267+
268+
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
189269
};
190270

191271
/*!
@@ -223,6 +303,26 @@ class SparseBufferNode : public Object {
223303
int ndim;
224304
/* Buffer corresponding to flattened value */
225305
Buffer data;
306+
307+
void VisitAttrs(AttrVisitor* v) {
308+
v->Visit("name", &root);
309+
v->Visit("length", &axes);
310+
v->Visit("indptr", &ndim);
311+
v->Visit("num_cols", &data);
312+
}
313+
314+
bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
315+
return equal(root, other->root) && equal(axes, other->axes) && equal(ndim, other->ndim) &&
316+
equal(data, other->data);
317+
}
318+
319+
void SHashReduce(SHashReducer hash_reduce) const {
320+
hash_reduce(root);
321+
hash_reduce(axes);
322+
hash_reduce(ndim);
323+
hash_reduce(data);
324+
}
325+
226326
static constexpr const char* _type_key = "tir.sparse.SparseBufferNode";
227327
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferNode, Object);
228328
};
@@ -233,11 +333,13 @@ class SparseBufferNode : public Object {
233333
*/
234334
class SparseBuffer : public ObjectRef {
235335
public:
336+
TVM_DLL explicit SparseBuffer(AxisTree root, Array<Axis> axes, int ndim, Buffer data);
337+
236338
TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
237339
};
238340

239341
} // namespace sparse
240342
} // namespace tir
241343
} // namespace tvm
242344

243-
#endif // TVM_TIR_BUFFER_H_
345+
#endif // TVM_TIR_SPARSE_H_

python/tvm/tir/sparse.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""SparseTIR axes and SparseBuffer
18+
"""
19+
from typing import List
20+
import tvm._ffi
21+
from tvm.ir import PrimExpr
22+
from tvm.runtime import Object, const
23+
24+
from . import _ffi_api
25+
from .buffer import Buffer
26+
27+
28+
class Axis(Object):
29+
"""Base class of all the sparse axes."""
30+
31+
32+
class DenseAxis(Axis):
33+
pass
34+
35+
36+
class SparseAxis(Axis):
37+
pass
38+
39+
40+
@tvm._ffi.register_object("tir.sparse.DenseFixedAxis")
41+
class DenseFixedAxis(DenseAxis):
42+
"""DenseFixedAxis node
43+
44+
Parameters
45+
----------
46+
name : str
47+
The name of the axis
48+
49+
length : PrimExpr
50+
The length of the axis
51+
"""
52+
53+
name: str
54+
length: PrimExpr
55+
56+
def __init__(self, name, length):
57+
self.__init_handle_by_constructor__(
58+
_ffi_api.DenseFixedAxis, name, length # type: ignore
59+
)
60+
61+
62+
@tvm._ffi.register_object("tir.sparse.DenseVariableAxis")
63+
class DenseVariableAxis(DenseAxis):
64+
"""DenseVariableAxis node
65+
66+
Parameters
67+
----------
68+
name : str
69+
The name of the axis
70+
71+
length : PrimExpr
72+
The length of the axis
73+
74+
indptr : Buffer
75+
The indptr buffer of the axis
76+
"""
77+
78+
name: str
79+
length: PrimExpr
80+
indptr: Buffer
81+
82+
def __init__(self, name, length, indptr):
83+
self.__init_handle_by_constructor__(
84+
_ffi_api.DenseVariableAxis, name, length, indptr # type: ignore
85+
)
86+
87+
88+
@tvm._ffi.register_object("tir.sparse.SparseFixedAxis")
89+
class SparseFixedAxis(DenseAxis):
90+
"""SparseFixedAxis node
91+
92+
Parameters
93+
----------
94+
name : str
95+
The name of the axis
96+
97+
length : PrimExpr
98+
The length of the axis
99+
100+
indices : Buffer
101+
The indices buffer of the axis
102+
103+
num_cols : PrimExpr
104+
The number of non-zero elements along the axis
105+
"""
106+
107+
name: str
108+
length: PrimExpr
109+
indices: Buffer
110+
num_cols: PrimExpr
111+
112+
def __init__(self, name, length, indices, num_cols):
113+
self.__init_handle_by_constructor__(
114+
_ffi_api.SparseFixedAxis, name, length, indices, num_cols # type: ignore
115+
)
116+
117+
118+
@tvm._ffi.register_object("tir.sparse.SparseVariableAxis")
119+
class SparseVariableAxis(DenseAxis):
120+
"""SparseVariableAxis node
121+
122+
Parameters
123+
----------
124+
name : str
125+
The name of the axis
126+
127+
length : PrimExpr
128+
The length of the axis
129+
130+
indptr : Buffer
131+
The indptr buffer of the axis
132+
133+
indices : Buffer
134+
The indices buffer of the axis
135+
"""
136+
137+
name: str
138+
length: PrimExpr
139+
indptr: Buffer
140+
indices: Buffer
141+
142+
def __init__(self, name, length, indptr, indices):
143+
self.__init_handle_by_constructor__(
144+
_ffi_api.SparseVariableAxis, name, length, indptr, indices # type: ignore
145+
)
146+
147+
148+
@tvm._ffi.register_object("tir.sparse.AxisTree")
149+
class AxisTree:
150+
# Todo(@ruihang): to do later
151+
pass
152+
153+
154+
@tvm._ffi.register_object("tir.sparse.SparseBuffer")
155+
class SparseBuffer:
156+
"""SparseBuffer node
157+
158+
Parameters
159+
----------
160+
root : AxisTree
161+
The root of the axis dependency tree of the sparse buffer
162+
163+
axes : List[Axis]
164+
The axes of the sparse buffer
165+
166+
ndim : int
167+
The number of dimensions of the sparse buffer
168+
169+
data : Buffer
170+
The data of the sparse buffer
171+
"""
172+
173+
root: AxisTree
174+
axes: List[Axis]
175+
ndim: int
176+
data: Buffer
177+
178+
def __init__(self, root, axes, ndim, data):
179+
self.__init_handle_by_constructor__(
180+
_ffi_api.SparseBuffer, root, axes, ndim, data # type: ignore
181+
)

0 commit comments

Comments
 (0)