@@ -62,10 +62,6 @@ def batch_matmul(
62
62
output : tvm.te.Tensor
63
63
3-D with shape [batch, M, N]
64
64
"""
65
- if cfg .is_fallback and not transpose_a and transpose_b :
66
- B , N , K = get_const_tuple (tensor_a .shape )
67
- _default_batch_matmul_config (cfg , B , N , K )
68
-
69
65
return nn .batch_matmul (
70
66
tensor_a ,
71
67
tensor_b ,
@@ -145,20 +141,32 @@ def _default_batch_matmul_config(cfg, M, N, K):
145
141
cfg ["tile_y" ] = SplitEntity ([M // y_bn , y_bn ])
146
142
147
143
148
- def batch_matmul_blas_common (cfg , x , y , out_shape , lib ):
149
- """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
150
- data in batch, using one of BLAS libraries. Supports broadcasting in batch dimension.
144
+ def batch_matmul_blas_common (cfg , tensor_a , tensor_b , out_shape , trans_a , trans_b , lib ):
145
+ """Computes batch matrix multiplication of `tensor_a` and `tensor_b` when `tensor_a` and
146
+ `tensor_b` are data in batch, using one of BLAS libraries. Supports broadcasting in batch
147
+ dimension.
151
148
152
149
Parameters
153
150
----------
154
151
cfg : ConfigSpace
155
152
Autotvm tuning space config file
156
- x : tvm.te.Tensor
157
- 3-D with shape [batch, M, K]
158
- y : tvm.te.Tensor
159
- 3-D with shape [batch, N, K]
160
- out_shape : tuple or None
161
- Shape of the output
153
+
154
+ tensor_a : tvm.te.Tensor
155
+ 3-D with shape [batch, M, K] or [batch, K, M].
156
+
157
+ tensor_b : tvm.te.Tensor
158
+ 3-D with shape [batch, K, N] or [batch, N, K].
159
+
160
+ out_shape : List[Optional]
161
+ Explicit intended output shape of the computation. Can be useful in cases
162
+ with dynamic input shapes.
163
+
164
+ trans_a : Optional[bool] = False
165
+ Whether the first tensor is in transposed format.
166
+
167
+ trans_b : Optional[bool] = True
168
+ Whether the second tensor is in transposed format.
169
+
162
170
lib : A contrib module which implements batch_matmul function
163
171
cblas and mkl are supported
164
172
@@ -167,23 +175,33 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib):
167
175
output : tvm.te.Tensor
168
176
3-D with shape [batch, M, N]
169
177
"""
170
- assert len (x .shape ) == 3 and len (y .shape ) == 3 , "only support 3-dim batch_matmul"
171
- XB , M , XK = get_const_tuple (x .shape )
172
- YB , N , YK = get_const_tuple (y .shape )
178
+ assert len (tensor_a .shape ) == 3 and len (tensor_b .shape ) == 3 , "only support 3-dim batch_matmul"
179
+ if trans_a :
180
+ XB , XK , M = get_const_tuple (tensor_a .shape )
181
+ else :
182
+ XB , M , XK = get_const_tuple (tensor_a .shape )
183
+ if trans_b :
184
+ YB , N , YK = get_const_tuple (tensor_b .shape )
185
+ else :
186
+ YB , YK , N = get_const_tuple (tensor_a .shape )
173
187
assert (XB == YB ) or (YB == 1 ) or (XB == 1 ), "batch dimension doesn't match"
174
188
assert XK == YK , "shapes of x and y is inconsistent"
175
189
if out_shape is not None :
176
190
assert out_shape [0 ] in (XB , YB ), "got invalid output shape"
177
191
assert out_shape [1 ] == M , "got invalid output shape"
178
192
assert out_shape [2 ] == N , "got invalid output shape"
179
193
cfg .add_flop (XB * M * N * XK * 2 )
180
- return lib .batch_matmul (x , y , False , True )
194
+ return lib .batch_matmul (tensor_a , tensor_b , trans_a , trans_b )
181
195
182
196
183
197
@autotvm .register_topi_compute ("batch_matmul_cblas.x86" )
184
- def batch_matmul_cblas (cfg , x , y , out_shape = None ):
198
+ def batch_matmul_cblas (
199
+ cfg , tensor_a , tensor_b , out_shape = None , out_dtype = None , transpose_a = False , transpose_b = True
200
+ ):
185
201
"""Compute batch_matmul using cblas"""
186
- return batch_matmul_blas_common (cfg , x , y , out_shape , cblas )
202
+ return batch_matmul_blas_common (
203
+ cfg , tensor_a , tensor_b , out_shape , transpose_a , transpose_b , cblas
204
+ )
187
205
188
206
189
207
@autotvm .register_topi_schedule ("batch_matmul_cblas.x86" )
@@ -193,9 +211,13 @@ def schedule_batch_matmul_cblas(_, outs):
193
211
194
212
195
213
@autotvm .register_topi_compute ("batch_matmul_mkl.x86" )
196
- def batch_matmul_mkl (cfg , x , y , out_shape = None ):
214
+ def batch_matmul_mkl (
215
+ cfg , tensor_a , tensor_b , out_shape = None , out_dtype = None , transpose_a = False , transpose_b = True
216
+ ):
197
217
"""Compute batch_matmul using mkl"""
198
- return batch_matmul_blas_common (cfg , x , y , out_shape , mkl )
218
+ return batch_matmul_blas_common (
219
+ cfg , tensor_a , tensor_b , out_shape , transpose_a , transpose_b , mkl
220
+ )
199
221
200
222
201
223
@autotvm .register_topi_schedule ("batch_matmul_mkl.x86" )
0 commit comments