Skip to content

Commit

Permalink
Add Matmul op (PaddlePaddle#26411)
Browse files Browse the repository at this point in the history
* add matmul_v2
  • Loading branch information
ForFishes authored Aug 22, 2020
1 parent 65ac1ef commit c609066
Show file tree
Hide file tree
Showing 10 changed files with 1,290 additions and 170 deletions.
158 changes: 82 additions & 76 deletions paddle/fluid/operators/dot_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,86 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename DeviceContext, typename T>
void DotGradFunction(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#ifdef __NVCC__
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);

if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}

if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);

if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto dx = EigenMatrix<T>::From(*tensor_dx);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
}

if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto dy = EigenMatrix<T>::From(*tensor_dy);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
}
}
#else
const auto* data_dout = tensor_dout->data<T>();

if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));

auto step = dim[dim.size() - 1];

int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_y[i] * data_dout[s];
}
}

if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));

auto step = dim[dim.size() - 1];

int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_x[i] * data_dout[s];
}
}
#endif
}

template <typename DeviceContext, typename T>
class DotKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -84,83 +164,9 @@ class DotGradKernel : public framework::OpKernel<T> {

if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());
#ifdef __NVCC__
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);

if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}

if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);

if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto dx = EigenMatrix<T>::From(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
}

if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto dy = EigenMatrix<T>::From(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
}
}
#else
const auto* data_dout = tensor_dout->data<T>();

if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));

auto step = dim[dim.size() - 1];

int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_y[i] * data_dout[s];
}
}

if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));

auto step = dim[dim.size() - 1];

int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_x[i] * data_dout[s];
}
}
#endif
DotGradFunction<DeviceContext, T>(tensor_x, tensor_y, tensor_dout,
tensor_dx, tensor_dy, ctx);
}
};

Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/math/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ class Blas {
int K, T alpha, const T* A, const T* B, T beta, T* C,
int batchCount, int64_t strideA, int64_t strideB) const;

template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T** A, const T** B, T beta, T** C,
int batchCount) const;

#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
template <typename T>
void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/operators/math/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,17 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#endif // CUDA_VERSION >= 9010
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<T>(transA, transB, M, N, K, alpha, A[k], B[k], beta,
C[k]);
}
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
Expand Down
21 changes: 21 additions & 0 deletions paddle/fluid/operators/math/blas_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cmath>
#include <limits>
#include <vector>
Expand Down Expand Up @@ -655,6 +656,26 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
#endif
}

template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const {
#ifdef PADDLE_WITH_MKLML
const int lda = std::max((transA == CblasNoTrans) ? K : M, 1);
const int ldb = std::max((transB == CblasNoTrans) ? N : K, 1);
const int ldc = std::max(N, 1);
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, A,
&lda, B, &ldb, &beta, C, &ldc, 1 /* group_count */,
&batchCount);
#else
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<T>(transA, transB, M, N, K, alpha, A[k], B[k], beta,
C[k]);
}
#endif
}

#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
template <>
template <typename T>
Expand Down
Loading

0 comments on commit c609066

Please sign in to comment.