Skip to content

Commit a52ab12

Browse files
antinucleondhruvaray
authored andcommitted
[Blocksparse] Pipeline for lowering dense model to sparse-dense (apache#5377)
1 parent f7ca70d commit a52ab12

File tree

14 files changed

+798
-0
lines changed

14 files changed

+798
-0
lines changed

python/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def get_package_data_files():
156156
zip_safe=False,
157157
install_requires=[
158158
'numpy',
159+
'scipy',
159160
'decorator',
160161
'attrs',
161162
'psutil',

python/tvm/relay/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from . import frontend
5454
from . import backend
5555
from . import quantize
56+
from . import data_dep_optimization
5657

5758
# Dialects
5859
from . import qnn

python/tvm/relay/analysis/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@
2828

2929
# Feature
3030
from . import feature
31+
from . import sparse_dense

python/tvm/relay/analysis/analysis.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,21 @@ def extract_fused_functions(mod):
333333
for hash_, func in ret_mod.functions.items():
334334
ret[hash_] = func
335335
return ret
336+
337+
338+
def search_fc_transpose(expr):
339+
"""Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0]))
340+
341+
This function is used in the data_dep_optimization.simplify_fc_transpose method
342+
343+
Parameters
344+
----------
345+
expr : tvm.relay.Expr
346+
347+
Returns
348+
-------
349+
ret : Array[String]
350+
Array of weight variable name in pattern y = nn.dense(x, transpose(w, [1, 0]))
351+
"""
352+
ret = _ffi_api.search_fc_transpose(expr)
353+
return ret
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=no-else-return
18+
# pylint: disable=unidiomatic-typecheck
19+
"""
20+
This file contains helper functions for convert dense model
21+
to block sparse model
22+
"""
23+
from collections import namedtuple
24+
import numpy as np
25+
import scipy.sparse as sp
26+
import tvm
27+
from . import _ffi_api
28+
29+
30+
SparseAnalysisResult = namedtuple("SparseAnalysisResult", [
31+
"weight_name",
32+
"weight_shape",
33+
])
34+
35+
def _search_dense_op_weight(expr):
36+
"""Search name of weight in all ```nn.dense``` operator
37+
This is a helpful function to determine which param need
38+
to be converted to sparse
39+
40+
Parameters
41+
----------
42+
expr : relay.Expr
43+
Expr will be searched
44+
45+
Returns
46+
-------
47+
ret : Array[String]
48+
name of weight in all ``nn.dense``` operator
49+
"""
50+
return _ffi_api.search_dense_op_weight(expr)
51+
52+
53+
def process_params(expr, params, block_size, sparsity_threshold):
54+
"""[summary]
55+
56+
Parameters
57+
----------
58+
expr : Relay.Expr
59+
Expr of the network
60+
params : Dict[String, tvm.nd.array]
61+
parameters of the network
62+
block_size : Tuple(int, int)
63+
Blocksize in BSR matrix
64+
sparsity_threshold : float
65+
Minimal sparsity requirement for converting to sparse operation
66+
67+
Returns
68+
-------
69+
ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]]
70+
return names of qualified dense weight and the shape in BSR format
71+
"""
72+
memo = SparseAnalysisResult(weight_name=[], weight_shape=[])
73+
weight_names = _search_dense_op_weight(expr)
74+
for name in weight_names:
75+
name = str(name)
76+
w_np = params[name].asnumpy()
77+
sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size)
78+
if sparsity >= sparsity_threshold:
79+
sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size)
80+
# remove dense weight
81+
del params[name]
82+
memo.weight_name.append(name)
83+
memo.weight_shape.append(list(sparse_weight.data.shape) +
84+
list(sparse_weight.indices.shape) +
85+
list(sparse_weight.indptr.shape))
86+
params[name + ".data"] = tvm.nd.array(sparse_weight.data)
87+
params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
88+
params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)
89+
ret = SparseAnalysisResult(
90+
weight_name=tvm.runtime.convert(memo.weight_name),
91+
weight_shape=tvm.runtime.convert(memo.weight_shape)
92+
)
93+
return ret
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#pylint: disable=unused-argument, not-context-manager
18+
"""Optimizations involves changing of paramters"""
19+
20+
from . import bsr_dense
21+
from . import simplify_fc_transpose
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#pylint: disable=unused-argument, not-context-manager
18+
"""Automatic convert model from dense to block sparse"""
19+
20+
from tvm import relay
21+
from tvm.relay.analysis.sparse_dense import process_params
22+
23+
from .utils import _run_opt_pass
24+
25+
def convert(func, params, blocksize, sparsity_threshold):
26+
"""Convert a dense func and according parameters to block sparse
27+
28+
Parameters
29+
----------
30+
func : relay.Expr
31+
Expr will be optimized to sparse operation
32+
params : Dict[Srting, tvm.nd.array]
33+
Parameters of the Expr
34+
blocksize : Tuple(int, int)
35+
Blocksize for BSR matrix
36+
sparsity_threshold : float
37+
Minimal sparsity requirement for converting.
38+
If weight sparsity is lower than this threshold,
39+
the dense operation will be kept.
40+
41+
Returns
42+
-------
43+
new_func: relay.Expr
44+
Mutated Expr with sparse operations
45+
46+
params: Dict[Srting, tvm.nd.array]
47+
New params with BSR matrix for mutated Expr
48+
"""
49+
weight_info = process_params(func, params, blocksize, sparsity_threshold)
50+
new_func = _run_opt_pass(
51+
func,
52+
relay.transform.DenseToSparse(
53+
weight_info.weight_name,
54+
weight_info.weight_shape
55+
)
56+
)
57+
return new_func, params
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#pylint: disable=unused-argument, not-context-manager
18+
"""Automatic optimize fc tranpose"""
19+
import numpy as np
20+
21+
import tvm
22+
from tvm import relay
23+
from tvm.relay.analysis import search_fc_transpose
24+
25+
from .utils import _run_opt_pass
26+
27+
28+
def convert(func, params):
29+
"""convert all ```y = nn.dense(x, transpose(w, [1, 0]))``` to
30+
```y = nn.dense(x, wt)```
31+
32+
Parameters
33+
----------
34+
func : relay.Expr
35+
Expr will be optimized
36+
params : Dict[String, tvm.nd.array]
37+
Parameters of Expr
38+
39+
Returns
40+
-------
41+
new_func : relay.Expr
42+
Mutated Expr from ```y = nn.dense(x, transpose(w, [1, 0]))``` to
43+
```y = nn.dense(x, wt)```
44+
params: Dict[String, tvm.nd.array]
45+
Parameters of mutated Expr, with weights pre-transposed
46+
"""
47+
weight_info = search_fc_transpose(func)
48+
for item in weight_info:
49+
name = str(item)
50+
w_np = params[name].asnumpy()
51+
new_w = np.transpose(w_np, axes=[1, 0])
52+
params[name + ".T"] = tvm.nd.array(new_w)
53+
del params[name]
54+
new_func = _run_opt_pass(
55+
func,
56+
relay.transform.SimplifyFCTranspose(
57+
weight_info,
58+
)
59+
)
60+
return new_func, params
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#pylint: disable=unused-argument, not-context-manager
18+
"""Utils functions for optimizations"""
19+
20+
import tvm
21+
22+
def _run_opt_pass(expr, opt_pass):
23+
"""Helper function to run pass
24+
25+
Parameters
26+
----------
27+
expr : relay.Expr
28+
Expr will be optimized
29+
opt_pass : relay.Pass
30+
Optimization pass
31+
32+
Returns
33+
-------
34+
ret: relay.Expr
35+
Optimized Expr by running opt_pass
36+
"""
37+
assert isinstance(opt_pass, tvm.transform.Pass)
38+
mod = tvm.IRModule.from_expr(expr)
39+
mod = opt_pass(mod)
40+
return mod["main"]

python/tvm/relay/transform/transform.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,3 +839,43 @@ def visit_var(self, var):
839839
return relay.Var(var.name_hint, relay.TensorType(new_shape, ty.dtype))
840840
return var
841841
return ChangeBatchMutator().visit(func)
842+
843+
844+
def DenseToSparse(weight_name, weight_shape):
845+
"""
846+
Rewrite qualified ```nn.dense operation``` to ```nn.sparse_dense```
847+
This pass is used in ```data_dep_optimization.bsr_dense```
848+
Parameters of this pass is generated by ```analysis.sparse_dense.process_params```
849+
850+
Parameters
851+
----------
852+
weight_name: Array[String]
853+
Names of weights which qualified sparse contrains
854+
855+
weight_shape: Array[Array[IntImm]]
856+
Weights shape in BSR format.
857+
858+
Returns
859+
-------
860+
ret : tvm.transform.Pass
861+
The registered DenseToSparse pass.
862+
"""
863+
return _ffi_api.DenseToSparse(weight_name, weight_shape)
864+
865+
def SimplifyFCTranspose(target_weight_name):
866+
"""
867+
Rewrite ```y = nn.dense(x, transpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)```
868+
This pass is used in ```data_dep_optimization.simplify_fc_transpose```
869+
870+
Parameters
871+
----------
872+
weight_name: Array[String]
873+
Names of weights which qualified ```y = nn.dense(x, transpose(w, [1, 0]))```
874+
This parameter is generated by ```analysis.search_fc_transpose``` function
875+
876+
Returns
877+
-------
878+
ret : tvm.transform.Pass
879+
The registered SimplifyFCTranspose pass.
880+
"""
881+
return _ffi_api.SimplifyFCTranspose(target_weight_name)

0 commit comments

Comments
 (0)