Skip to content

Commit

Permalink
[bug fix] fix pow composite (#52645)
Browse files Browse the repository at this point in the history
* [bug fix] fix pow composite

* [bug fix] for ci
  • Loading branch information
wangzhen38 authored Apr 10, 2023
1 parent 58d5af0 commit f2d1f28
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
5 changes: 2 additions & 3 deletions paddle/fluid/operators/elementwise/elementwise_pow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ REGISTER_OPERATOR(elementwise_pow,
ops::ElementwisePowOpMaker,
ops::ElementwiseOpInferVarType,
ops::ElementwisePowOpGradMaker<paddle::framework::OpDesc>,
ops::ElementwisePowOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_pow_grad,
ops::ElementwiseOpGrad,
ops::ElementwisePowOpGradMaker<paddle::imperative::OpBase>,
ops::ElementwisePowCompositeGradOpMaker);
REGISTER_OPERATOR(elementwise_pow_grad, ops::ElementwiseOpGrad);

REGISTER_OP_VERSION(elementwise_pow)
.AddCheckpoint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ void elementwise_pow_grad(const Tensor& x,
// dy = lnx * x^y
auto lnx = log<T>(x);
auto x_pow_y = elementwise_pow<T>(x, y);
auto dy_res = lnx * x_pow_y;
auto dy_res = lnx * x_pow_y * out_grad;
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
Expand All @@ -418,7 +418,7 @@ void elementwise_pow_grad(const Tensor& x,
// dx = y * x^(y-1)
auto tmp_z = y - 1.0;
auto x_pow_z = elementwise_pow<T>(x, tmp_z);
auto dx_res = y * x_pow_z;
auto dx_res = y * x_pow_z * out_grad;
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
Expand Down

0 comments on commit f2d1f28

Please sign in to comment.