Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support for axis parameter in linalg.gemm #10864

Merged
merged 1 commit into from
May 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/operator/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ void linalg_batch_gemm(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DTyp
const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
bool tA, bool tB, Stream<xpu> *s = 0);

// Version of batch gemmm where rows are indexed at axis 1 and columns at axis 3.
template<typename xpu, typename DType>
void linalg_batch_gemm(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4, DType>& B,
const Tensor<xpu, 4, DType>& C, DType alpha, DType beta,
bool tA, bool tB, Stream<xpu> *s = 0);


template<typename xpu, typename DType>
inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 2, DType>& B,
Expand Down
276 changes: 187 additions & 89 deletions src/operator/linalg_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ inline void check_gemm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DTyp
<< "Non compatible matrix dimensions between inputs A and B for gemm";
}

template<typename xpu, typename DType>
void linalg_gemm_axis(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
bool tA, bool tB, Stream<xpu> *s = 0);

#if (MSHADOW_USE_CBLAS == 1 || MSHADOW_USE_MKL == 1)

#define LINALG_CPU_GEMM(fname, DType) \
Expand All @@ -80,6 +85,38 @@ void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const Tensor<
} \
}

// Batched gemm where the batch coordinate is given by the second axis.
#define LINALG_CPU_GEMM_AXIS(fname, DType) \
template<> inline \
void linalg_gemm_axis<cpu, DType>(const Tensor<cpu, 3, DType>& A, const Tensor<cpu, 3, DType>& B, \
const Tensor<cpu, 3, DType>& C, DType alpha, DType beta, \
bool tA, bool tB, Stream<cpu> *s) { \
linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \
for (index_t i = 0; i < A.size(1); ++i) { \
cblas_##fname(CblasRowMajor, (tA ? CblasTrans : CblasNoTrans), \
(tB ? CblasTrans : CblasNoTrans), \
C.size(0), C.size(2), (tA ? A.size(0) : A.size(2)), alpha, \
A.dptr_+i*A.stride_, A.size(1)*A.stride_, \
B.dptr_+i*B.stride_, B.size(1)*B.stride_, beta, \
C.dptr_+i*C.stride_, C.size(1)*C.stride_); \
} \
}

LINALG_CPU_GEMM_AXIS(sgemm, float)
LINALG_CPU_GEMM_AXIS(dgemm, double)

// Version where matrix rows are given by the second axis.
#define LINALG_XPU_BATCH_GEMM_AXIS(xpu, DType) \
template<> inline \
void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4, DType>& B, \
const Tensor<xpu, 4, DType>& C, DType alpha, DType beta, \
bool tA, bool tB, Stream<xpu> *s) { \
linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
for (index_t i = 0; i < A.size(0); ++i) { \
linalg_gemm_axis(A[i], B[i], C[i], alpha, beta, tA, tB, s); \
} \
}

#else

#define LINALG_CPU_GEMM(fname, DType) \
Expand All @@ -98,6 +135,14 @@ void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const Tensor<
LOG(FATAL) << "linalg_batch_gemm not implemented by mxnet for cpu, needs cblas!"; \
}

#define LINALG_XPU_BATCH_GEMM_AXIS(xpu, DType) \
template<> inline \
void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4, DType>& B, \
const Tensor<xpu, 4, DType>& C, DType alpha, DType beta, \
bool tA, bool tB, Stream<xpu> *s) { \
LOG(FATAL) << "linalg_batch_gemm not implemented by mxnet for cpu, needs cblas!"; \
}

#endif // MSHADOW_USE_CBLAS == 1 || MSHADOW_USE_MKL == 1

LINALG_CPU_GEMM(sgemm, float)
Expand All @@ -106,6 +151,9 @@ LINALG_CPU_GEMM(dgemm, double)
LINALG_XPU_BATCH_GEMM(cpu, float)
LINALG_XPU_BATCH_GEMM(cpu, double)

LINALG_XPU_BATCH_GEMM_AXIS(cpu, float)
LINALG_XPU_BATCH_GEMM_AXIS(cpu, double)

// Specialization of linalg_gemm<cpu, DType> for DType=mshadow::half::half_t.
template<> inline
void linalg_gemm<cpu, mshadow::half::half_t>(const Tensor<cpu, 2, mshadow::half::half_t>& A,
Expand Down Expand Up @@ -140,6 +188,28 @@ void linalg_gemm<gpu, DType>(const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 2
LINALG_GPU_GEMM(Sgemm, float)
LINALG_GPU_GEMM(Dgemm, double)

// Version where matrix rows are given by first axis.
#define LINALG_GPU_GEMM_AXIS(fname, DType) \
template<> inline \
void linalg_gemm_axis<gpu, DType>(const Tensor<gpu, 3, DType>& A, const Tensor<gpu, 3, DType>& B, \
const Tensor<gpu, 3, DType>& C, DType alpha, DType beta, \
bool tA, bool tB, Stream<gpu> *s) { \
using namespace mxnet; \
using mshadow::gpu; \
CHECK_NOTNULL(s); \
linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \
CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
(tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
C.size(2), C.size(0), (tB ? B.size(2) : B.size(0)), &alpha, \
B.dptr_, B.size(1)*B.stride_, B.stride_, \
A.dptr_, A.size(1)*A.stride_, A.stride_, &beta, \
C.dptr_, C.size(1)*C.stride_, C.stride_, A.size(1))) \
}
LINALG_GPU_GEMM_AXIS(SgemmStridedBatched, float)
LINALG_GPU_GEMM_AXIS(DgemmStridedBatched, double)

// Specialization of linalg_gemm<gpu, DType> for DType=mshadow::half::half_t.
// Specialization of linalg_gemm<gpu, DType> for DType=mshadow::half::half_t.
template<> inline
void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half::half_t>& A,
Expand Down Expand Up @@ -192,6 +262,8 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
#if CUDA_VERSION < 8000
LINALG_XPU_BATCH_GEMM(gpu, float)
LINALG_XPU_BATCH_GEMM(gpu, double)
LINALG_XPU_BATCH_GEMM_AXIS(gpu, float)
LINALG_XPU_BATCH_GEMM_AXIS(gpu, double)
#else
#define LINALG_GPU_BATCH_GEMM(fname, DType) \
template<> inline \
Expand All @@ -217,10 +289,125 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
LINALG_GPU_BATCH_GEMM(SgemmStridedBatched, float)
LINALG_GPU_BATCH_GEMM(DgemmStridedBatched, double)

// Version where matrix rows are given by second axis.
#define LINALG_GPU_BATCH_GEMM_AXIS(fname, DType) \
template<> inline \
void linalg_batch_gemm<gpu, DType>(const Tensor<gpu, 4, DType>& A, \
const Tensor<gpu, 4, DType>& B, \
const Tensor<gpu, 4, DType>& C, DType alpha, DType beta, \
bool tA, bool tB, Stream<gpu> *s) { \
using namespace mxnet; \
using mshadow::gpu; \
CHECK_NOTNULL(s); \
linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
linalg_check_batch_size(A.size(2), B.size(2), C.size(2)); \
for (index_t i = 0; i < A.size(2); ++i) { \
CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
(tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
C.size(3), C.size(1), (tB ? B.size(3) : B.size(1)), &alpha, \
B.dptr_+i*B.stride_, B.size(2) * B.stride_, B.size(1)*B.size(2)*B.stride_, \
A.dptr_+i*A.stride_, A.size(2) * A.stride_, A.size(1)*A.size(2)*A.stride_, &beta, \
C.dptr_+i*C.stride_, C.size(2) * C.stride_, C.size(1)*C.size(2)*C.stride_, A.size(0))) \
}\
}

LINALG_GPU_BATCH_GEMM_AXIS(SgemmStridedBatched, float)
LINALG_GPU_BATCH_GEMM_AXIS(DgemmStridedBatched, double)

#endif // CUDA < 8000

#endif // __CUDACC__

/*!
* \brief Performs gemm, setting alpha and beta as appropriate for `req`.
*
* \param A the first operand of the gemm
* \param B the second operand of the gemm
* \param C the data to be assigned
* \param tA whether the `A` operand should be transposed first.
* \param tB whether the `B` operand should be transposed first.
* \param s the stream to perform the operation
* \param req the assignment request
*/
template<typename xpu, typename DType>
inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 2, DType>& B,
const Tensor<xpu, 2, DType>& C,
bool tA, bool tB, Stream<xpu> *s,
mxnet::OpReqType req) {
using namespace mxnet;
switch (req) {
case kNullOp:
break;
case kWriteTo:
case kWriteInplace:
linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s);
break;
case kAddTo:
linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s);
break;
default:
LOG(FATAL) << "not reached";
}
}

#if (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)

// A template for a cpu linalg_gemm implementation using mshadow::dot()
#define LINALG_CPU_GEMM_NO_CBLAS(DType) \
template<> inline \
void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
const Tensor<cpu, 2, DType>& B, \
const Tensor<cpu, 2, DType>& C, \
bool tA, bool tB, Stream<cpu> *s, \
mxnet::OpReqType req) { \
using namespace mxnet; \
using mshadow::cpu; \
switch (req) { \
case kNullOp: \
break; \
case kWriteTo: \
case kWriteInplace: \
if (tA) { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B); \
} \
} else { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B); \
} \
} \
break; \
case kAddTo: \
if (tA) { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B); \
} \
} else { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B); \
} \
} \
break; \
default: \
LOG(FATAL) << "not reached"; \
} \
}

LINALG_CPU_GEMM_NO_CBLAS(float)
LINALG_CPU_GEMM_NO_CBLAS(double)

#endif // (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)

//////////////////////////////// TRSM ////////////////////////////////////////////

// CPU/GPU-versions of BLAS3 function "trsm". Please refer to the BLAS3-documentation
Expand Down Expand Up @@ -313,95 +500,6 @@ LINALG_XPU_BATCH_TRSM(gpu, double)

#endif // __CUDACC__

/*!
* \brief Performs gemm, setting alpha and beta as appropriate for `req`.
*
* \param A the first operand of the gemm
* \param B the second operand of the gemm
* \param C the data to be assigned
* \param tA whether the `A` operand should be transposed first.
* \param tB whether the `B` operand should be transposed first.
* \param s the stream to perform the operation
* \param req the assignment request
*/
template<typename xpu, typename DType>
inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 2, DType>& B,
const Tensor<xpu, 2, DType>& C,
bool tA, bool tB, Stream<xpu> *s,
mxnet::OpReqType req) {
using namespace mxnet;
switch (req) {
case kNullOp:
break;
case kWriteTo:
case kWriteInplace:
linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s);
break;
case kAddTo:
linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s);
break;
default:
LOG(FATAL) << "not reached";
}
}

#if (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)

// A template for a cpu linalg_gemm implementation using mshadow::dot()
#define LINALG_CPU_GEMM_NO_CBLAS(DType) \
template<> inline \
void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
const Tensor<cpu, 2, DType>& B, \
const Tensor<cpu, 2, DType>& C, \
bool tA, bool tB, Stream<cpu> *s, \
mxnet::OpReqType req) { \
using namespace mxnet; \
using mshadow::cpu; \
switch (req) { \
case kNullOp: \
break; \
case kWriteTo: \
case kWriteInplace: \
if (tA) { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B); \
} \
} else { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B); \
} \
} \
break; \
case kAddTo: \
if (tA) { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B); \
} \
} else { \
if (tB) { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B.T()); \
} else { \
const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B); \
} \
} \
break; \
default: \
LOG(FATAL) << "not reached"; \
} \
}

LINALG_CPU_GEMM_NO_CBLAS(float)
LINALG_CPU_GEMM_NO_CBLAS(double)

#endif // (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)

//////////////////////////////// TRMM ////////////////////////////////////////////

// CPU/GPU-versions of BLAS3 function "trmm". Please refer to the BLAS3-documentation
Expand Down
Loading