Skip to content

Commit

Permalink
Support for axis parameter in linalg.gemm (apache#10864)
Browse files Browse the repository at this point in the history
  • Loading branch information
asmushetzel authored and piiswrong committed May 29, 2018
1 parent 9ab0d2b commit 4ac76c8
Show file tree
Hide file tree
Showing 7 changed files with 447 additions and 188 deletions.
7 changes: 7 additions & 0 deletions src/operator/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ void linalg_batch_gemm(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DTyp
const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
bool tA, bool tB, Stream<xpu> *s = 0);

// Version of batch gemmm where rows are indexed at axis 1 and columns at axis 3.
template<typename xpu, typename DType>
void linalg_batch_gemm(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4, DType>& B,
const Tensor<xpu, 4, DType>& C, DType alpha, DType beta,
bool tA, bool tB, Stream<xpu> *s = 0);


template<typename xpu, typename DType>
inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 2, DType>& B,
Expand Down
276 changes: 187 additions & 89 deletions src/operator/linalg_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ inline void check_gemm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DTyp
<< "Non compatible matrix dimensions between inputs A and B for gemm";
}

template<typename xpu, typename DType>
void linalg_gemm_axis(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
bool tA, bool tB, Stream<xpu> *s = 0);

#if (MSHADOW_USE_CBLAS == 1 || MSHADOW_USE_MKL == 1)

#define LINALG_CPU_GEMM(fname, DType) \
Expand All @@ -80,6 +85,38 @@ void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const Tensor<
} \
}

// Batched gemm where the batch coordinate is given by the second axis.
#define LINALG_CPU_GEMM_AXIS(fname, DType) \
template<> inline \
void linalg_gemm_axis<cpu, DType>(const Tensor<cpu, 3, DType>& A, const Tensor<cpu, 3, DType>& B, \
const Tensor<cpu, 3, DType>& C, DType alpha, DType beta, \
bool tA, bool tB, Stream<cpu> *s) { \
linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \
for (index_t i = 0; i < A.size(1); ++i) { \
cblas_##fname(CblasRowMajor, (tA ? CblasTrans : CblasNoTrans), \
(tB ? CblasTrans : CblasNoTrans), \
C.size(0), C.size(2), (tA ? A.size(0) : A.size(2)), alpha, \
A.dptr_+i*A.stride_, A.size(1)*A.stride_, \
B.dptr_+i*B.stride_, B.size(1)*B.stride_, beta, \
C.dptr_+i*C.stride_, C.size(1)*C.stride_); \
} \
}

LINALG_CPU_GEMM_AXIS(sgemm, float)
LINALG_CPU_GEMM_AXIS(dgemm, double)

// Version where matrix rows are given by the second axis.
#define LINALG_XPU_BATCH_GEMM_AXIS(xpu, DType) \
template<> inline \
void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4, DType>& B, \
const Tensor<xpu, 4, DType>& C, DType alpha, DType beta, \
bool tA, bool tB, Stream<xpu> *s) { \
linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
for (index_t i = 0; i < A.size(0); ++i) { \
linalg_gemm_axis(A[i], B[i], C[i], alpha, beta, tA, tB, s); \
} \
}

#else

#define LINALG_CPU_GEMM(fname, DType) \
Expand All @@ -98,6 +135,14 @@ void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const Tensor<
LOG(FATAL) << "linalg_batch_gemm not implemented by mxnet for cpu, needs cblas!"; \
}

#define LINALG_XPU_BATCH_GEMM_AXIS(xpu, DType) \
template<> inline \
void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4, DType>& B, \
const Tensor<xpu, 4, DType>& C, DType alpha, DType beta, \
bool tA, bool tB, Stream<xpu> *s) { \
LOG(FATAL) << "linalg_batch_gemm not implemented by mxnet for cpu, needs cblas!"; \
}

#endif // MSHADOW_USE_CBLAS == 1 || MSHADOW_USE_MKL == 1

LINALG_CPU_GEMM(sgemm, float)
Expand All @@ -106,6 +151,9 @@ LINALG_CPU_GEMM(dgemm, double)
LINALG_XPU_BATCH_GEMM(cpu, float)
LINALG_XPU_BATCH_GEMM(cpu, double)

LINALG_XPU_BATCH_GEMM_AXIS(cpu, float)
LINALG_XPU_BATCH_GEMM_AXIS(cpu, double)

// Specialization of linalg_gemm<cpu, DType> for DType=mshadow::half::half_t.
template<> inline
void linalg_gemm<cpu, mshadow::half::half_t>(const Tensor<cpu, 2, mshadow::half::half_t>& A,
Expand Down Expand Up @@ -140,6 +188,28 @@ void linalg_gemm<gpu, DType>(const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 2
LINALG_GPU_GEMM(Sgemm, float)
LINALG_GPU_GEMM(Dgemm, double)

// Version where matrix rows are given by first axis.
#define LINALG_GPU_GEMM_AXIS(fname, DType) \
template<> inline \
void linalg_gemm_axis<gpu, DType>(const Tensor<gpu, 3, DType>& A, const Tensor<gpu, 3, DType>& B, \
const Tensor<gpu, 3, DType>& C, DType alpha, DType beta, \
bool tA, bool tB, Stream<gpu> *s) { \
using namespace mxnet; \
using mshadow::gpu; \
CHECK_NOTNULL(s); \
linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \
CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
(tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
C.size(2), C.size(0), (tB ? B.size(2) : B.size(0)), &alpha, \
B.dptr_, B.size(1)*B.stride_, B.stride_, \
A.dptr_, A.size(1)*A.stride_, A.stride_, &beta, \
C.dptr_, C.size(1)*C.stride_, C.stride_, A.size(1))) \
}
LINALG_GPU_GEMM_AXIS(SgemmStridedBatched, float)
LINALG_GPU_GEMM_AXIS(DgemmStridedBatched, double)

// Specialization of linalg_gemm<gpu, DType> for DType=mshadow::half::half_t.
// Specialization of linalg_gemm<gpu, DType> for DType=mshadow::half::half_t.
template<> inline
void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half::half_t>& A,
Expand Down Expand Up @@ -192,6 +262,8 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
#if CUDA_VERSION < 8000
LINALG_XPU_BATCH_GEMM(gpu, float)
LINALG_XPU_BATCH_GEMM(gpu, double)
LINALG_XPU_BATCH_GEMM_AXIS(gpu, float)
LINALG_XPU_BATCH_GEMM_AXIS(gpu, double)
#else
#define LINALG_GPU_BATCH_GEMM(fname, DType) \
template<> inline \
Expand All @@ -217,10 +289,125 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
LINALG_GPU_BATCH_GEMM(SgemmStridedBatched, float)
LINALG_GPU_BATCH_GEMM(DgemmStridedBatched, double)

// Version where matrix rows are given by second axis.
#define LINALG_GPU_BATCH_GEMM_AXIS(fname, DType) \
template<> inline \
void linalg_batch_gemm<gpu, DType>(const Tensor<gpu, 4, DType>& A, \
const Tensor<gpu, 4, DType>& B, \
const Tensor<gpu, 4, DType>& C, DType alpha, DType beta, \
bool tA, bool tB, Stream<gpu> *s) { \
using namespace mxnet; \
using mshadow::gpu; \
CHECK_NOTNULL(s); \
linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
linalg_check_batch_size(A.size(2), B.size(2), C.size(2)); \
for (index_t i = 0; i < A.size(2); ++i) { \
CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
(tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
C.size(3), C.size(1), (tB ? B.size(3) : B.size(1)), &alpha, \
B.dptr_+i*B.stride_, B.size(2) * B.stride_, B.size(1)*B.size(2)*B.stride_, \
A.dptr_+i*A.stride_, A.size(2) * A.stride_, A.size(1)*A.size(2)*A.stride_, &beta, \
C.dptr_+i*C.stride_, C.size(2) * C.stride_, C.size(1)*C.size(2)*C.stride_, A.size(0))) \
}\
}

LINALG_GPU_BATCH_GEMM_AXIS(SgemmStridedBatched, float)
LINALG_GPU_BATCH_GEMM_AXIS(DgemmStridedBatched, double)

#endif // CUDA < 8000

#endif // __CUDACC__

/*!
* \brief Performs gemm, setting alpha and beta as appropriate for `req`.
*
* \param A the first operand of the gemm
* \param B the second operand of the gemm
* \param C the data to be assigned
* \param tA whether the `A` operand should be transposed first.
* \param tB whether the `B` operand should be transposed first.
* \param s the stream to perform the operation
* \param req the assignment request
*/
template<typename xpu, typename DType>
inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 2, DType>& B,
const Tensor<xpu, 2, DType>& C,
bool tA, bool tB, Stream<xpu> *s,
mxnet::OpReqType req) {
using namespace mxnet;
switch (req) {
case kNullOp:
break;
case kWriteTo:
case kWriteInplace:
linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s);
break;
case kAddTo:
linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s);
break;
default:
LOG(FATAL) << "not reached";
}
}

#if (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)

// A template for a cpu linalg_gemm implementation using mshadow::dot()
#define LINALG_CPU_GEMM_NO_CBLAS(DType) \
template<> inline \
void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
const Tensor<cpu, 2, DType>& B, \
const Tensor<cpu, 2, DType>& C, \
bool tA, bool tB, Stream<cpu> *s, \
mxnet::OpReqType req) { \
using namespace mxnet; \
using mshadow::cpu; \
switch (req) { \
case kNullOp: \
break; \
case kWriteTo: \
case kWriteInplace: \
if (tA) { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B); \
} \
} else { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B); \
} \
} \
break; \
case kAddTo: \
if (tA) { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B); \
} \
} else { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B); \
} \
} \
break; \
default: \
LOG(FATAL) << "not reached"; \
} \
}

LINALG_CPU_GEMM_NO_CBLAS(float)
LINALG_CPU_GEMM_NO_CBLAS(double)

#endif // (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)

//////////////////////////////// TRSM ////////////////////////////////////////////

// CPU/GPU-versions of BLAS3 function "trsm". Please refer to the BLAS3-documentation
Expand Down Expand Up @@ -313,95 +500,6 @@ LINALG_XPU_BATCH_TRSM(gpu, double)

#endif // __CUDACC__

/*!
* \brief Performs gemm, setting alpha and beta as appropriate for `req`.
*
* \param A the first operand of the gemm
* \param B the second operand of the gemm
* \param C the data to be assigned
* \param tA whether the `A` operand should be transposed first.
* \param tB whether the `B` operand should be transposed first.
* \param s the stream to perform the operation
* \param req the assignment request
*/
template<typename xpu, typename DType>
inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 2, DType>& B,
const Tensor<xpu, 2, DType>& C,
bool tA, bool tB, Stream<xpu> *s,
mxnet::OpReqType req) {
using namespace mxnet;
switch (req) {
case kNullOp:
break;
case kWriteTo:
case kWriteInplace:
linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s);
break;
case kAddTo:
linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s);
break;
default:
LOG(FATAL) << "not reached";
}
}

#if (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)

// A template for a cpu linalg_gemm implementation using mshadow::dot()
#define LINALG_CPU_GEMM_NO_CBLAS(DType) \
template<> inline \
void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
const Tensor<cpu, 2, DType>& B, \
const Tensor<cpu, 2, DType>& C, \
bool tA, bool tB, Stream<cpu> *s, \
mxnet::OpReqType req) { \
using namespace mxnet; \
using mshadow::cpu; \
switch (req) { \
case kNullOp: \
break; \
case kWriteTo: \
case kWriteInplace: \
if (tA) { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B); \
} \
} else { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B); \
} \
} \
break; \
case kAddTo: \
if (tA) { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B); \
} \
} else { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B); \
} \
} \
break; \
default: \
LOG(FATAL) << "not reached"; \
} \
}

LINALG_CPU_GEMM_NO_CBLAS(float)
LINALG_CPU_GEMM_NO_CBLAS(double)

#endif // (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)

//////////////////////////////// TRMM ////////////////////////////////////////////

// CPU/GPU-versions of BLAS3 function "trmm". Please refer to the BLAS3-documentation
Expand Down
Loading

0 comments on commit 4ac76c8

Please sign in to comment.