Skip to content

Commit 5e4f27b

Browse files
tkonoligeTushar Dey
authored andcommitted
Faster sparse_dense on GPUs (apache#6580)
* Faster sparse_dense on GPUs. This new sparse_dense requires a padded matrix, so a new op `sparse_dense_padded` has been added. AlterOpLayout should transform `sparse_dense` to `sparse_dense_padded` when possible on the gpu. * formatting * more formatting * Check that alteroplayout is definedbefore using it * check if FTVMAlterOpLayout exists before using it * formatting * restore message passing * Fix sparse_dense and sparse_dense_padded docs * Fix old sparse_dense, autotvm and sparse_dense dont play well together * Remove unused imports * clarify warp count in cuda_transpose * Document multidimensional access * Warn users not to use sparse_dense_padded * rename nn.sparse_dense_padded to nn.internal.sparse_dense_padded
1 parent eef3495 commit 5e4f27b

File tree

16 files changed

+537
-60
lines changed

16 files changed

+537
-60
lines changed

python/tvm/relay/op/nn/_nn.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,22 @@ def compute_sparse_dense(attrs, inputs, out_type):
7575
reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
7676

7777

78+
@reg.register_alter_op_layout("nn.sparse_dense")
79+
def alter_op_layout_sparse_dense(attrs, inputs, tinfos, out_type):
80+
"""Alternate the layout of sparse_dense"""
81+
return topi.nn.sparse_dense_alter_layout(attrs, inputs, tinfos, out_type)
82+
83+
84+
@reg.register_compute("nn.internal.sparse_dense_padded")
85+
def compute_sparse_dense_padded(attrs, inputs, out_type):
86+
"""Compute definition of sparse_dense_padded"""
87+
raise NotImplementedError("nn.internal.sparse_dense_padded is only available on cuda")
88+
89+
90+
reg.register_strategy("nn.internal.sparse_dense_padded", strategy.sparse_dense_padded_strategy)
91+
reg.register_pattern("nn.internal.sparse_dense_padded", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
92+
93+
7894
# sparse_transpose
7995
@reg.register_compute("nn.sparse_transpose")
8096
def compute_sparse_transpose(attrs, inputs, out_type):

python/tvm/relay/op/nn/nn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2016,15 +2016,17 @@ def sparse_dense(data, weight):
20162016
data : tvm.relay.Expr
20172017
The input data for the matrix multiplication
20182018
2019-
weight : namedtuple.
2019+
weight : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
20202020
The sparse weight matrix for the matrix multiplication.
20212021
20222022
Returns
20232023
-------
20242024
result: tvm.relay.Expr
20252025
The computed result.
20262026
"""
2027-
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
2027+
if hasattr(weight, "indices"):
2028+
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
2029+
return _make.sparse_dense(data, weight[0], weight[1], weight[2])
20282030

20292031

20302032
def sparse_transpose(x):

python/tvm/relay/op/strategy/cuda.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,19 @@ def sparse_dense_strategy_cuda(attrs, inputs, out_type, target):
633633
return strategy
634634

635635

636+
@sparse_dense_padded_strategy.register(["cuda", "gpu"])
637+
def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target):
638+
"""sparse dense cuda strategy"""
639+
strategy = _op.OpStrategy()
640+
strategy.add_implementation(
641+
wrap_compute_sparse_dense(topi.cuda.sparse_dense_padded),
642+
wrap_topi_schedule(topi.cuda.schedule_sparse_dense_padded),
643+
name="sparse_dense_padded.cuda",
644+
plevel=10,
645+
)
646+
return strategy
647+
648+
636649
@argsort_strategy.register(["cuda", "gpu"])
637650
def argsort_strategy_cuda(attrs, inputs, out_type, target):
638651
"""argsort cuda strategy"""

python/tvm/relay/op/strategy/generic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,12 @@ def sparse_dense_strategy(attrs, inputs, out_type, target):
724724
return strategy
725725

726726

727+
@override_native_generic_func("sparse_dense_padded_strategy")
728+
def sparse_dense_padded_strategy(attrs, inputs, out_type, target):
729+
"""sparse dense padded generic strategy"""
730+
raise NotImplementedError("sparse_dense_padded is only implemented for cuda")
731+
732+
727733
# sparse_transpose
728734
@generic_func
729735
def schedule_sparse_transpose(attrs, outs, target):

python/tvm/tir/ir_builder.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ class BufferVar(ObjectGeneric):
4242
4343
Do not create it directly, create use IRBuilder.
4444
45+
BufferVars support array access either via a linear index, or, if given a
46+
shape, via a multidimensional index.
47+
4548
Examples
4649
--------
4750
In the follow example, x is BufferVar.
@@ -55,16 +58,23 @@ class BufferVar(ObjectGeneric):
5558
x = ib.pointer("float32")
5659
x[0] = x[10] + 1
5760
61+
y = ib.allocate("float32", (32, 32))
62+
# Array access using a linear index
63+
y[(2*32) + 31] = 0.
64+
# The same array access using a multidimensional index
65+
y[2, 31] = 0.
66+
5867
See Also
5968
--------
6069
IRBuilder.pointer
6170
IRBuilder.buffer_ptr
6271
IRBuilder.allocate
6372
"""
6473

65-
def __init__(self, builder, buffer_var, content_type):
74+
def __init__(self, builder, buffer_var, shape, content_type):
6675
self._builder = builder
6776
self._buffer_var = buffer_var
77+
self._shape = shape
6878
self._content_type = content_type
6979

7080
def asobject(self):
@@ -74,8 +84,23 @@ def asobject(self):
7484
def dtype(self):
7585
return self._content_type
7686

87+
def _linear_index(self, index):
88+
if not isinstance(index, tuple) or self._shape is None:
89+
return index
90+
assert len(index) == len(self._shape), "Index size (%s) does not match shape size (%s)" % (
91+
len(index),
92+
len(self._shape),
93+
)
94+
dim_size = 1
95+
lidx = 0
96+
for dim, idx in zip(reversed(self._shape), reversed(index)):
97+
lidx += idx * dim_size
98+
dim_size *= dim
99+
return lidx
100+
77101
def __getitem__(self, index):
78102
t = DataType(self._content_type)
103+
index = self._linear_index(index)
79104
if t.lanes > 1:
80105
base = index * t.lanes
81106
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
@@ -87,6 +112,7 @@ def __setitem__(self, index, value):
87112
raise ValueError(
88113
"data type does not match content type %s vs %s" % (value.dtype, self._content_type)
89114
)
115+
index = self._linear_index(index)
90116
t = DataType(self._content_type)
91117
if t.lanes > 1:
92118
base = index * t.lanes
@@ -341,7 +367,7 @@ def allocate(self, dtype, shape, name="buf", scope=None):
341367
if scope:
342368
self.scope_attr(buffer_var, "storage_scope", scope)
343369
self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x))
344-
return BufferVar(self, buffer_var, dtype)
370+
return BufferVar(self, buffer_var, shape, dtype)
345371

346372
def pointer(self, content_type, name="ptr"):
347373
"""Create pointer variable with content type.
@@ -360,22 +386,25 @@ def pointer(self, content_type, name="ptr"):
360386
The buffer var representing the buffer.
361387
"""
362388
buffer_var = _expr.Var(name, dtype="handle")
363-
return BufferVar(self, buffer_var, content_type)
389+
return BufferVar(self, buffer_var, None, content_type)
364390

365-
def buffer_ptr(self, buf):
391+
def buffer_ptr(self, buf, shape=None):
366392
"""Create pointer variable corresponds to buffer ptr.
367393
368394
Parameters
369395
----------
370396
buf : Buffer
371397
The buffer to be extracted.
372398
399+
shape : Tuple
400+
Optional shape of the buffer. Overrides existing buffer shape.
401+
373402
Returns
374403
-------
375404
ptr : BufferVar
376405
The buffer var representing the buffer.
377406
"""
378-
return BufferVar(self, buf.data, buf.dtype)
407+
return BufferVar(self, buf.data, buf.shape if shape is None else shape, buf.dtype)
379408

380409
def likely(self, expr):
381410
"""Add likely tag for expression.

0 commit comments

Comments
 (0)