6
6
* to you under the Apache License, Version 2.0 (the
7
7
* "License"); you may not use this file except in compliance
8
8
* with the License. You may obtain a copy of the License at
9
- *
9
+ *
10
10
* http://www.apache.org/licenses/LICENSE-2.0
11
- *
11
+ *
12
12
* Unless required by applicable law or agreed to in writing,
13
13
* software distributed under the License is distributed on an
14
14
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
21
21
* Copyright (c) 2017 by Contributors
22
22
* \file Use external cblas library call.
23
23
*/
24
+ #include < dmlc/logging.h>
24
25
#include < tvm/runtime/registry.h>
25
26
#include < tvm/runtime/util.h>
26
- #include < dmlc/logging.h>
27
27
#include " gemm_common.h"
28
28
29
-
30
29
extern " C" {
31
30
#if USE_MKL_BLAS == 1
32
31
#include < mkl_cblas.h>
@@ -40,56 +39,148 @@ namespace contrib {
40
39
41
40
using namespace runtime ;
42
41
43
- inline CBLAS_TRANSPOSE BooleanToTranspose (bool trans) {
44
- return trans ? CblasTrans : CblasNoTrans;
45
- }
42
+ inline CBLAS_TRANSPOSE BooleanToTranspose (bool trans) { return trans ? CblasTrans : CblasNoTrans; }
46
43
47
44
struct CblasSgemmOp {
48
45
typedef float TDatatype;
49
- void operator ()(bool ta, bool tb,
50
- int M, int N, int K,
51
- float alpha, float * A, int lda,
52
- float * B, int ldb,
53
- float beta, float * C, int ldc) {
54
- cblas_sgemm (CblasColMajor,
55
- BooleanToTranspose (ta),
56
- BooleanToTranspose (tb),
57
- M, N, K,
58
- alpha, A, lda,
59
- B, ldb,
60
- beta, C, ldc);
46
+ void operator ()(bool ta, bool tb, int M, int N, int K, float alpha, float * A, int lda, float * B,
47
+ int ldb, float beta, float * C, int ldc) {
48
+ cblas_sgemm (CblasColMajor, BooleanToTranspose (ta), BooleanToTranspose (tb), M, N, K, alpha, A,
49
+ lda, B, ldb, beta, C, ldc);
61
50
}
62
51
};
63
52
64
53
struct CblasDgemmOp {
65
54
typedef double TDatatype;
66
- void operator ()(bool ta, bool tb,
67
- int M, int N, int K,
68
- double alpha, double * A, int lda,
69
- double * B, int ldb,
70
- double beta, double * C, int ldc) {
71
- cblas_dgemm (CblasColMajor,
72
- BooleanToTranspose (ta),
73
- BooleanToTranspose (tb),
74
- M, N, K,
75
- alpha, A, lda,
76
- B, ldb,
77
- beta, C, ldc);
55
+ void operator ()(bool ta, bool tb, int M, int N, int K, double alpha, double * A, int lda,
56
+ double * B, int ldb, double beta, double * C, int ldc) {
57
+ cblas_dgemm (CblasColMajor, BooleanToTranspose (ta), BooleanToTranspose (tb), M, N, K, alpha, A,
58
+ lda, B, ldb, beta, C, ldc);
78
59
}
79
60
};
80
61
62
+ struct CblasSgemmBatchOp {
63
+ typedef float TDatatype;
64
+ void operator ()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float * A,
65
+ int a_stride, int lda, float * B, int b_stride, int ldb, float beta, float * C,
66
+ int c_stride, int ldc) {
67
+ CBLAS_TRANSPOSE trans_a = BooleanToTranspose (ta);
68
+ CBLAS_TRANSPOSE trans_b = BooleanToTranspose (tb);
69
+ #if USE_MKL_BLAS == 1
70
+ std::vector<const float *> A_array (batch_size);
71
+ std::vector<const float *> B_array (batch_size);
72
+ std::vector<float *> C_array (batch_size);
73
+ for (int i = 0 ; i < batch_size; ++i) {
74
+ A_array[i] = A + i * a_stride;
75
+ B_array[i] = B + i * b_stride;
76
+ C_array[i] = C + i * c_stride;
77
+ }
78
+ cblas_sgemm_batch (CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data (), &lda,
79
+ B_array.data (), &ldb, &beta, C_array.data (), &ldc, 1 , &batch_size);
80
+ #else
81
+ for (int i = 0 ; i < batch_size; ++i) {
82
+ cblas_sgemm (CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
83
+ A += a_stride;
84
+ B += b_stride;
85
+ C += c_stride;
86
+ }
87
+ #endif
88
+ }
89
+ };
90
+
91
+ struct CblasSgemmBatchIterativeOp {
92
+ typedef float TDatatype;
93
+ void operator ()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float * A,
94
+ int a_stride, int lda, float * B, int b_stride, int ldb, float beta, float * C,
95
+ int c_stride, int ldc) {
96
+ CBLAS_TRANSPOSE trans_a = BooleanToTranspose (ta);
97
+ CBLAS_TRANSPOSE trans_b = BooleanToTranspose (tb);
98
+ for (int i = 0 ; i < batch_size; ++i) {
99
+ cblas_sgemm (CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
100
+ A += a_stride;
101
+ B += b_stride;
102
+ C += c_stride;
103
+ }
104
+ }
105
+ };
106
+
107
+ struct CblasDgemmBatchOp {
108
+ typedef double TDatatype;
109
+ void operator ()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double * A,
110
+ int a_stride, int lda, double * B, int b_stride, int ldb, double beta, double * C,
111
+ int c_stride, int ldc) {
112
+ CBLAS_TRANSPOSE trans_a = BooleanToTranspose (ta);
113
+ CBLAS_TRANSPOSE trans_b = BooleanToTranspose (tb);
114
+ #if USE_MKL_BLAS == 1
115
+ std::vector<const double *> A_array (batch_size);
116
+ std::vector<const double *> B_array (batch_size);
117
+ std::vector<double *> C_array (batch_size);
118
+ for (int i = 0 ; i < batch_size; ++i) {
119
+ A_array[i] = A + i * a_stride;
120
+ B_array[i] = B + i * b_stride;
121
+ C_array[i] = C + i * c_stride;
122
+ }
123
+ cblas_dgemm_batch (CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data (), &lda,
124
+ B_array.data (), &ldb, &beta, C_array.data (), &ldc, 1 , &batch_size);
125
+ #else
126
+ for (int i = 0 ; i < batch_size; ++i) {
127
+ cblas_dgemm (CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
128
+ A += a_stride;
129
+ B += b_stride;
130
+ C += c_stride;
131
+ }
132
+ #endif
133
+ }
134
+ };
135
+
136
+ struct CblasDgemmBatchIterativeOp {
137
+ typedef double TDatatype;
138
+ void operator ()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double * A,
139
+ int a_stride, int lda, double * B, int b_stride, int ldb, double beta, double * C,
140
+ int c_stride, int ldc) {
141
+ CBLAS_TRANSPOSE trans_a = BooleanToTranspose (ta);
142
+ CBLAS_TRANSPOSE trans_b = BooleanToTranspose (tb);
143
+ for (int i = 0 ; i < batch_size; ++i) {
144
+ cblas_dgemm (CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
145
+ A += a_stride;
146
+ B += b_stride;
147
+ C += c_stride;
148
+ }
149
+ }
150
+ };
81
151
82
152
// matrix multiplication for row major
83
153
TVM_REGISTER_GLOBAL (" tvm.contrib.cblas.matmul" )
84
- .set_body([](TVMArgs args, TVMRetValue *ret) {
85
- DLTensor* A = args[0 ];
86
- CHECK (TypeMatch (A->dtype , kDLFloat , 32 ) ||
87
- TypeMatch (A->dtype , kDLFloat , 64 ));
154
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
155
+ DLTensor* A = args[0 ];
156
+ CHECK (TypeMatch (A->dtype , kDLFloat , 32 ) || TypeMatch (A->dtype , kDLFloat , 64 ));
157
+
158
+ if (TypeMatch (A->dtype , kDLFloat , 32 ))
159
+ CallGemm (args, ret, CblasSgemmOp ());
160
+ else
161
+ CallGemm (args, ret, CblasDgemmOp ());
162
+ });
88
163
89
- if (TypeMatch (A->dtype , kDLFloat , 32 ))
90
- CallGemm (args, ret, CblasSgemmOp ());
91
- else
92
- CallGemm (args, ret, CblasDgemmOp ());
93
- });
164
+ TVM_REGISTER_GLOBAL (" tvm.contrib.cblas.batch_matmul" )
165
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
166
+ DLTensor* A = args[0 ];
167
+ CHECK (TypeMatch (A->dtype , kDLFloat , 32 ) || TypeMatch (A->dtype , kDLFloat , 64 ));
168
+ if (TypeMatch (A->dtype , kDLFloat , 32 )) {
169
+ CallBatchGemm (args, ret, CblasSgemmBatchOp ());
170
+ } else {
171
+ CallBatchGemm (args, ret, CblasDgemmBatchOp ());
172
+ }
173
+ });
174
+
175
+ TVM_REGISTER_GLOBAL (" tvm.contrib.cblas.batch_matmul_iterative" )
176
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
177
+ DLTensor* A = args[0 ];
178
+ CHECK (TypeMatch (A->dtype , kDLFloat , 32 ) || TypeMatch (A->dtype , kDLFloat , 64 ));
179
+ if (TypeMatch (A->dtype , kDLFloat , 32 )) {
180
+ CallBatchGemm (args, ret, CblasSgemmBatchIterativeOp ());
181
+ } else {
182
+ CallBatchGemm (args, ret, CblasDgemmBatchIterativeOp ());
183
+ }
184
+ });
94
185
} // namespace contrib
95
186
} // namespace tvm
0 commit comments