Skip to content

Commit 2069eea

Browse files
committed
Fix AxisTree (#3)
* fix axis tree * upd
1 parent abf6205 commit 2069eea

File tree

6 files changed

+141
-30
lines changed

6 files changed

+141
-30
lines changed

include/tvm/tir/sparse.h

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -268,15 +268,23 @@ class SparseVariableAxis : public SparseAxis {
268268
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
269269
};
270270

271+
271272
/*!
272273
* \brief Axis Dependency Tree.
273274
*/
274275
class AxisTreeNode : public Object {
275276
public:
276-
// parent refers to the parent axis of current axis tree.
277-
Optional<AxisTree> parent;
278-
Axis axis;
279-
Array<AxisTree> children;
277+
// mapping from names to axes.
278+
std::unordered_map<String, Axis> axis_map;
279+
// unordered map that stores the parent relationship between axes.
280+
std::unordered_map<Axis, Axis, ObjectPtrHash, ObjectPtrEqual> parent;
281+
// unordered map that stores the children relationship between axes.
282+
std::unordered_map<Axis, Array<Axis>, ObjectPtrHash, ObjectPtrEqual> children;
283+
// The root axis.
284+
Axis root;
285+
286+
void VisitAttrs(AttrVisitor* v) {}
287+
280288
static constexpr const char* _type_key = "tir.sparse.AxisTree";
281289
TVM_DECLARE_FINAL_OBJECT_INFO(AxisTreeNode, Object);
282290
};
@@ -287,6 +295,7 @@ class AxisTreeNode : public Object {
287295
*/
288296
class AxisTree : public ObjectRef {
289297
public:
298+
TVM_DLL AxisTree(Array<Axis> axes, Array<Optional<String>> axis_parent_names);
290299
TVM_DEFINE_OBJECT_REF_METHODS(AxisTree, ObjectRef, AxisTreeNode);
291300
};
292301

@@ -296,7 +305,7 @@ class AxisTree : public ObjectRef {
296305
class SparseBufferNode : public Object {
297306
public:
298307
/* Root of Axis Dependency Tree. */
299-
AxisTree root;
308+
AxisTree tree;
300309
/* Axes */
301310
Array<Axis> axes;
302311
/* Number of dimensions */
@@ -305,25 +314,25 @@ class SparseBufferNode : public Object {
305314
Buffer data;
306315

307316
void VisitAttrs(AttrVisitor* v) {
308-
v->Visit("name", &root);
317+
v->Visit("name", &tree);
309318
v->Visit("length", &axes);
310319
v->Visit("indptr", &ndim);
311320
v->Visit("num_cols", &data);
312321
}
313322

314323
bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
315-
return equal(root, other->root) && equal(axes, other->axes) && equal(ndim, other->ndim) &&
324+
return equal(tree, other->tree) && equal(axes, other->axes) && equal(ndim, other->ndim) &&
316325
equal(data, other->data);
317326
}
318327

319328
void SHashReduce(SHashReducer hash_reduce) const {
320-
hash_reduce(root);
329+
hash_reduce(tree);
321330
hash_reduce(axes);
322331
hash_reduce(ndim);
323332
hash_reduce(data);
324333
}
325334

326-
static constexpr const char* _type_key = "tir.sparse.SparseBufferNode";
335+
static constexpr const char* _type_key = "tir.sparse.SparseBuffer";
327336
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferNode, Object);
328337
};
329338

@@ -333,7 +342,7 @@ class SparseBufferNode : public Object {
333342
*/
334343
class SparseBuffer : public ObjectRef {
335344
public:
336-
TVM_DLL explicit SparseBuffer(AxisTree root, Array<Axis> axes, int ndim, Buffer data);
345+
TVM_DLL explicit SparseBuffer(AxisTree tree, Array<Axis> axes, int ndim, Buffer data);
337346

338347
TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
339348
};

python/tvm/tir/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,4 @@
5555
from . import transform
5656
from . import analysis
5757
from . import stmt_functor
58+
from . import sparse

python/tvm/tir/_ffi_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919

2020

2121
tvm._ffi._init_api("tir", __name__)
22+
tvm._ffi._init_api("tir.sparse", __name__)

python/tvm/tir/sparse.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717
"""SparseTIR axes and SparseBuffer
1818
"""
19-
from typing import List
19+
from typing import List, Dict, Optional
2020
import tvm._ffi
2121
from tvm.ir import PrimExpr
2222
from tvm.runtime import Object, const
@@ -146,9 +146,23 @@ def __init__(self, name, length, indptr, indices):
146146

147147

148148
@tvm._ffi.register_object("tir.sparse.AxisTree")
149-
class AxisTree:
150-
# Todo(@ruihang): to do later
151-
pass
149+
class AxisTree(Object):
150+
"""AxisTree node
151+
152+
Parameters
153+
----------
154+
axis_parent_map: Dict
155+
A dictionary that maps Axis to parent axis name, value is None if there is not parent axis.
156+
"""
157+
158+
axis_parent_map: Dict[Axis, Optional[str]]
159+
160+
def __init__(self, axis_parent_map) -> None:
161+
keys = list(axis_parent_map.keys())
162+
values = list(axis_parent_map.values())
163+
self.__init_handle_by_constructor__(
164+
_ffi_api.AxisTree, keys, values # type:ignore
165+
)
152166

153167

154168
@tvm._ffi.register_object("tir.sparse.SparseBuffer")
@@ -157,8 +171,8 @@ class SparseBuffer:
157171
158172
Parameters
159173
----------
160-
root : AxisTree
161-
The root of the axis dependency tree of the sparse buffer
174+
tree : AxisTree
175+
The axis dependency tree of the sparse buffer
162176
163177
axes : List[Axis]
164178
The axes of the sparse buffer
@@ -170,12 +184,12 @@ class SparseBuffer:
170184
The data of the sparse buffer
171185
"""
172186

173-
root: AxisTree
187+
tree: AxisTree
174188
axes: List[Axis]
175189
ndim: int
176190
data: Buffer
177191

178-
def __init__(self, root, axes, ndim, data):
192+
def __init__(self, tree, axes, ndim, data):
179193
self.__init_handle_by_constructor__(
180194
_ffi_api.SparseBuffer, root, axes, ndim, data # type: ignore
181195
)

src/tir/ir/sparse.cc

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace tir {
3030

3131
namespace sparse {
3232

33+
3334
// DenseFixedAxis
3435
DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {
3536
ObjectPtr<DenseFixedAxisNode> node = make_object<DenseFixedAxisNode>();
@@ -40,12 +41,14 @@ DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {
4041

4142
TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode);
4243

43-
TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) {
44-
return DenseFixedAxis(name, length);
45-
});
44+
TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis")
45+
.set_body_typed([](String name, PrimExpr length) {
46+
return DenseFixedAxis(name, length);
47+
});
4648

4749
// DenseVariableAxis
48-
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) {
50+
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length,
51+
Buffer indptr) {
4952
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
5053
node->name = std::move(name);
5154
node->length = std::move(length);
@@ -61,7 +64,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
6164
});
6265

6366
// SparseFixedAxis
64-
SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
67+
SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices,
68+
PrimExpr num_cols) {
6569
ObjectPtr<SparseFixedAxisNode> node = make_object<SparseFixedAxisNode>();
6670
node->name = std::move(name);
6771
node->length = std::move(length);
@@ -73,14 +77,16 @@ SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, P
7377
TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode);
7478

7579
TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
76-
.set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
80+
.set_body_typed([](String name, PrimExpr length, Buffer indices,
81+
PrimExpr num_cols) {
7782
return SparseFixedAxis(name, length, indices, num_cols);
7883
});
7984

8085
// SparseVariableAxis
81-
SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indptr,
82-
Buffer indices) {
83-
ObjectPtr<SparseVariableAxisNode> node = make_object<SparseVariableAxisNode>();
86+
SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length,
87+
Buffer indptr, Buffer indices) {
88+
ObjectPtr<SparseVariableAxisNode> node =
89+
make_object<SparseVariableAxisNode>();
8490
node->name = std::move(name);
8591
node->length = std::move(length);
8692
node->indptr = std::move(indptr);
@@ -91,14 +97,61 @@ SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indp
9197
TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode);
9298

9399
TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis")
94-
.set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) {
100+
.set_body_typed([](String name, PrimExpr length, Buffer indptr,
101+
Buffer indices) {
95102
return SparseVariableAxis(name, length, indptr, indices);
96103
});
97104

105+
// AxisTree
106+
AxisTree::AxisTree(Array<Axis> axes,
107+
Array<Optional<String>> axis_parent_names) {
108+
CHECK_EQ(axes.size(), axis_parent_names.size())
109+
<< "ValueError: The axes array should have the same length as axis_parent_names "
110+
"array.";
111+
ObjectPtr<AxisTreeNode> node = make_object<AxisTreeNode>();
112+
Axis root = Downcast<Axis>(RootAxis());
113+
for (const Axis& axis : axes) {
114+
// update axis map
115+
String name = axis->name;
116+
CHECK(node->axis_map.find(name) != node->axis_map.end()) << "ValueError: duplicate axis names.";
117+
node->axis_map[name] = axis;
118+
}
119+
for (size_t i = 0; i < axes.size(); i++) {
120+
// update parent map & children map
121+
Axis axis = axes[i];
122+
Optional<String> parent_name = axis_parent_names[i];
123+
if (parent_name.get() != nullptr) {
124+
CHECK(node->axis_map.find(parent_name.value()) != node->axis_map.end())
125+
<< "ValueError: Parent axis name doesn't exist.";
126+
}
127+
Axis parent_axis = (parent_name.get() != nullptr)
128+
? node->axis_map[parent_name.value()]
129+
: root;
130+
node->parent[axis] = parent_axis;
131+
if (node->children.find(parent_axis) != node->children.end()) {
132+
node->children[parent_axis].push_back(axis);
133+
} else {
134+
Array<Axis> children;
135+
children.push_back(axis);
136+
node->children[parent_axis] = std::move(children);
137+
}
138+
}
139+
data_ = std::move(node);
140+
}
141+
142+
TVM_REGISTER_NODE_TYPE(AxisTreeNode);
143+
144+
TVM_REGISTER_GLOBAL("tir.sparse.AxisTree")
145+
.set_body_typed([](Array<Axis> axes,
146+
Array<Optional<String>> axis_parent_names) {
147+
return AxisTree(axes, axis_parent_names);
148+
});
149+
98150
// SparseBuffer
99-
SparseBuffer::SparseBuffer(AxisTree root, Array<Axis> axes, int ndim, Buffer data) {
151+
SparseBuffer::SparseBuffer(AxisTree tree, Array<Axis> axes, int ndim,
152+
Buffer data) {
100153
ObjectPtr<SparseBufferNode> node = make_object<SparseBufferNode>();
101-
node->root = std::move(root);
154+
node->tree = std::move(tree);
102155
node->axes = std::move(axes);
103156
node->ndim = ndim;
104157
node->data = std::move(data);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import tvm
18+
import tvm.tir as tir
19+
20+
def test_format_tree_creation():
21+
i = tir.sparse.DenseFixedAxis('i', 128)
22+
j = tir.sparse.DenseFixedAxis('j', 128)
23+
k = tir.sparse.DenseFixedAxis('k', 128)
24+
tree = tir.sparse.AxisTree({
25+
i: None,
26+
j: None,
27+
k: None
28+
})
29+
print(tree)
30+
31+
32+
if __name__ == "__main__":
33+
test_format_tree_creation()

0 commit comments

Comments
 (0)