Skip to content

Commit 39f09e8

Browse files
committed
[Contrib] cblas batch_matmul
1 parent 78a0f47 commit 39f09e8

File tree

5 files changed

+349
-93
lines changed

5 files changed

+349
-93
lines changed

cmake/modules/contrib/BLAS.cmake

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ elseif(USE_BLAS STREQUAL "mkl")
2727
if(NOT IS_DIRECTORY ${USE_MKL_PATH})
2828
set(USE_MKL_PATH /opt/intel/mkl)
2929
endif()
30-
find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
30+
if(APPLE)
31+
find_library(BLAS_LIBRARY NAMES mklml HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
32+
elseif(UNIX)
33+
find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
34+
endif()
3135
include_directories(${USE_MKL_PATH}/include)
3236
list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY})
3337
list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC})

python/tvm/contrib/cblas.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
"""External function interface to BLAS libraries."""
1818
from __future__ import absolute_import as _abs
1919

20-
from .. import api as _api
21-
from .. import intrin as _intrin
20+
from .. import api as _api, intrin as _intrin
2221

23-
def matmul(lhs, rhs, transa=False, transb=False):
22+
23+
def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
2424
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
2525
2626
This function serves as an example on how to call external libraries.
@@ -44,7 +44,50 @@ def matmul(lhs, rhs, transa=False, transb=False):
4444
n = lhs.shape[1] if transa else lhs.shape[0]
4545
m = rhs.shape[0] if transb else rhs.shape[1]
4646
return _api.extern(
47-
(n, m), [lhs, rhs],
47+
(n, m),
48+
[lhs, rhs],
49+
lambda ins, outs: _intrin.call_packed(
50+
"tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb
51+
),
52+
name="C",
53+
**kwargs
54+
)
55+
56+
57+
def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs):
58+
"""Create an extern op that compute batched matrix mult of A and rhs with CBLAS
59+
This function serves as an example on how to call external libraries.
60+
Parameters
61+
----------
62+
lhs : Tensor
63+
The left matrix operand
64+
rhs : Tensor
65+
The right matrix operand
66+
transa : bool
67+
Whether transpose lhs
68+
transb : bool
69+
Whether transpose rhs
70+
Returns
71+
-------
72+
C : Tensor
73+
The result tensor.
74+
"""
75+
b = lhs.shape[0]
76+
n = lhs.shape[2] if transa else lhs.shape[1]
77+
m = rhs.shape[1] if transb else rhs.shape[2]
78+
return _api.extern(
79+
(b, n, m),
80+
[lhs, rhs],
4881
lambda ins, outs: _intrin.call_packed(
49-
"tvm.contrib.cblas.matmul",
50-
ins[0], ins[1], outs[0], transa, transb), name="C")
82+
"tvm.contrib.cblas.batch_matmul"
83+
if not iterative
84+
else "tvm.contrib.cblas.batch_matmul_iterative",
85+
ins[0],
86+
ins[1],
87+
outs[0],
88+
transa,
89+
transb,
90+
),
91+
name="C",
92+
**kwargs
93+
)

src/contrib/cblas/cblas.cc

Lines changed: 131 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -21,12 +21,11 @@
2121
* Copyright (c) 2017 by Contributors
2222
* \file Use external cblas library call.
2323
*/
24+
#include <dmlc/logging.h>
2425
#include <tvm/runtime/registry.h>
2526
#include <tvm/runtime/util.h>
26-
#include <dmlc/logging.h>
2727
#include "gemm_common.h"
2828

29-
3029
extern "C" {
3130
#if USE_MKL_BLAS == 1
3231
#include <mkl_cblas.h>
@@ -40,56 +39,148 @@ namespace contrib {
4039

4140
using namespace runtime;
4241

43-
inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) {
44-
return trans ? CblasTrans : CblasNoTrans;
45-
}
42+
inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; }
4643

4744
struct CblasSgemmOp {
4845
typedef float TDatatype;
49-
void operator()(bool ta, bool tb,
50-
int M, int N, int K,
51-
float alpha, float* A, int lda,
52-
float* B, int ldb,
53-
float beta, float* C, int ldc) {
54-
cblas_sgemm(CblasColMajor,
55-
BooleanToTranspose(ta),
56-
BooleanToTranspose(tb),
57-
M, N, K,
58-
alpha, A, lda,
59-
B, ldb,
60-
beta, C, ldc);
46+
void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
47+
int ldb, float beta, float* C, int ldc) {
48+
cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
49+
lda, B, ldb, beta, C, ldc);
6150
}
6251
};
6352

6453
struct CblasDgemmOp {
6554
typedef double TDatatype;
66-
void operator()(bool ta, bool tb,
67-
int M, int N, int K,
68-
double alpha, double* A, int lda,
69-
double* B, int ldb,
70-
double beta, double* C, int ldc) {
71-
cblas_dgemm(CblasColMajor,
72-
BooleanToTranspose(ta),
73-
BooleanToTranspose(tb),
74-
M, N, K,
75-
alpha, A, lda,
76-
B, ldb,
77-
beta, C, ldc);
55+
void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda,
56+
double* B, int ldb, double beta, double* C, int ldc) {
57+
cblas_dgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
58+
lda, B, ldb, beta, C, ldc);
7859
}
7960
};
8061

62+
struct CblasSgemmBatchOp {
63+
typedef float TDatatype;
64+
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
65+
int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
66+
int c_stride, int ldc) {
67+
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
68+
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
69+
#if USE_MKL_BLAS == 1
70+
std::vector<const float*> A_array(batch_size);
71+
std::vector<const float*> B_array(batch_size);
72+
std::vector<float*> C_array(batch_size);
73+
for (int i = 0; i < batch_size; ++i) {
74+
A_array[i] = A + i * a_stride;
75+
B_array[i] = B + i * b_stride;
76+
C_array[i] = C + i * c_stride;
77+
}
78+
cblas_sgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda,
79+
B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size);
80+
#else
81+
for (int i = 0; i < batch_size; ++i) {
82+
cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
83+
A += a_stride;
84+
B += b_stride;
85+
C += c_stride;
86+
}
87+
#endif
88+
}
89+
};
90+
91+
struct CblasSgemmBatchIterativeOp {
92+
typedef float TDatatype;
93+
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
94+
int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
95+
int c_stride, int ldc) {
96+
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
97+
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
98+
for (int i = 0; i < batch_size; ++i) {
99+
cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
100+
A += a_stride;
101+
B += b_stride;
102+
C += c_stride;
103+
}
104+
}
105+
};
106+
107+
struct CblasDgemmBatchOp {
108+
typedef double TDatatype;
109+
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
110+
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
111+
int c_stride, int ldc) {
112+
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
113+
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
114+
#if USE_MKL_BLAS == 1
115+
std::vector<const double*> A_array(batch_size);
116+
std::vector<const double*> B_array(batch_size);
117+
std::vector<double*> C_array(batch_size);
118+
for (int i = 0; i < batch_size; ++i) {
119+
A_array[i] = A + i * a_stride;
120+
B_array[i] = B + i * b_stride;
121+
C_array[i] = C + i * c_stride;
122+
}
123+
cblas_dgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda,
124+
B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size);
125+
#else
126+
for (int i = 0; i < batch_size; ++i) {
127+
cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
128+
A += a_stride;
129+
B += b_stride;
130+
C += c_stride;
131+
}
132+
#endif
133+
}
134+
};
135+
136+
struct CblasDgemmBatchIterativeOp {
137+
typedef double TDatatype;
138+
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
139+
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
140+
int c_stride, int ldc) {
141+
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
142+
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
143+
for (int i = 0; i < batch_size; ++i) {
144+
cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
145+
A += a_stride;
146+
B += b_stride;
147+
C += c_stride;
148+
}
149+
}
150+
};
81151

82152
// matrix multiplication for row major
83153
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
84-
.set_body([](TVMArgs args, TVMRetValue *ret) {
85-
DLTensor* A = args[0];
86-
CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
87-
TypeMatch(A->dtype, kDLFloat, 64));
154+
.set_body([](TVMArgs args, TVMRetValue* ret) {
155+
DLTensor* A = args[0];
156+
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
157+
158+
if (TypeMatch(A->dtype, kDLFloat, 32))
159+
CallGemm(args, ret, CblasSgemmOp());
160+
else
161+
CallGemm(args, ret, CblasDgemmOp());
162+
});
88163

89-
if (TypeMatch(A->dtype, kDLFloat, 32))
90-
CallGemm(args, ret, CblasSgemmOp());
91-
else
92-
CallGemm(args, ret, CblasDgemmOp());
93-
});
164+
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul")
165+
.set_body([](TVMArgs args, TVMRetValue* ret) {
166+
DLTensor* A = args[0];
167+
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
168+
if (TypeMatch(A->dtype, kDLFloat, 32)) {
169+
CallBatchGemm(args, ret, CblasSgemmBatchOp());
170+
} else {
171+
CallBatchGemm(args, ret, CblasDgemmBatchOp());
172+
}
173+
});
174+
175+
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative")
176+
.set_body([](TVMArgs args, TVMRetValue* ret) {
177+
DLTensor* A = args[0];
178+
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
179+
if (TypeMatch(A->dtype, kDLFloat, 32)) {
180+
CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp());
181+
} else {
182+
CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp());
183+
}
184+
});
94185
} // namespace contrib
95186
} // namespace tvm

0 commit comments

Comments
 (0)