Skip to content

Commit 4e509f4

Browse files
authored
add cumprod op (#35185)
* add test_cumprod_op * Revert "add test_cumprod_op" This reverts commit c96cf6d. * recommit * add error message * test input(x) initialize * test use cpu * update test code * add test type * add test case * solve ci problem * add complex case test * add complex case test * fix review problem * fix conflict * fix some docs * change test case * change test case * fix review problems again * fix docs * fix inclusivescan bug
1 parent 5bdca05 commit 4e509f4

File tree

10 files changed

+1154
-0
lines changed

10 files changed

+1154
-0
lines changed

paddle/fluid/framework/data_type.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ struct complex;
3737
namespace paddle {
3838
namespace framework {
3939

40+
template <typename T>
41+
struct IsComplex : public std::false_type {};
42+
43+
template <typename T>
44+
struct IsComplex<platform::complex<T>> : public std::true_type {};
45+
4046
template <typename T>
4147
struct DataTypeTrait {};
4248

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/cumprod_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class CumprodOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext *ctx) const override {
25+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Cumprod");
26+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Cumprod");
27+
28+
ctx->ShareDim("X", "Out");
29+
ctx->ShareLoD("X", "Out");
30+
}
31+
};
32+
33+
class CumprodOpMaker : public framework::OpProtoAndCheckerMaker {
34+
public:
35+
void Make() override {
36+
AddInput("X", "(Tensor), The input tensor of cumprod op.");
37+
AddOutput("Out", "(Tensor), The output tensor of cumprod op.");
38+
AddAttr<int>(
39+
"dim",
40+
"(int), The dim along which the input tensors will be cumproded");
41+
AddComment(
42+
R"DOC(Cumprod operator. Return the cumprod results of the input elements along the dim.
43+
For example, if input X is a tensor with rank 1 and N elements, the output will also be a tensor
44+
with rank 1 and N elements, and elements y[i] = x[0] * x[1] * x[2] *...* x[i] (0<=i<N))DOC");
45+
}
46+
};
47+
48+
template <typename T>
49+
class CumprodGradOpMaker : public framework::SingleGradOpMaker<T> {
50+
public:
51+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
52+
53+
protected:
54+
void Apply(GradOpPtr<T> grad_op) const override {
55+
grad_op->SetType("cumprod_grad");
56+
grad_op->SetInput("X", this->Input("X"));
57+
grad_op->SetInput("Out", this->Output("Out"));
58+
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
59+
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
60+
grad_op->SetAttrMap(this->Attrs());
61+
}
62+
};
63+
64+
class CumprodGradOp : public framework::OperatorWithKernel {
65+
public:
66+
using framework::OperatorWithKernel::OperatorWithKernel;
67+
68+
void InferShape(framework::InferShapeContext *ctx) const override {
69+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CumprodGrad");
70+
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "CumprodGrad");
71+
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
72+
"framework::GradVarName(\"Out\")", "CumprodGrad");
73+
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
74+
"framework::GradVarName(\"X\")", "CumprodGrad");
75+
ctx->ShareDim(framework::GradVarName("Out"), framework::GradVarName("X"));
76+
ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X"));
77+
}
78+
};
79+
80+
} // namespace operators
81+
} // namespace paddle
82+
83+
namespace ops = paddle::operators;
84+
85+
REGISTER_OPERATOR(cumprod, ops::CumprodOp, ops::CumprodOpMaker,
86+
ops::CumprodGradOpMaker<paddle::framework::OpDesc>,
87+
ops::CumprodGradOpMaker<paddle::imperative::OpBase>);
88+
89+
REGISTER_OPERATOR(cumprod_grad, ops::CumprodGradOp);
90+
91+
REGISTER_OP_CPU_KERNEL(
92+
cumprod, ops::CumprodOpCPUKernel<float>, ops::CumprodOpCPUKernel<double>,
93+
ops::CumprodOpCPUKernel<int>, ops::CumprodOpCPUKernel<int64_t>,
94+
ops::CumprodOpCPUKernel<paddle::platform::complex<float>>,
95+
ops::CumprodOpCPUKernel<paddle::platform::complex<double>>);
96+
97+
REGISTER_OP_CPU_KERNEL(
98+
cumprod_grad, ops::CumprodGradOpCPUKernel<float>,
99+
ops::CumprodGradOpCPUKernel<double>, ops::CumprodGradOpCPUKernel<int>,
100+
ops::CumprodGradOpCPUKernel<int64_t>,
101+
ops::CumprodGradOpCPUKernel<paddle::platform::complex<float>>,
102+
ops::CumprodGradOpCPUKernel<paddle::platform::complex<double>>);

0 commit comments

Comments
 (0)