@@ -54,7 +54,9 @@ def _search_conv2d_op_weight(expr):
54
54
return _ffi_api .search_conv2d_op_weight (expr )
55
55
56
56
57
- def process_params (expr , params , block_size , sparsity_threshold , layout ):
57
+ def process_params (
58
+ expr , params , block_size , sparsity_threshold , layout , kernel_size , reg_task_input = True
59
+ ):
58
60
"""Process parameters of conv2d from dense to sparse.
59
61
60
62
Parameters
@@ -86,14 +88,18 @@ def process_params(expr, params, block_size, sparsity_threshold, layout):
86
88
for name in weight_names :
87
89
name = str (name )
88
90
w_np = params [name ].numpy ()
89
- # currently only support conv2d_1*1
90
- if not (
91
- (w_np .shape [0 ] == 1 and w_np .shape [1 ] == 1 )
92
- or (w_np .shape [2 ] == 1 and w_np .shape [3 ] == 1 )
93
- ):
91
+
92
+ if layout == "NHWC" : # HWIO
93
+ weight_kernel = (w_np .shape [0 ], w_np .shape [1 ])
94
+ elif layout == "NCHW" : # OIHW
95
+ weight_kernel = (w_np .shape [2 ], w_np .shape [3 ])
96
+ if weight_kernel [0 ] != weight_kernel [1 ]:
94
97
continue
95
- sparsity = 1.0 - (np .count_nonzero (w_np ) / w_np .size )
96
- if sparsity >= sparsity_threshold :
98
+
99
+ if weight_kernel [0 ] == kernel_size == 1 :
100
+ sparsity = 1.0 - (np .count_nonzero (w_np ) / w_np .size )
101
+ if sparsity < sparsity_threshold :
102
+ continue
97
103
if layout == "NHWC" :
98
104
w_np = w_np .squeeze ().T
99
105
elif layout == "NCHW" :
@@ -108,19 +114,31 @@ def process_params(expr, params, block_size, sparsity_threshold, layout):
108
114
)
109
115
else :
110
116
sparse_weight_data = sparse_weight .data
117
+ elif weight_kernel [0 ] == kernel_size == 3 :
118
+ if layout == "NHWC" : # HWIO
119
+ w_np = w_np .reshape ((- 1 , w_np .shape [- 1 ])).T
120
+ elif layout == "NCHW" : # OIHW
121
+ w_np = w_np .reshape ((w_np .shape [0 ], - 1 ))
122
+ sparse_weight = sp .bsr_matrix (w_np , blocksize = block_size )
123
+ if 1 - (sparse_weight .nnz / w_np .size ) < sparsity_threshold :
124
+ continue
125
+ sparse_weight_data = sparse_weight .data
126
+ else :
127
+ continue
111
128
112
- # remove dense weight
113
- del params [name ]
114
- memo .weight_name .append (name )
115
- memo .weight_shape .append (
116
- list (sparse_weight_data .shape )
117
- + list (sparse_weight .indices .shape )
118
- + list (sparse_weight .indptr .shape )
119
- )
120
- params [name + ".data" ] = tvm .nd .array (sparse_weight_data )
121
- params [name + ".indices" ] = tvm .nd .array (sparse_weight .indices )
122
- params [name + ".indptr" ] = tvm .nd .array (sparse_weight .indptr )
123
-
129
+ # remove dense weight
130
+ del params [name ]
131
+ memo .weight_name .append (name )
132
+ memo .weight_shape .append (
133
+ list (sparse_weight_data .shape )
134
+ + list (sparse_weight .indices .shape )
135
+ + list (sparse_weight .indptr .shape )
136
+ )
137
+ params [name + ".data" ] = tvm .nd .array (sparse_weight_data )
138
+ params [name + ".indices" ] = tvm .nd .array (sparse_weight .indices )
139
+ params [name + ".indptr" ] = tvm .nd .array (sparse_weight .indptr )
140
+
141
+ if reg_task_input :
124
142
prefix = "sparse_conv2d_bsr_%d_%d_%d_%d_%d_%d_" % (
125
143
w_np .shape [0 ],
126
144
w_np .shape [1 ],
0 commit comments