Skip to content

Commit

Permalink
【Prim】Add gather vjp (#50305)
Browse files Browse the repository at this point in the history
* tmp gather vjp

* support gather

* remove useless code

* fix compiling error

* fix ut

* add eager test

* add eager test

* add seed

* fix cpu error

* fix transpose op compat

* remove tensor index case

* fix prim_cinn

* fix ut
  • Loading branch information
JiabinYang authored Feb 22, 2023
1 parent 613a3ff commit 4db8e5c
Show file tree
Hide file tree
Showing 11 changed files with 836 additions and 28 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ paddle/phi/api/lib/tensor_operants.cc
paddle/phi/extension.h
paddle/phi/include/*
paddle/phi/infermeta/generated.*

paddle/fluid/prim/api/generated_prim/*.cc
paddle/fluid/prim/api/generated_prim/*.h
*.DS_Store
*.vs
build/
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ class ElementwiseAddCompositeGradOpMaker
paddle::experimental::Tensor y = this->GetSingleForwardInput("Y");
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
auto dx_ptr = this->GetOutputPtr(&dx);
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
paddle::experimental::Tensor dy = this->GetSingleInputGrad("Y");
auto dy_ptr = this->GetOutputPtr(&dy);
auto* dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
VLOG(6) << "Runing add_grad composite func";
Expand Down
30 changes: 30 additions & 0 deletions paddle/fluid/operators/gather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
Expand Down Expand Up @@ -132,6 +134,33 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};

class GatherCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;

protected:
void Apply() override {
paddle::experimental::Tensor index = this->GetSingleForwardInput("Index");
paddle::optional<paddle::experimental::Tensor> tensor_axis =
this->GetOptionalSingleForwardInput("Axis");
paddle::experimental::Tensor x = this->GetSingleForwardInput("X");
paddle::experimental::Tensor dout = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(*dx_ptr);
int axis = static_cast<int>(this->Attr<int>("axis"));
VLOG(3) << "Runing gather_grad composite func";
if (tensor_axis.is_initialized()) {
PADDLE_THROW(platform::errors::Unimplemented(
"We don't support dynamic index from tensor for gather composite "
"grad for now. "));
} else {
prim::gather_grad<prim::DescTensor>(x, index, dout, axis, false, dx_ptr);
}
this->RecoverOutputName(dx, dx_name);
}
};

DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X");

} // namespace operators
Expand All @@ -146,6 +175,7 @@ REGISTER_OPERATOR(gather,
ops::GatherOpMaker,
ops::GatherGradOpMaker<paddle::framework::OpDesc>,
ops::GatherGradOpMaker<paddle::imperative::OpBase>,
ops::GatherCompositeGradOpMaker,
GatherInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(gather_grad,
GatherGradInferShapeFunctor,
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
- scatter
- scatter_nd_add
- tile
- transpose
- subtract
Loading

0 comments on commit 4db8e5c

Please sign in to comment.