Skip to content

Commit b434dde

Browse files
committed
Move bmm OP from fluid to phi
1 parent 98e9685 commit b434dde

File tree

16 files changed

+411
-209
lines changed

16 files changed

+411
-209
lines changed

paddle/fluid/operators/bmm_op.cc

Lines changed: 14 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616

1717
#include <vector>
1818

19+
#include "paddle/fluid/framework/infershape_utils.h"
20+
#include "paddle/phi/core/infermeta_utils.h"
21+
#include "paddle/phi/infermeta/backward.h"
22+
#include "paddle/phi/infermeta/binary.h"
23+
1924
namespace paddle {
2025
namespace operators {
2126

@@ -24,62 +29,6 @@ class BmmOp : public framework::OperatorWithKernel {
2429
using framework::OperatorWithKernel::OperatorWithKernel;
2530

2631
protected:
27-
void InferShape(framework::InferShapeContext* ctx) const override {
28-
PADDLE_ENFORCE_EQ(
29-
ctx->HasInput("X"),
30-
true,
31-
platform::errors::NotFound("Input(X) of BmmOp should not be null"));
32-
PADDLE_ENFORCE_EQ(
33-
ctx->HasInput("Y"),
34-
true,
35-
platform::errors::NotFound("Input(Y) of BmmOp should not be null"));
36-
PADDLE_ENFORCE_EQ(
37-
ctx->HasOutput("Out"),
38-
true,
39-
platform::errors::NotFound("Output(Out) of BmmOp should not be null."));
40-
41-
auto x_dims = ctx->GetInputDim("X");
42-
auto y_dims = ctx->GetInputDim("Y");
43-
44-
PADDLE_ENFORCE_EQ(x_dims.size(),
45-
3,
46-
platform::errors::InvalidArgument(
47-
"Input(X) of BmmOp must be 3-dimensional in BmmOp, "
48-
"but received X's shape: [%s].",
49-
x_dims));
50-
PADDLE_ENFORCE_EQ(y_dims.size(),
51-
3,
52-
platform::errors::InvalidArgument(
53-
"Input(Y) of BmmOp must be 3-dimensional in BmmOp, "
54-
"but received Y's shape: [%s].",
55-
y_dims));
56-
PADDLE_ENFORCE_EQ(
57-
x_dims[0],
58-
y_dims[0],
59-
platform::errors::InvalidArgument(
60-
"Input(X) and Input(Y) must have the same batch size in BmmOp, "
61-
"but received X's batch size: [%s],"
62-
"Y's batch size [%s]",
63-
x_dims[0],
64-
y_dims[0]));
65-
PADDLE_ENFORCE_EQ(
66-
x_dims[2],
67-
y_dims[1],
68-
platform::errors::InvalidArgument(
69-
"Input(X)'s width must be equal with Input(Y)'s height in BmmOp,"
70-
"but receive X's width: [%s],"
71-
"Y's height: [%s].",
72-
x_dims[2],
73-
y_dims[1]));
74-
75-
std::vector<int64_t> dim_out;
76-
dim_out.push_back(x_dims[0]);
77-
dim_out.push_back(x_dims[1]);
78-
dim_out.push_back(y_dims[2]);
79-
ctx->SetOutputDim("Out", phi::make_ddim(dim_out));
80-
ctx->ShareLoD("X", /*->*/ "Out");
81-
}
82-
8332
framework::OpKernelType GetExpectedKernelType(
8433
const framework::ExecutionContext& ctx) const override {
8534
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
@@ -110,33 +59,6 @@ class BmmOpGrad : public framework::OperatorWithKernel {
11059
using framework::OperatorWithKernel::OperatorWithKernel;
11160

11261
protected:
113-
void InferShape(framework::InferShapeContext* ctx) const override {
114-
PADDLE_ENFORCE_EQ(
115-
ctx->HasInput("X"),
116-
true,
117-
platform::errors::NotFound("Input(X) of BmmOp should not be null"));
118-
PADDLE_ENFORCE_EQ(
119-
ctx->HasInput("Y"),
120-
true,
121-
platform::errors::NotFound("Input(Y) of BmmOp should not be null"));
122-
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")),
123-
true,
124-
platform::errors::NotFound(
125-
"Output(Out@GRAD) of BmmOp should not be null."));
126-
127-
auto x_dims = ctx->GetInputDim("X");
128-
auto y_dims = ctx->GetInputDim("Y");
129-
130-
auto x_grad_name = framework::GradVarName("X");
131-
auto y_grad_name = framework::GradVarName("Y");
132-
133-
if (ctx->HasOutput(x_grad_name)) {
134-
ctx->SetOutputDim(x_grad_name, x_dims);
135-
}
136-
if (ctx->HasOutput(y_grad_name)) {
137-
ctx->SetOutputDim(y_grad_name, y_dims);
138-
}
139-
}
14062
framework::OpKernelType GetExpectedKernelType(
14163
const framework::ExecutionContext& ctx) const override {
14264
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
@@ -166,15 +88,16 @@ class BmmOpGradMaker : public framework::SingleGradOpMaker<T> {
16688

16789
namespace ops = paddle::operators;
16890

91+
DECLARE_INFER_SHAPE_FUNCTOR(bmm,
92+
BmmInferShapeFunctor,
93+
PD_INFER_META(phi::BmmInferMeta));
94+
DECLARE_INFER_SHAPE_FUNCTOR(bmm_grad,
95+
BmmGradInferShapeFunctor,
96+
PD_INFER_META(phi::BmmGradInferMeta));
16997
REGISTER_OPERATOR(bmm,
17098
ops::BmmOp,
17199
ops::BmmOpMaker,
172100
ops::BmmOpGradMaker<paddle::framework::OpDesc>,
173-
ops::BmmOpGradMaker<paddle::imperative::OpBase>);
174-
REGISTER_OPERATOR(bmm_grad, ops::BmmOpGrad);
175-
REGISTER_OP_CPU_KERNEL(bmm,
176-
ops::BmmKernel<phi::CPUContext, float>,
177-
ops::BmmKernel<phi::CPUContext, double>);
178-
REGISTER_OP_CPU_KERNEL(bmm_grad,
179-
ops::BmmGradKernel<phi::CPUContext, float>,
180-
ops::BmmGradKernel<phi::CPUContext, double>);
101+
ops::BmmOpGradMaker<paddle::imperative::OpBase>,
102+
BmmInferShapeFunctor);
103+
REGISTER_OPERATOR(bmm_grad, ops::BmmOpGrad, BmmGradInferShapeFunctor);

paddle/fluid/operators/bmm_op.cu

Lines changed: 0 additions & 29 deletions
This file was deleted.

paddle/fluid/operators/bmm_op.h

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -58,95 +58,6 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x,
5858
ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
5959
}
6060

61-
template <typename DeviceContext, typename T>
62-
class BmmKernel : public framework::OpKernel<T> {
63-
public:
64-
void Compute(const framework::ExecutionContext &context) const override {
65-
const Tensor &x = *context.Input<Tensor>("X");
66-
const Tensor &y = *context.Input<Tensor>("Y");
67-
Tensor *out = context.Output<Tensor>("Out");
68-
out->mutable_data<T>(context.GetPlace());
69-
70-
if (x.numel() == 0 || y.numel() == 0) {
71-
return;
72-
}
73-
74-
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
75-
76-
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(x.dims(), 0, false);
77-
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(y.dims(), 0, false);
78-
79-
// auto scale = static_cast<T>(context.Attr<float>("alpha"));
80-
blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0));
81-
}
82-
};
83-
84-
template <typename DeviceContext, typename T>
85-
class BmmGradKernel : public framework::OpKernel<T> {
86-
public:
87-
void MatMul(const framework::ExecutionContext &context,
88-
const framework::Tensor &a,
89-
bool trans_a,
90-
const framework::Tensor &b,
91-
bool trans_b,
92-
framework::Tensor *out) const {
93-
out->mutable_data<T>(context.GetPlace());
94-
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
95-
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a);
96-
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b);
97-
98-
blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0));
99-
}
100-
void CalcInputGrad(const framework::ExecutionContext &context,
101-
const framework::Tensor &a,
102-
bool trans_a,
103-
const framework::Tensor &b,
104-
bool trans_b,
105-
framework::Tensor *out) const {
106-
if (out == nullptr) return;
107-
MatMul(context, a, trans_a, b, trans_b, out);
108-
}
109-
void Compute(const framework::ExecutionContext &context) const override {
110-
auto x = *context.Input<framework::Tensor>("X");
111-
auto y = *context.Input<framework::Tensor>("Y");
112-
auto dout =
113-
*context.Input<framework::Tensor>(framework::GradVarName("Out"));
114-
auto *dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
115-
auto *dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
116-
117-
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, false, false);
118-
framework::DDim dx_dims;
119-
if (dx) {
120-
dx_dims = dx->dims();
121-
if (dx_dims != x.dims()) {
122-
dx->Resize(x.dims());
123-
}
124-
}
125-
126-
framework::DDim dy_dims;
127-
if (dy) {
128-
dy_dims = dy->dims();
129-
if (dy_dims != y.dims()) {
130-
dy->Resize(y.dims());
131-
}
132-
}
133-
134-
CalcInputGrad(context, dout, false, y, true, dx);
135-
CalcInputGrad(context, x, true, dout, false, dy);
136-
137-
if (dx) {
138-
if (dx_dims != x.dims()) {
139-
dx->Resize(dx_dims);
140-
}
141-
}
142-
if (dy) {
143-
if (dy_dims != y.dims()) {
144-
dy->Resize(dy_dims);
145-
}
146-
}
147-
}
148-
};
149-
15061
} // namespace operators
15162
} // namespace paddle
15263
#endif // PADDLE_FLUID_OPERATORS_BMM_OP_H_

paddle/phi/infermeta/backward.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,17 @@ void BilinearTensorProductGradInferMeta(const MetaTensor& x,
7373
}
7474
}
7575

76+
void BmmGradInferMeta(const MetaTensor& x,
77+
const MetaTensor& y,
78+
const MetaTensor& out_grad,
79+
MetaTensor* x_grad,
80+
MetaTensor* y_grad) {
81+
x_grad->set_dims(x.dims());
82+
y_grad->set_dims(y.dims());
83+
x_grad->set_dtype(x.dtype());
84+
y_grad->set_dtype(y.dtype());
85+
}
86+
7687
void ChannelShuffleGradInferMeta(const MetaTensor& out_grad,
7788
int groups,
7889
const std::string& data_format,

paddle/phi/infermeta/backward.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ void BilinearTensorProductGradInferMeta(const MetaTensor& x,
4141
MetaTensor* dweight,
4242
MetaTensor* dbias);
4343

44+
void BmmGradInferMeta(const MetaTensor& x,
45+
const MetaTensor& y,
46+
const MetaTensor& out_grad,
47+
MetaTensor* x_grad,
48+
MetaTensor* y_grad);
49+
4450
void ChannelShuffleGradInferMeta(const MetaTensor& out_grad,
4551
int groups,
4652
const std::string& data_format,

paddle/phi/infermeta/binary.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,53 @@ void BincountInferMeta(const MetaTensor& x,
260260
out->share_lod(x);
261261
}
262262

263+
void BmmInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
264+
std::vector<int64_t> x_dims = phi::vectorize(x.dims());
265+
std::vector<int64_t> y_dims = phi::vectorize(y.dims());
266+
std::size_t x_ndims = x_dims.size();
267+
std::size_t y_ndims = y_dims.size();
268+
269+
PADDLE_ENFORCE_EQ(x_ndims,
270+
3,
271+
phi::errors::InvalidArgument(
272+
"Input(X) of BmmOp must be 3-dimensional in BmmOp, "
273+
"but received X's shape: [%s].",
274+
x_ndims));
275+
PADDLE_ENFORCE_EQ(y_ndims,
276+
3,
277+
phi::errors::InvalidArgument(
278+
"Input(Y) of BmmOp must be 3-dimensional in BmmOp, "
279+
"but received Y's shape: [%s].",
280+
y_ndims));
281+
PADDLE_ENFORCE_EQ(
282+
x_dims[0],
283+
y_dims[0],
284+
phi::errors::InvalidArgument(
285+
"Input(X) and Input(Y) must have the same batch size in BmmOp, "
286+
"but received X's batch size: [%s],"
287+
"Y's batch size [%s]",
288+
x_dims[0],
289+
y_dims[0]));
290+
PADDLE_ENFORCE_EQ(
291+
x_dims[2],
292+
y_dims[1],
293+
phi::errors::InvalidArgument(
294+
"Input(X)'s width must be equal with Input(Y)'s height in BmmOp,"
295+
"but receive X's width: [%s],"
296+
"Y's height: [%s].",
297+
x_dims[2],
298+
y_dims[1]));
299+
300+
std::vector<int64_t> dim_out;
301+
dim_out.push_back(x_dims[0]);
302+
dim_out.push_back(x_dims[1]);
303+
dim_out.push_back(y_dims[2]);
304+
out->set_dims(phi::make_ddim(dim_out));
305+
out->share_lod(x);
306+
out->set_dtype(x.dtype());
307+
out->set_layout(x.layout());
308+
}
309+
263310
void CholeskySolveInferMeta(const MetaTensor& x,
264311
const MetaTensor& y,
265312
bool upper,

paddle/phi/infermeta/binary.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ void BincountInferMeta(const MetaTensor& x,
6060
int minlength,
6161
MetaTensor* out);
6262

63+
void BmmInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
64+
6365
void CholeskySolveInferMeta(const MetaTensor& x,
6466
const MetaTensor& y,
6567
bool upper,
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/dense_tensor.h"
18+
19+
namespace phi {
20+
21+
template <typename T, typename Context>
22+
void BmmGradKernel(const Context& ctx,
23+
const DenseTensor& x,
24+
const DenseTensor& y,
25+
const DenseTensor& out_grad,
26+
DenseTensor* x_grad,
27+
DenseTensor* y_grad);
28+
29+
} // namespace phi

0 commit comments

Comments
 (0)