@@ -30,6 +30,7 @@ namespace tir {
3030
3131namespace sparse {
3232
33+
3334// DenseFixedAxis
3435DenseFixedAxis::DenseFixedAxis (String name, PrimExpr length) {
3536 ObjectPtr<DenseFixedAxisNode> node = make_object<DenseFixedAxisNode>();
@@ -40,12 +41,14 @@ DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {
4041
4142TVM_REGISTER_NODE_TYPE (DenseFixedAxisNode);
4243
43- TVM_REGISTER_GLOBAL (" tir.sparse.DenseFixedAxis" ).set_body_typed([](String name, PrimExpr length) {
44- return DenseFixedAxis (name, length);
45- });
44+ TVM_REGISTER_GLOBAL (" tir.sparse.DenseFixedAxis" )
45+ .set_body_typed([](String name, PrimExpr length) {
46+ return DenseFixedAxis (name, length);
47+ });
4648
4749// DenseVariableAxis
48- DenseVariableAxis::DenseVariableAxis (String name, PrimExpr length, Buffer indptr) {
50+ DenseVariableAxis::DenseVariableAxis (String name, PrimExpr length,
51+ Buffer indptr) {
4952 ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
5053 node->name = std::move (name);
5154 node->length = std::move (length);
@@ -61,7 +64,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
6164 });
6265
6366// SparseFixedAxis
64- SparseFixedAxis::SparseFixedAxis (String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
67+ SparseFixedAxis::SparseFixedAxis (String name, PrimExpr length, Buffer indices,
68+ PrimExpr num_cols) {
6569 ObjectPtr<SparseFixedAxisNode> node = make_object<SparseFixedAxisNode>();
6670 node->name = std::move (name);
6771 node->length = std::move (length);
@@ -73,14 +77,16 @@ SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, P
7377TVM_REGISTER_NODE_TYPE (SparseFixedAxisNode);
7478
7579TVM_REGISTER_GLOBAL (" tir.sparse.SparseFixedAxis" )
76- .set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
80+ .set_body_typed([](String name, PrimExpr length, Buffer indices,
81+ PrimExpr num_cols) {
7782 return SparseFixedAxis (name, length, indices, num_cols);
7883 });
7984
8085// SparseVariableAxis
81- SparseVariableAxis::SparseVariableAxis (String name, PrimExpr length, Buffer indptr,
82- Buffer indices) {
83- ObjectPtr<SparseVariableAxisNode> node = make_object<SparseVariableAxisNode>();
86+ SparseVariableAxis::SparseVariableAxis (String name, PrimExpr length,
87+ Buffer indptr, Buffer indices) {
88+ ObjectPtr<SparseVariableAxisNode> node =
89+ make_object<SparseVariableAxisNode>();
8490 node->name = std::move (name);
8591 node->length = std::move (length);
8692 node->indptr = std::move (indptr);
@@ -91,14 +97,61 @@ SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indp
9197TVM_REGISTER_NODE_TYPE (SparseVariableAxisNode);
9298
9399TVM_REGISTER_GLOBAL (" tir.sparse.SparseVariableAxis" )
94- .set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) {
100+ .set_body_typed([](String name, PrimExpr length, Buffer indptr,
101+ Buffer indices) {
95102 return SparseVariableAxis (name, length, indptr, indices);
96103 });
97104
105+ // AxisTree
106+ AxisTree::AxisTree (Array<Axis> axes,
107+ Array<Optional<String>> axis_parent_names) {
108+ CHECK_EQ (axes.size (), axis_parent_names.size ())
109+ << " ValueError: The axes array should have the same length as axis_parent_names "
110+ " array." ;
111+ ObjectPtr<AxisTreeNode> node = make_object<AxisTreeNode>();
112+ Axis root = Downcast<Axis>(RootAxis ());
113+ for (const Axis& axis : axes) {
114+ // update axis map
115+ String name = axis->name ;
116+ CHECK (node->axis_map .find (name) != node->axis_map .end ()) << " ValueError: duplicate axis names." ;
117+ node->axis_map [name] = axis;
118+ }
119+ for (size_t i = 0 ; i < axes.size (); i++) {
120+ // update parent map & children map
121+ Axis axis = axes[i];
122+ Optional<String> parent_name = axis_parent_names[i];
123+ if (parent_name.get () != nullptr ) {
124+ CHECK (node->axis_map .find (parent_name.value ()) != node->axis_map .end ())
125+ << " ValueError: Parent axis name doesn't exist." ;
126+ }
127+ Axis parent_axis = (parent_name.get () != nullptr )
128+ ? node->axis_map [parent_name.value ()]
129+ : root;
130+ node->parent [axis] = parent_axis;
131+ if (node->children .find (parent_axis) != node->children .end ()) {
132+ node->children [parent_axis].push_back (axis);
133+ } else {
134+ Array<Axis> children;
135+ children.push_back (axis);
136+ node->children [parent_axis] = std::move (children);
137+ }
138+ }
139+ data_ = std::move (node);
140+ }
141+
142+ TVM_REGISTER_NODE_TYPE (AxisTreeNode);
143+
144+ TVM_REGISTER_GLOBAL (" tir.sparse.AxisTree" )
145+ .set_body_typed([](Array<Axis> axes,
146+ Array<Optional<String>> axis_parent_names) {
147+ return AxisTree (axes, axis_parent_names);
148+ });
149+
98150// SparseBuffer
99- SparseBuffer::SparseBuffer (AxisTree root, Array<Axis> axes, int ndim, Buffer data) {
151+ SparseBuffer::SparseBuffer (AxisTree tree, Array<Axis> axes, int ndim,
152+ Buffer data) {
100153 ObjectPtr<SparseBufferNode> node = make_object<SparseBufferNode>();
101- node->root = std::move (root );
154+ node->tree = std::move (tree );
102155 node->axes = std::move (axes);
103156 node->ndim = ndim;
104157 node->data = std::move (data);
0 commit comments