diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index a6fbbda19a91..935da5a8f259 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -44,6 +44,10 @@ class AxisNode : public Object { * the current axis. */ PrimExpr length; + String GetName() const { return name; } + PrimExpr GetLength() const { return length; } + DataType GetIndexType() const { return length->dtype; } + static constexpr const char* _type_key = "tir.sparse.Axis"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -139,8 +143,10 @@ class DenseVariableAxisNode : public DenseAxisNode { v->Visit("indptr", &indptr); } - bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr); + bool SEqualReduce(const DenseVariableAxisNode* other, + SEqualReducer equal) const { + return equal(name, other->name) && equal(length, other->length) && + equal(indptr, other->indptr); } void SHashReduce(SHashReducer hash_reduce) const { @@ -159,9 +165,11 @@ class DenseVariableAxisNode : public DenseAxisNode { */ class DenseVariableAxis : public DenseAxis { public: - TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr); + TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, + Buffer indptr); - TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode); + TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, + DenseVariableAxisNode); }; /*! @@ -198,7 +206,8 @@ class SparseFixedAxisNode : public SparseAxisNode { v->Visit("num_cols", &num_cols); } - bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const { + bool SEqualReduce(const SparseFixedAxisNode* other, + SEqualReducer equal) const { return equal(name, other->name) && equal(length, other->length) && equal(indices, other->indices) && equal(num_cols, other->num_cols); } @@ -220,9 +229,11 @@ class SparseFixedAxisNode : public SparseAxisNode { */ class SparseFixedAxis : public SparseAxis { public: - TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols); + TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, + PrimExpr num_cols); - TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode); + TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, + SparseFixedAxisNode); }; /*! @@ -240,7 +251,8 @@ class SparseVariableAxisNode : public SparseAxisNode { v->Visit("indices", &indices); } - bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const { + bool SEqualReduce(const SparseVariableAxisNode* other, + SEqualReducer equal) const { return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr) && equal(indices, other->indices); } @@ -262,9 +274,11 @@ class SparseVariableAxisNode : public SparseAxisNode { */ class SparseVariableAxis : public SparseAxis { public: - TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices); + TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, + Buffer indptr, Buffer indices); - TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode); + TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, + SparseVariableAxisNode); }; /*! @@ -272,14 +286,13 @@ class SparseVariableAxis : public SparseAxis { */ class AxisTreeNode : public Object { public: - // mapping from names to axes. - std::unordered_map axis_map; // unordered map that stores the parent relationship between axes. - std::unordered_map parent; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + parent; // unordered map that stores the children relationship between axes. - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> children; - // The root axis. - Axis root; + std::unordered_map, Array, ObjectPtrHash, + ObjectPtrEqual> + children; void VisitAttrs(AttrVisitor* v) {} @@ -293,7 +306,9 @@ class AxisTreeNode : public Object { */ class AxisTree : public ObjectRef { public: - TVM_DLL AxisTree(Array axes, Array> axis_parent_names); + TVM_DLL AxisTree(Array axis_names, + Array> axis_parent_names); + TVM_DEFINE_OBJECT_REF_METHODS(AxisTree, ObjectRef, AxisTreeNode); }; @@ -302,38 +317,30 @@ class AxisTree : public ObjectRef { */ class SparseBufferNode : public Object { public: - /* Root of Axis Dependency Tree. */ - AxisTree tree; /* Axes */ Array axes; /* Buffer corresponding to flattened value */ Buffer data; /* Buffer Name */ String name; - /* Data type */ - runtime::DataType dtype; inline int ndim() const { return static_cast(axes.size()); } void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &tree); v->Visit("length", &axes); v->Visit("num_cols", &data); v->Visit("name", &name); - v->Visit("dtype", &dtype); } bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const { - return equal(tree, other->tree) && equal(axes, other->axes) && equal(data, other->data) && - equal(name, other->name) && equal(dtype, other->dtype); + return equal(axes, other->axes) && equal(data, other->data) && + equal(name, other->name); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(tree); hash_reduce(axes); hash_reduce(data); hash_reduce(name); - hash_reduce(dtype); } static constexpr const char* _type_key = "tir.sparse.SparseBuffer"; @@ -346,8 +353,7 @@ class SparseBufferNode : public Object { */ class SparseBuffer : public ObjectRef { public: - TVM_DLL explicit SparseBuffer(AxisTree tree, Array axes, Buffer data, String name, - DataType dtype); + TVM_DLL explicit SparseBuffer(Array axes, Buffer data, String name); TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode); }; @@ -380,8 +386,8 @@ class SpIterVarNode : public Object { bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const { return equal(var, other->var) && equal(max_extent, other->max_extent) && - equal(axis, other->axis) && equal(is_reduction, other->is_reduction) && - equal(kind, other->kind); + equal(axis, other->axis) && + equal(is_reduction, other->is_reduction) && equal(kind, other->kind); } void SHashReduce(SHashReducer hash_reduce) const { @@ -400,8 +406,8 @@ class SpIterVarNode : public Object { class SpIterVar : public ObjectRef { public: - TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction, - Optional axis = NullOpt); + TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, + bool is_reduction, Optional axis = NullOpt); /*! * \return the corresponding var in the IterVar. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 12dc4bf34dc6..36069e1298ba 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -327,6 +327,28 @@ class BufferStore : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); }; +/*! + * \brief Sparse Block node. + */ +class SparseBlockNode : public StmtNode { + public: + /*! \brief The sparse iteration variables of the block. */ + Array sp_iter_vars; + /*! \brief The sparse buffers defined in the block. */ + Array sp_buffers; + /*! \brief The body of the block */ + Stmt body; + + static constexpr const char* _type_key = "tir.SparseBlock"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, StmtNode); +}; + +class SparseBlock : public Stmt { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode); +}; + + /*! * \brief Store value to the high dimension sparse buffer. * diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 149e17bcc701..5938a1da6285 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -23,9 +23,11 @@ import tvm from tvm.ir import Span from tvm.ir.expr import Range +from tvm.script.tir.sparse import MatchSparseBuffer from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion from tvm.runtime import Object from tvm.tir.expr import IterVar +from tvm.tir.sparse import Axis, SparseBuffer from .tir.node import BufferSlice @@ -74,6 +76,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None: """List[Buffer]: list of T.alloc_buffer statements in the block signature""" match_buffers: List[MatchBufferRegion] = [] """List[MatchBufferRegion]: list of T.match_buffer statements in the block signature""" + axes: List[Axis] = [] + """List[Axis]: list of sparse axis created in the block signature.""" + match_sparse_buffers: List[MatchSparseBuffer] + """List[MatchSparseBuffer]: list of T.match_sparse_buffer statements in the block signature.""" iter_values: List[PrimExpr] = [] """List[PrimExpr]: list of binding values for iter vars""" iter_vars: List[IterVar] = [] @@ -119,7 +125,7 @@ class ContextMaintainer: """List[BlockInfo]: The block info for the current block scope""" loop_stack: Dict[Var, Range] = {} """Dict[Var, Range]: The dict from loop var to its domain outside the block""" - symbols: List[Dict[str, Union[Var, Buffer]]] = [] + symbols: List[Dict[str, Union[Var, Buffer, SparseBuffer, Axis]]] = [] """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope""" # function context @@ -127,6 +133,8 @@ class ContextMaintainer: """List[Var]: The function parameters""" func_buffer_map: Mapping[Var, Buffer] = {} """Mapping[Var, Buffer]: The function buffer map""" + func_sparse_buffer_map: Mapping[Var, SparseBuffer] = {} + """Mapping[Var, SparseBuffer]: The function sparse buffer map""" func_dict_attr: Mapping[str, Object] = {} """Mapping[str, Object]: The function attrs""" func_var_env_dict: Mapping[Var, str] = {} @@ -151,6 +159,7 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No # function context self.func_params = [] self.func_buffer_map = {} + self.func_sparse_buffer_map = {} self.func_dict_attr = {} self.func_var_env_dict = {} # parser and analyzer @@ -208,9 +217,9 @@ def exit_block_scope(self): # Pop block_info self.block_info_stack.pop() - def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node): + def update_symbol(self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node): """Append a symbol into current scope""" - if isinstance(symbol, Buffer): + if isinstance(symbol, (Buffer, Var, SparseBuffer, Axis)): if name in self.symbols[0]: self.report_error("Duplicate Buffer name: " + symbol.name, node.span) self.symbols[0][name] = symbol @@ -225,7 +234,7 @@ def remove_symbol(self, name: str): return raise RuntimeError("Internal error of tvm script parser: no symbol named " + name) - def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]: + def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var, SparseBuffer, Axis]]: """Look up symbol by name""" for symbols in reversed(self.symbols): if name in symbols: diff --git a/python/tvm/script/tir/sparse.py b/python/tvm/script/tir/sparse.py new file mode 100644 index 000000000000..3a565545f575 --- /dev/null +++ b/python/tvm/script/tir/sparse.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script Interface for Sparse TIR""" +import synr +import tvm +from synr import ast +from tvm.ir.base import Span +from tvm.ir.expr import PrimExpr, Range + +from tvm.script.tir.node import BufferSlice +from tvm.script.tir.utils import buffer_slice_to_region +from tvm.tir.expr import PrimExprWithOp +from .scope_handler import ScopeHandler, LoopInfo +from .intrin import Intrin +from ..context_maintainer import BlockInfo, ContextMaintainer +from .special_stmt import SpecialStmt +from tvm.tir.sparse import Axis, AxisTree, DenseFixedAxis, DenseVariableAxis, SpIterVar, SparseFixedAxis, SparseVariableAxis +from typing import List, Mapping, Optional, Tuple, Any +from tvm.runtime.object import Object +from tvm.script.registry import register +from ..utils import ( + tvm_span_from_synr, + call_with_error_reporting, +) + + +@register +class DenseFixed(SpecialStmt): + """Special Stmt for creating dense fixed axis. + """ + + def __init__(self): + def dense_fixed( + name: str, + length: PrimExpr, + idtype: str = 'int32', + span: Optional[Span] = None + ): + var_name = self.node.lhs[0].id.name + axis = DenseFixedAxis(name, length, idtype=idtype) + self.context.update_symbol(var_name, axis, self.node) + super().__init__(dense_fixed, def_symbol=True) + + +@register +class DenseVariable(SpecialStmt): + """Special Stmt for creating dense variable axis. + """ + + def __init__(self): + def dense_variable( + name: str, + shape: Tuple[PrimExpr, PrimExpr], + indptr: tvm.tir.Var, + idtype: str = 'int32', + span: Optional[Span] = None + ): + indptr_len, length = shape + var_name = self.node.lhs[0].id.name + indptr_buf = tvm.tir.decl_buffer( + (indptr_len,), + dtype=idtype, + name=name + "_indptr", + span=span + ) + axis = DenseVariableAxis(name, length, indptr_buf, idtype=idtype) + self.context.func_buffer_map[indptr] = indptr_buf + self.context.update_symbol(var_name, axis, self.node) + super().__init__(dense_variable, def_symbol=True) + + +@register +class SparseFixed(SpecialStmt): + """Special Stmt for creating sparse fixed axis. + """ + + def __init__(self): + def sparse_fixed( + name: str, + shape: Tuple[PrimExpr, PrimExpr, PrimExpr], + indices: tvm.tir.Var, + idtype: str = 'int32', + span: Optional[Span] = None + ): + var_name = self.node.lhs[0].id.name + length, nnz, nnz_cols = shape + indices_buf = tvm.tir.decl_buffer( + (nnz,), + dtype=idtype, + name=name+"_indices", + span=span + ) + axis = SparseFixedAxis(name, length, indices_buf, nnz_cols, idtype=idtype) + self.context.func_buffer_map[indices] = indices_buf + self.context.update_symbol(var_name, axis, self.node) + super().__init__(sparse_fixed, def_symbol=True) + + +@register +class SparseVariable(SpecialStmt): + """Special Stmt for creating sparse variable axis: + """ + + def __init__(self): + def sparse_variable( + name: str, + shape: Tuple[PrimExpr, PrimExpr], + data: Tuple[tvm.tir.Var, tvm.tir.Var], + idtype: str = 'int32', + span: Optional[Span] = None + ): + var_name = self.node.lhs[0].id.name + length, indptr_len, nnz = shape + indptr, indices = data + indptr_buf = tvm.tir.decl_buffer( + (indptr_len,), + dtype=idtype, + name=name+"_indptr", + span=span + ) + indices_buf = tvm.tir.decl_buffer( + (nnz,), + dtype=idtype, + name=name+"_indices", + span=span + ) + axis = SparseVariableAxis(name, length, indptr_buf, indices_buf, idtype=idtype) + self.context.func_buffer_map[indices] = indices_buf + self.context.func_buffer_map[indptr] = indptr_buf + self.context.update_symbol(var_name, axis, self.node) + super().__init__(sparse_variable, def_symbol=True) + + +@register +class MatchSparseBuffer(SpecialStmt): + """Special Stmt match_sparse_buffer() + """ + + def __init__(self): + def match_sparse_buffer( + param: tvm.tir.Var, + axes: List[Axis], + dtype: str = 'float32', + span: Optional[Span] = None, + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`match_sparse_buffer` must be assigned to a single sparse buffer, " + "e.g. A = match_sparse_buffer(...)" + ) + + buffer_name: str = self.node.lhs[0].id.name + if not isinstance(param, tvm.tir.Var): + self.context.report_error( + "The source of match_sparse_buffer expected Var, but got" + + str(type(param)), + self.node.rhs.params[0].span + ) + + if param in self.context.func_params: + buffer = tvm.tir.sparse.decl_buffer( + axes, + param, + buffer_name, + dtype, + span=span + ) + self.context.func_sparse_buffer_map[param] = buffer + self.context.update_symbol(buffer_name, buffer, self.node) + else: + self.context.report_error( + "Can not bind non-input param to sparse buffer", self.node.rhs.params[0].span + ) + + super().__init__(match_sparse_buffer, def_symbol=True) + + +@register +def to_dense(axis: Axis, span: Optional[Span] = None): + if isinstance(axis, (SparseFixedAxis, SparseVariableAxis)): + return DenseFixedAxis(axis.name, axis.length, axis.idtype) + else: + return axis + + +@register +def cord(axis: Axis, span: Optional[Span] = None): + return 'cord', axis + + +@register +def pos(axis: Axis, span: Optional[Span] = None): + return 'pos', axis diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index 09cf6a3e9f8d..11302a14b1d8 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -28,6 +28,17 @@ class Axis(Object): """Base class of all the sparse axes.""" + @property + def name(self): + return _ffi_api.GetAxisName(self) + + @property + def length(self): + return _ffi_api.GetAxisLength(self) + + @property + def idtype(self): + return _ffi_api.GetAxisIndexType(self) class DenseAxis(Axis): @@ -153,10 +164,10 @@ class AxisTree(Object): Parameters ---------- axis_parent_map: Dict - A dictionary that maps Axis to parent axis name, value is None if there is not parent axis. + A dictionary that maps axis name to parent axis name, value is None if there is not parent axis. """ - axis_parent_map: Dict[Axis, Optional[str]] + axis_parent_map: Dict[str, Optional[str]] def __init__(self, axis_parent_map) -> None: keys = list(axis_parent_map.keys()) @@ -172,9 +183,6 @@ class SparseBuffer(Object): Parameters ---------- - tree : AxisTree - The axis dependency tree of the sparse buffer - axes : List[Axis] The axes of the sparse buffer @@ -183,20 +191,15 @@ class SparseBuffer(Object): name : str The name of the sparse buffer - - dtype : Optional[str] - The data type of the sparse buffer """ - tree: AxisTree axes: List[Axis] data: Buffer name: str - def __init__(self, tree, axes, data, name, dtype=None): - dtype = "float32" if dtype is None else dtype + def __init__(self, axes, data, name): self.__init_handle_by_constructor__( - _ffi_api.SparseBuffer, tree, axes, data, name, dtype # type: ignore + _ffi_api.SparseBuffer, axes, data, name # type: ignore ) @@ -214,7 +217,7 @@ class SpIterVar(Object): kind : int The kind of the SpIterVar - + is_reduction : bool Whether the SpIterVar is a reduction iterator @@ -237,4 +240,3 @@ def __init__(self, var, max_extent, kind, axis=None): self.__init_handle_by_constructor__( _ffi_api.SpIterVar, var, max_extent, kind, is_reduction, axis # type: ignore ) - diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index f9c9203ed369..95dcfa3d7a2a 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -29,6 +29,19 @@ namespace tvm { namespace tir { +// Axis +TVM_REGISTER_GLOBAL("tir.sparse.GetAxisName").set_body_typed([](Axis axis) { + return axis->GetName(); +}); + +TVM_REGISTER_GLOBAL("tir.sparse.GetAxisLength").set_body_typed([](Axis axis) { + return axis->GetLength(); +}); + +TVM_REGISTER_GLOBAL("tir.sparse.GetAxisIndexType").set_body_typed([](Axis axis) { + return DLDataType2String(axis->GetIndexType()); +}); + // DenseFixedAxis DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) { ObjectPtr node = make_object(); @@ -39,12 +52,14 @@ DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) { TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode); -TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) { - return DenseFixedAxis(name, length); -}); +TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis") + .set_body_typed([](String name, PrimExpr length) { + return DenseFixedAxis(name, length); + }); // DenseVariableAxis -DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) { +DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, + Buffer indptr) { ObjectPtr node = make_object(); node->name = std::move(name); node->length = std::move(length); @@ -56,11 +71,13 @@ TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis") .set_body_typed([](String name, PrimExpr length, Buffer indptr) { - return DenseVariableAxis(name, length, indptr); + return DenseVariableAxis( + name, length, indptr); }); // SparseFixedAxis -SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols) { +SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, + PrimExpr num_cols) { ObjectPtr node = make_object(); node->name = std::move(name); node->length = std::move(length); @@ -72,14 +89,16 @@ SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, P TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis") - .set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr num_cols) { + .set_body_typed([](String name, PrimExpr length, Buffer indices, + PrimExpr num_cols) { return SparseFixedAxis(name, length, indices, num_cols); }); // SparseVariableAxis -SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indptr, - Buffer indices) { - ObjectPtr node = make_object(); +SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, + Buffer indptr, Buffer indices) { + ObjectPtr node = + make_object(); node->name = std::move(name); node->length = std::move(length); node->indptr = std::move(indptr); @@ -90,39 +109,31 @@ SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indp TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis") - .set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) { - return SparseVariableAxis(name, length, indptr, indices); + .set_body_typed([](String name, PrimExpr length, Buffer indptr, + Buffer indices) { + return SparseVariableAxis( + name, length, indptr, indices); }); // AxisTree -AxisTree::AxisTree(Array axes, Array> axis_parent_names) { - CHECK_EQ(axes.size(), axis_parent_names.size()) - << "ValueError: The axes array should have the same length as axis_parent_names " +AxisTree::AxisTree(Array axis_names, + Array> axis_parent_names) { + CHECK_EQ(axis_names.size(), axis_parent_names.size()) + << "ValueError: The axis_names array should have the same length as " + "axis_parent_names " "array."; ObjectPtr node = make_object(); - Axis root = Downcast(RootAxis()); - for (const Axis& axis : axes) { - // update axis map - String name = axis->name; - CHECK(node->axis_map.find(name) != node->axis_map.end()) << "ValueError: duplicate axis names."; - node->axis_map[name] = axis; - } - for (size_t i = 0; i < axes.size(); i++) { + for (size_t i = 0; i < axis_names.size(); i++) { // update parent map & children map - Axis axis = axes[i]; + String axis_name = axis_names[i]; Optional parent_name = axis_parent_names[i]; - if (parent_name.get() != nullptr) { - CHECK(node->axis_map.find(parent_name.value()) != node->axis_map.end()) - << "ValueError: Parent axis name doesn't exist."; - } - Axis parent_axis = (parent_name.get() != nullptr) ? node->axis_map[parent_name.value()] : root; - node->parent[axis] = parent_axis; - if (node->children.find(parent_axis) != node->children.end()) { - node->children[parent_axis].push_back(axis); + node->parent[axis_name] = parent_name; + if (node->children.find(parent_name) != node->children.end()) { + node->children[parent_name].push_back(axis_name); } else { - Array children; - children.push_back(axis); - node->children[parent_axis] = std::move(children); + Array children; + children.push_back(axis_name); + node->children[parent_name] = std::move(children); } } data_ = std::move(node); @@ -131,32 +142,30 @@ AxisTree::AxisTree(Array axes, Array> axis_parent_names) TVM_REGISTER_NODE_TYPE(AxisTreeNode); TVM_REGISTER_GLOBAL("tir.sparse.AxisTree") - .set_body_typed([](Array axes, Array> axis_parent_names) { - return AxisTree(axes, axis_parent_names); + .set_body_typed([](Array axis_names, + Array> axis_parent_names) { + return AxisTree(axis_names, axis_parent_names); }); // SparseBuffer -SparseBuffer::SparseBuffer(AxisTree tree, Array axes, Buffer data, String name, - DataType dtype) { +SparseBuffer::SparseBuffer(Array axes, Buffer data, String name) { ObjectPtr node = make_object(); - node->tree = std::move(tree); node->axes = std::move(axes); node->data = std::move(data); node->name = std::move(name); - node->dtype = dtype; data_ = std::move(node); } TVM_REGISTER_NODE_TYPE(SparseBufferNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer") - .set_body_typed([](AxisTree tree, Array axes, Buffer data, String name, DataType dtype) { - return SparseBuffer(tree, axes, data, name, dtype); + .set_body_typed([](Array axes, Buffer data, String name) { + return SparseBuffer(axes, data, name); }); // SpIterVar -SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction, - Optional axis) { +SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, + bool is_reduction, Optional axis) { ObjectPtr node = make_object(); arith::Analyzer ana; diff --git a/tests/python/unittest/test_tir_sparse_buffer.py b/tests/python/unittest/test_tir_sparse_buffer.py index fae6dde1af7a..4bc7423e31d6 100644 --- a/tests/python/unittest/test_tir_sparse_buffer.py +++ b/tests/python/unittest/test_tir_sparse_buffer.py @@ -22,11 +22,12 @@ def test_format_tree_creation(): j = tir.sparse.DenseFixedAxis('j', 128) k = tir.sparse.DenseFixedAxis('k', 128) tree = tir.sparse.AxisTree({ - i: None, - j: None, - k: None + 'i': None, + 'j': None, + 'k': None }) print(tree) + print(i, j, k) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_sparse_scripts.py b/tests/python/unittest/test_tir_sparse_scripts.py new file mode 100644 index 000000000000..4a80f21164a0 --- /dev/null +++ b/tests/python/unittest/test_tir_sparse_scripts.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.tir as tir +import tvm.te as te +from tvm.script import tir as T + + +@T.prim_func +def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None: + n = T.var("int32") + m = T.var("int32") + k = T.var("int32") + nnz = T.var("int32") + I = T.dense_fixed("I", n, "int32") + J = T.sparse_variable("J", (m, nnz), (indptr, indices), "int32") + K = T.dense_fixed("K", k, "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (T.to_dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, K), "float32") + with T.iter((T.cord(I), T.cord(J), T.cord(K)), "SRS", "csrmm") as [vi, vj, vk]: + with T.init(): + C[vi, vk] = 0. + C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] + + +@T.prim_func +def csr_reduce(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle) -> None: + n = T.var("int32") + m = T.var("int32") + nnz = T.var("int32") + I = T.dense_fixed("I", n, "int32") + J = T.sparse_variable("J", (m, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + with T.iter((tir.cord(I), tir.pos(J)), "SR", "csr_reduce") as [vi, vj]: + with T.init(): + B[vi] = 0. + B[vi] = B[vi] + A[vi, vj] + + +@T.prim_func +def bsrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None: + nb = T.var("int32") + mb = T.var("int32") + nnzb = T.var("int32") + blk = T.var("int32") + feat_size = T.var("int32") + I = T.dense_fixed("I", nb, "int32") + J = T.sparse_variable("J", (mb, nnzb), (indptr, indices), "int32") + BI = T.dense_fixed("BI", blk, "int32") + BJ = T.dense_fixed("BJ", blk, "int32") + F = T.dense_fixed("F", feat_size, "int32") + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.to_dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") + + with T.iter((T.cord(I), T.pos(J), T.cord(BI), T.cord(BJ), T.cord(F)), "SRSSS", "bsrmm") as [vi, vj, vbi, vbj, vf]: + with T.init(): + C[vi, vbi, vf] = 0. + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +def test_csrmm(): + pass + + +def test_csr_reduce(): + pass + + +def test_bsrmm(): + pass + + +if __name__ == "__main__": + test_csrmm() + test_csr_reduce() + test_bsrmm() + + + \ No newline at end of file