From 4db8e5c76ee6675ccfaab61a0e0416f6fef15cea Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Wed, 22 Feb 2023 10:27:13 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Prim=E3=80=91Add=20gather=20vjp=20(#50?= =?UTF-8?q?305)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tmp gather vjp * support gather * remove useless code * fix compiling error * fix ut * add eager test * add eager test * add seed * fix cpu error * fix transpose op compat * remove tensor index case * fix prim_cinn * fix ut --- .gitignore | 3 +- .../elementwise/elementwise_add_op.cc | 4 +- paddle/fluid/operators/gather_op.cc | 30 ++ paddle/fluid/prim/api/api.yaml | 1 + .../prim/api/auto_code_generated/prim_base.py | 344 ++++++++++++++++++ .../composite_backward_api.h | 34 ++ .../utils/static/composite_grad_desc_maker.h | 84 +++-- paddle/phi/api/yaml/op_compat.yaml | 3 + .../dygraph_to_static/test_cinn_prim.py | 9 + .../vjp/eager/test_comp_eager_gather_grad.py | 116 ++++++ .../prim/vjp/static/test_comp_gather_grad.py | 236 ++++++++++++ 11 files changed, 836 insertions(+), 28 deletions(-) create mode 100644 paddle/fluid/prim/api/auto_code_generated/prim_base.py create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_gather_grad.py create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_gather_grad.py diff --git a/.gitignore b/.gitignore index eef92a0488cd22..f88ab47d9bd2e8 100644 --- a/.gitignore +++ b/.gitignore @@ -26,7 +26,8 @@ paddle/phi/api/lib/tensor_operants.cc paddle/phi/extension.h paddle/phi/include/* paddle/phi/infermeta/generated.* - +paddle/fluid/prim/api/generated_prim/*.cc +paddle/fluid/prim/api/generated_prim/*.h *.DS_Store *.vs build/ diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index c122a07c9b1d49..700a69fa3ce487 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -61,10 +61,10 @@ class ElementwiseAddCompositeGradOpMaker paddle::experimental::Tensor y = this->GetSingleForwardInput("Y"); paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out"); paddle::experimental::Tensor dx = this->GetSingleInputGrad("X"); - auto dx_ptr = this->GetOutputPtr(&dx); + auto* dx_ptr = this->GetOutputPtr(&dx); std::string dx_name = this->GetOutputName(dx); paddle::experimental::Tensor dy = this->GetSingleInputGrad("Y"); - auto dy_ptr = this->GetOutputPtr(&dy); + auto* dy_ptr = this->GetOutputPtr(&dy); std::string dy_name = this->GetOutputName(dy); int axis = static_cast(this->Attr("axis")); VLOG(6) << "Runing add_grad composite func"; diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 4b85dee9a270be..de0c6409b4615d 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -19,6 +19,8 @@ limitations under the License. */ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/backward.h" @@ -132,6 +134,33 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker { } }; +class GatherCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { + public: + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + protected: + void Apply() override { + paddle::experimental::Tensor index = this->GetSingleForwardInput("Index"); + paddle::optional tensor_axis = + this->GetOptionalSingleForwardInput("Axis"); + paddle::experimental::Tensor x = this->GetSingleForwardInput("X"); + paddle::experimental::Tensor dout = this->GetSingleOutputGrad("Out"); + paddle::experimental::Tensor dx = this->GetSingleInputGrad("X"); + auto* dx_ptr = this->GetOutputPtr(&dx); + std::string dx_name = this->GetOutputName(*dx_ptr); + int axis = static_cast(this->Attr("axis")); + VLOG(3) << "Runing gather_grad composite func"; + if (tensor_axis.is_initialized()) { + PADDLE_THROW(platform::errors::Unimplemented( + "We don't support dynamic index from tensor for gather composite " + "grad for now. ")); + } else { + prim::gather_grad(x, index, dout, axis, false, dx_ptr); + } + this->RecoverOutputName(dx, dx_name); + } +}; + DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X"); } // namespace operators @@ -146,6 +175,7 @@ REGISTER_OPERATOR(gather, ops::GatherOpMaker, ops::GatherGradOpMaker, ops::GatherGradOpMaker, + ops::GatherCompositeGradOpMaker, GatherInferShapeFunctor); DECLARE_INFER_SHAPE_FUNCTOR(gather_grad, GatherGradInferShapeFunctor, diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index 3d5a92a398cc53..fc64b758adee02 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -23,4 +23,5 @@ - scatter - scatter_nd_add - tile +- transpose - subtract diff --git a/paddle/fluid/prim/api/auto_code_generated/prim_base.py b/paddle/fluid/prim/api/auto_code_generated/prim_base.py new file mode 100644 index 00000000000000..3e45fbc6419b5e --- /dev/null +++ b/paddle/fluid/prim/api/auto_code_generated/prim_base.py @@ -0,0 +1,344 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# prim api list +white_ops_list = [ + "pow", + "scale", + "multiply", + "unsqueeze", + "expand", + "full", + "reshape", + "divide", + "sum", + "exp", + "scatter", + "transpose", +] + +inplace_out_type_map = { + "Tensor": "Tensor&", + "std::vector": "std::vector&", +} + +inplace_optional_out_type_map = { + "Tensor": "paddle::optional&", + "std::vector": "paddle::optional>&", +} + + +class BaseAPI: + def __init__(self, api_item_yaml, prims=tuple()): + # self.api = api_item_yaml['op'] + self.api = api_item_yaml['name'] + + self.is_prim_api = False + if api_item_yaml['name'] in prims: + self.is_prim_api = True + + ####################################### + # inputs: + # names : [], list of input names + # input_info : {input_name : type} + # attrs: + # names : [], list of attribute names + # attr_info : { attr_name : (type, default_values)} + # outputs: + # names : [], list of output names + # types : [], list of output types + # out_size_expr : [], expression for getting size of vector + ######################################## + if self.is_prim_api: + ( + self.inputs, + self.attrs, + self.outputs, + self.optional_vars, + ) = self.parse_args(self.api, api_item_yaml) + + self.inplace_map = api_item_yaml['inplace'] + + def get_api_func_name(self): + return self.api + + # def is_inplace(self): + # if self.inplace_map + # return True + # return False + + def get_input_tensor_args(self, inplace_flag=False): + input_args = [] + inplace_type_map = { + "const Tensor&": "Tensor&", + "const paddle::optional&": "paddle::optional&", + "const std::vector&": "std::vector&", + "const paddle::optional>&": "paddle::optional>&", + } + for name in self.inputs['names']: + name = name.split('@')[0] + if inplace_flag and name in self.inplace_map.values(): + input_args.append( + inplace_type_map[self.inputs['input_info'][name]] + + ' ' + + name + ) + else: + input_args.append(self.inputs['input_info'][name] + ' ' + name) + return input_args + + def get_declare_args(self, inplace_flag=False): + declare_args = self.get_input_tensor_args(inplace_flag) + for name in self.attrs['names']: + default_value = '' + if self.attrs['attr_info'][name][1] is not None: + default_value = ' = ' + self.attrs['attr_info'][name][1] + declare_args.append( + self.attrs['attr_info'][name][0] + ' ' + name + default_value + ) + + return ", ".join(declare_args) + + def get_declare_args_nodefault(self, inplace_flag=False): + declare_args = self.get_input_tensor_args(inplace_flag) + for name in self.attrs['names']: + declare_args.append(self.attrs['attr_info'][name][0] + ' ' + name) + + return ", ".join(declare_args) + + def get_return_type(self, inplace_flag=False): + out_type_list = [] + for i, out_type in enumerate(self.outputs['types']): + out_name = self.outputs['names'][i].split('@')[0] + if inplace_flag and out_name in self.inplace_map: + if self.inplace_map[out_name] in self.optional_vars: + out_type_list.append( + inplace_optional_out_type_map[out_type] + ) + else: + out_type_list.append(inplace_out_type_map[out_type]) + else: + out_type_list.append(out_type) + if len(out_type_list) == 1: + return out_type_list[0] + else: + return "std::tuple<" + ", ".join(out_type_list) + ">" + + def parse_args(self, api_name, api_item_yaml): + optional_vars = [] + for input_dict in api_item_yaml['inputs']: + if input_dict['optional']: + optional_vars.append(input_dict['name']) + + inputs, attrs = self.parse_input_and_attr( + api_item_yaml['inputs'], api_item_yaml['attrs'] + ) + + output_type_list, output_names, out_size_expr = self.parse_output( + api_item_yaml['outputs'] + ) + return ( + inputs, + attrs, + { + 'names': output_names, + 'types': output_type_list, + 'out_size_expr': out_size_expr, + }, + optional_vars, + ) + + def parse_input_and_attr(self, inputs_list, attrs_list): + input_types_map = { + 'Tensor': 'const Tensor&', + 'Tensor[]': 'const std::vector&', + } + attr_types_map = { + 'IntArray': 'const IntArray&', + 'Scalar': 'const Scalar&', + 'Scalar(int)': 'const Scalar&', + 'Scalar(int64_t)': 'const Scalar&', + 'Scalar(float)': 'const Scalar&', + 'Scalar(dobule)': 'const Scalar&', + 'Scalar[]': 'const std::vector&', + 'int': 'int', + 'int32_t': 'int32_t', + 'int64_t': 'int64_t', + 'long': 'long', + 'size_t': 'size_t', + 'float': 'float', + 'float[]': 'const std::vector&', + 'double': 'double', + 'bool': 'bool', + 'bool[]': 'const std::vector&', + 'str': 'const std::string&', + 'str[]': 'const std::vector&', + 'Place': 'const Place&', + 'DataLayout': 'DataLayout', + 'DataType': 'DataType', + 'int64_t[]': 'const std::vector&', + 'int[]': 'const std::vector&', + } + optional_types_trans = { + 'Tensor': 'const paddle::optional&', + 'Tensor[]': 'const paddle::optional>&', + 'int': 'paddle::optional', + 'int32_t': 'paddle::optional', + 'int64_t': 'paddle::optional', + 'float': 'paddle::optional', + 'double': 'paddle::optional', + 'bool': 'paddle::optional', + 'Place': 'paddle::optional', + 'DataLayout': 'paddle::optional', + 'DataType': 'paddle::optional', + } + + inputs = {'names': [], 'input_info': {}} + for input_dict in inputs_list: + inputs['names'].append(input_dict['name']) + if input_dict['optional']: + inputs['input_info'][input_dict['name']] = optional_types_trans[ + input_dict['typename'] + ] + else: + inputs['input_info'][input_dict['name']] = input_types_map[ + input_dict['typename'] + ] + + attrs = {'names': [], 'attr_info': {}} + for attr_dict in attrs_list: + attrs['names'].append(attr_dict['name']) + if 'default_value' in attr_dict.keys(): + default_value = attr_dict['default_value'] + else: + default_value = None + + if 'optional' in attr_dict.keys(): + attrs['attr_info'][attr_dict['name']] = ( + optional_types_trans[attr_dict['typename']], + default_value, + ) + else: + attrs['attr_info'][attr_dict['name']] = ( + attr_types_map[attr_dict['typename']], + default_value, + ) + return inputs, attrs + + def parse_output(self, outputs_list): + + out_type_list = [] + out_name_list = [] + out_size_expr_list = [] + for output_dict in outputs_list: + if output_dict['intermediate']: + continue + out_type_list.append(output_dict['typename']) + out_name_list.append(output_dict['name']) + if 'size' in output_dict.keys(): + out_size_expr_list.append(output_dict['size']) + else: + out_size_expr_list.append(None) + return out_type_list, out_name_list, out_size_expr_list + + +class EagerPrimAPI(BaseAPI): + def __init__(self, api_item_yaml, prims=tuple()): + super().__init__(api_item_yaml, prims) + + def get_api__func_name(self): + api_func_name = self.api + # if self.is_inplace: + # if api_func_name[-1] != '_': + # api_func_name += '_' + # print("after api name", api_func_name) + return api_func_name + + def gene_prim_api_declaration(self): + api_declaration = "" + api_func_name = self.get_api__func_name() + if api_func_name[-1] != '_': + api_declaration = f""" +template +{self.get_return_type()} {api_func_name}({self.get_declare_args()}); +""" + else: + api_declaration = ( + api_declaration + + f""" +template +{self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_declare_args(inplace_flag=True)}); +""" + ) + + return api_declaration + + def get_ad_func_input_args(self, inplace_flag=False): + input_args = [] + for name in self.inputs['names']: + name = name.split('@')[0] + if inplace_flag and name in self.inplace_map.values(): + input_args.append(name) + else: + input_args.append(name) + return input_args + + def get_ad_func_args(self, inplace_flag=False): + ad_func_args = self.get_ad_func_input_args(inplace_flag) + for name in self.attrs['names']: + default_value = '' + if self.attrs['attr_info'][name][1] is not None: + default_value = ' = ' + self.attrs['attr_info'][name][1] + ad_func_args.append(name) + + ad_func_args_str = ", ".join(ad_func_args) + return ad_func_args_str + + def gene_ad_func_call(self): + api_func_name = self.get_api__func_name() + + dygraph_ad_func_name = '::' + api_func_name + '_ad_func' + dygraph_ad_func_parameters = self.get_ad_func_args() + + ad_func_call_str = f""" +VLOG(4) << "Eager Prim API {api_func_name}_ad_func call"; +return {dygraph_ad_func_name}({dygraph_ad_func_parameters}); +""" + # print("ad_func_call_str: ", ad_func_call_str) + return ad_func_call_str + + def gene_eager_prim_api_code(self): + api_code = "" + indent = " " + api_func_name = self.get_api__func_name() + template = '' + # func decalaration + if api_func_name[-1] != '_': + api_code = f""" +template <> +{self.get_return_type()} {api_func_name}{template}({self.get_declare_args_nodefault()}) +""" + else: + api_code = f""" +template <> +{self.get_return_type(inplace_flag=True)} {api_func_name}{template}({self.get_declare_args_nodefault(inplace_flag=True)}) +""" + # func code + + api_code = api_code + '{' + api_code += f"""{self.gene_ad_func_call()}""" + api_code += '}' + '\n' + + return api_code diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 0b53fc71f97d2b..e99f816bf6da28 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -24,6 +24,40 @@ using IntArray = paddle::experimental::IntArrayBase; // This function should have as same signature as phi, which defined in // paddle/phi/api/backward/backward_api.h +template +void gather_grad(const Tensor& x, + const Tensor& index, + const Tensor& out_grad, + const Scalar& axis, + bool overwrite, + Tensor* grad_x) { + auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); + std::vector tmp_perm; + + // change axis to rank 0 + int axis_value = axis.to(); + tmp_perm.push_back(axis_value); + // make other ranks + for (int i = 0; i < x.dims().size(); ++i) { + if (i != axis_value) { + tmp_perm.push_back(i); + } + } + std::vector reverse_perm(tmp_perm); + // make origin ranks + for (int i = 0; i < static_cast(tmp_perm.size()); ++i) { + reverse_perm[tmp_perm[i]] = i; + } + + // transpose out_grad and zero grad to target rank. + auto tmp_zero_x_grad = transpose(zero_tensor, tmp_perm); + auto tmp_out_grad = transpose(out_grad, tmp_perm); + // scatter grad to grad_x + auto tmp_grad_x = scatter(tmp_zero_x_grad, index, tmp_out_grad, false); + auto tmp_grad_x_tranposed = transpose(tmp_grad_x, reverse_perm); + set_output(tmp_grad_x_tranposed, grad_x); +} + template void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { if (!grad_x) return; diff --git a/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h b/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h index 6de5c1ae41c030..eb8ddfa865de93 100644 --- a/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h +++ b/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h @@ -38,9 +38,9 @@ namespace prim { /* This functor class is responsible for creating the gradient ops for the given - operator fwd_op. After it is called (through operator()), the pairs of - (gradient variable, corresponding input variable of fwd_op) will be added to - grad_to_var. If an input variable of fwd_op is contained in no_grad_set, its + operator fwd_op_. After it is called (through operator()), the pairs of + (gradient variable, corresponding input variable of fwd_op_) will be added to + grad_to_var. If an input variable of fwd_op_ is contained in no_grad_set, its gradient variable will be ignored or kEmptyVarName depending on the template argument DropEmptyIG in the derived classes. */ @@ -114,34 +114,40 @@ class CompositeGradOpMakerBase { paddle::optional GetOptionalSingleForwardOutput( const std::string& name) { paddle::optional output_opt; - framework::VarDesc* output_desc = this->SingleForwardOutput(name); - if (!output_desc) return output_opt; - paddle::experimental::Tensor output = - paddle::experimental::Tensor(std::make_shared(output_desc)); - output_opt = paddle::make_optional(output); + if (fwd_op_.Outputs().find(name) != fwd_op_.Outputs().end()) { + framework::VarDesc* output_desc = this->SingleForwardOutput(name); + if (!output_desc) return output_opt; + paddle::experimental::Tensor output = paddle::experimental::Tensor( + std::make_shared(output_desc)); + output_opt = paddle::make_optional(output); + } return output_opt; } paddle::optional GetOptionalSingleForwardInput( const std::string& name) { paddle::optional input_opt; - framework::VarDesc* input_desc = this->SingleForwardInput(name); - if (!input_desc) return input_opt; - paddle::experimental::Tensor input = - paddle::experimental::Tensor(std::make_shared(input_desc)); - input_opt = paddle::make_optional(input); + if (fwd_op_.Inputs().find(name) != fwd_op_.Inputs().end()) { + framework::VarDesc* input_desc = this->SingleForwardInput(name); + if (!input_desc) return input_opt; + paddle::experimental::Tensor input = paddle::experimental::Tensor( + std::make_shared(input_desc)); + input_opt = paddle::make_optional(input); + } return input_opt; } paddle::optional GetOptionalSingleOutputGrad( const std::string& name) { paddle::optional output_grad_opt; - framework::VarDesc* output_grad_desc = this->SingleOutputGrad(name); - if (!output_grad_desc) return output_grad_opt; - paddle::experimental::Tensor output_grad = paddle::experimental::Tensor( - std::make_shared(output_grad_desc)); - output_grad_opt = - paddle::make_optional(output_grad); + if (fwd_op_.Outputs().find(name) != fwd_op_.Outputs().end()) { + framework::VarDesc* output_grad_desc = this->SingleOutputGrad(name); + if (!output_grad_desc) return output_grad_opt; + paddle::experimental::Tensor output_grad = paddle::experimental::Tensor( + std::make_shared(output_grad_desc)); + output_grad_opt = + paddle::make_optional(output_grad); + } return output_grad_opt; } @@ -457,16 +463,44 @@ class CompositeGradOpMakerBase { framework::VarDesc* SingleForwardInput(const std::string& name) const { // Copy Var from original block to active block, or create a new one. - CopyVarFromOrig(fwd_op_.Input(name).at(0)); - return StaticCompositeContext::Instance().GetBlock()->FindVar( - fwd_op_.Input(name).at(0)); + auto fwd_in_names = fwd_op_.Input(name); + if (!fwd_in_names.empty()) { + PADDLE_ENFORCE_EQ( + fwd_in_names.size(), + 1, + phi::errors::InvalidArgument( + "When calling SingleForward for op: %s's Input: %s, we should " + "only get one input tensor, but we got %d instead.", + fwd_op_.Type(), + name, + fwd_in_names.size())); + CopyVarFromOrig(fwd_op_.Input(name).at(0)); + return StaticCompositeContext::Instance().GetBlock()->FindVar( + fwd_op_.Input(name).at(0)); + } else { + return nullptr; + } } framework::VarDesc* SingleForwardOutput(const std::string& name) const { // Copy Var from original block to active block, or create a new one. - CopyVarFromOrig(fwd_op_.Output(name).at(0)); - return StaticCompositeContext::Instance().GetBlock()->FindVar( - fwd_op_.Output(name).at(0)); + auto fwd_out_names = fwd_op_.Output(name); + if (!fwd_out_names.empty()) { + PADDLE_ENFORCE_EQ( + fwd_out_names.size(), + 1, + phi::errors::InvalidArgument( + "When calling SingleForward for op: %s's Output: %s, we should " + "only get one input tensor, but we got %d instead.", + fwd_op_.Type(), + name, + fwd_out_names.size())); + CopyVarFromOrig(fwd_op_.Output(name).at(0)); + return StaticCompositeContext::Instance().GetBlock()->FindVar( + fwd_op_.Output(name).at(0)); + } else { + return nullptr; + } } std::vector MultiForwardInput( diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index a34ee5471a5aa5..6a89f0e7d2ce9e 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1675,7 +1675,10 @@ - op : transpose (transpose2) backward : transpose_grad (transpose2_grad) + attrs: + perm : axis extra : + outputs : [XShape] attrs : [bool use_mkldnn = false, str data_format = "AnyLayout", bool use_quantizer = false, str mkldnn_data_type = "float32"] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py index b5af3a19d49afc..1a0fe1a6938cbe 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py @@ -129,8 +129,17 @@ def check_prim(self, net, use_prim): if not use_prim: return fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + all_ops = [ + op.type + for op in net.forward.program_cache.last()[-1][-1] + .train_program.block(0) + .ops + ] # Ensure that softmax is splitted into small ops self.assertTrue('softmax' not in fwd_ops) + for op in all_ops: + if op != "matmul_v2_grad": + self.assertTrue("_grad" not in op) def test_cinn_prim(self): dy_res = self.train(use_prim=False) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_gather_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_gather_grad.py new file mode 100644 index 00000000000000..2da25afb7a9b8c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_gather_grad.py @@ -0,0 +1,116 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + + +@param.parameterized_class( + ('primal0', 'index', 'axis', 'x_dtype', 'index_dtype', 'v'), + [ + ( + np.random.rand(100), + np.array([1, 3, 5]), + 0, + np.float32, + np.int32, + np.random.rand(3), + ), + ( + np.random.rand(10, 20), + np.array([1, 3, 5]), + 0, + np.float64, + np.int64, + np.random.rand(3, 20), + ), + ( + np.random.rand(10, 20), + np.array([1, 1, 3]), + 0, + np.float32, + np.int32, + np.random.rand(3, 20), + ), + ( + np.random.rand(3, 88, 30), + np.array([1, 3, 5]), + 1, + np.float32, + np.int32, + np.random.rand(3, 3, 30), + ), + ( + np.random.rand(10, 88, 10), + np.array([1, 3, 5]), + 0, + np.float32, + np.int32, + np.random.rand(3, 88, 10), + ), + ], +) +class TestGatherGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal0 = cls.primal0.astype(cls.x_dtype) + cls.index = cls.index.astype(cls.index_dtype) + cls.v = cls.v.astype(cls.x_dtype) + + @classmethod + def tearDownClass(cls): + core._set_prim_backward_enabled(False) + + def test_exp_grad_comp(self): + def actual(primal0, index, axis): + core._set_prim_backward_enabled(True) + paddle.disable_static() + x = paddle.to_tensor( + primal0, dtype=primal0.dtype, stop_gradient=False + ) + index = paddle.to_tensor(index, dtype=index.dtype) + x.stop_gradient = False + index.stop_gradient = True + out = paddle.gather(x, index, axis) + res = paddle.grad(out, [x], create_graph=False, retain_graph=True) + return res[0].numpy() + + def desired(primal0, index, axis): + core._set_prim_backward_enabled(False) + paddle.disable_static() + x = paddle.to_tensor( + primal0, dtype=primal0.dtype, stop_gradient=False + ) + index = paddle.to_tensor(index, dtype=index.dtype) + x.stop_gradient = False + index.stop_gradient = True + out = paddle.gather(x, index, axis) + res = paddle.grad(out, [x], create_graph=False, retain_graph=True) + return res[0].numpy() + + np.testing.assert_allclose( + actual=actual(self.primal0, self.index, self.axis), + desired=desired(self.primal0, self.index, self.axis), + rtol=1e-6, + atol=0, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_gather_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_gather_grad.py new file mode 100644 index 00000000000000..284620ba76ae36 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_gather_grad.py @@ -0,0 +1,236 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core, framework + +np.random.seed(2023) + + +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=build_strategy) + + +class PrimeNet(paddle.nn.Layer): + def __init__(self): + super(PrimeNet, self).__init__() + self.fc = paddle.nn.Linear(4, 4) + + def forward(self, x, index, axis): + tmp = self.fc(x) + out = paddle.gather(tmp, index, axis) + return out + + +@param.parameterized_class( + ('primal0', 'index', 'axis', 'x_dtype', 'index_dtype', 'v', "count"), + [ + ( + np.random.rand(100), + np.array([1, 3, 5]), + 0, + np.float32, + np.int32, + np.random.rand(3), + 0, + ), + ( + np.random.rand(10, 20), + np.array([1, 3, 5]), + 0, + np.float64, + np.int64, + np.random.rand(3, 20), + 1, + ), + ( + np.random.rand(10, 20), + np.array([1, 1, 3]), + 0, + np.float32, + np.int32, + np.random.rand(3, 20), + 2, + ), + ( + # Something wrong with gather grad cpu kernel + np.random.rand(3, 88, 30), + np.array([1, 3, 5]), + 1, + np.float32, + np.int32, + np.random.rand(3, 3, 30), + 3, + ), + ( + np.random.rand(10, 88, 10), + np.array([1, 3, 5]), + 0, + np.float16, + np.int32, + np.random.rand(3, 88, 10), + 4, + ), + ], +) +class TestGatherGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal0 = cls.primal0.astype(cls.x_dtype) + cls.index = cls.index.astype(cls.index_dtype) + cls.v = cls.v.astype(cls.x_dtype) + + def train(self, use_prim, use_cinn): + paddle.seed(2022) + self.x = paddle.randn([2, 4]) + self.index = paddle.to_tensor(np.array([0, 1])) + self.x.stop_gradient = False + net = PrimeNet() + core._set_prim_backward_enabled(use_prim) + net = apply_to_static(net, use_cinn) + out = net(self.x, self.index, 0) + res = paddle.autograd.grad(out, [self.x]) + + return res + + def test_cinn(self): + paddle.disable_static() + dy_res = self.train(use_prim=False, use_cinn=False) + # TODO(jiabin): CINN will crashed in this case open it when fixed + comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + + for i in range(len(dy_res)): + np.testing.assert_allclose( + comp_st_cinn_res[i].numpy(), + dy_res[i].numpy(), + rtol=1e-6, + atol=1e-6, + ) + paddle.enable_static() + + def test_tanh_grad_comp(self): + paddle.enable_static() + + def actual(primal0, index, axis, v): + core._set_prim_backward_enabled(True) + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data('primal0', primal0.shape, primal0.dtype) + index_tmp = paddle.static.data( + 'index', index.shape, index.dtype + ) + x.stop_gradient = False + index_tmp.stop_gradient = True + z = paddle.gather(x, index_tmp, axis) + z_grad = paddle.static.data('v', z.shape, z.dtype) + res = paddle.static.gradients([z], [x], [z_grad]) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run( + program=mp, + feed={ + 'primal0': primal0, + 'index': index, + 'v': v, + }, + fetch_list=[res[0].name], + ) + return out[0] + + def desired(primal0, index, axis, v): + core._set_prim_backward_enabled(False) + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data('primal0', primal0.shape, primal0.dtype) + index_tmp = paddle.static.data( + 'index', index.shape, index.dtype + ) + x.stop_gradient = False + index_tmp.stop_gradient = True + z = paddle.gather(x, index_tmp, axis) + z_grad = paddle.static.data('v', z.shape, z.dtype) + res = paddle.static.gradients([z], [x], [z_grad]) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run( + program=mp, + feed={ + 'primal0': primal0, + 'index': index, + 'v': v, + }, + fetch_list=[res[0].name], + ) + return out[0] + + dx = None + ddx = None + + # fp16 is not supported for cpu gather + if not ( + (self.count == 4) + and isinstance( + framework._current_expected_place(), framework.core.CPUPlace + ) + ): + dx = actual(self.primal0, self.index, self.axis, self.v) + + ddx = desired(self.primal0, self.index, self.axis, self.v) + + if (self.count >= 3) and isinstance( + framework._current_expected_place(), framework.core.CPUPlace + ): + # Scatter in phi has problem with cpu kernel of case 4, so skip this + pass + elif (self.count == 4) and ( + not isinstance( + framework._current_expected_place(), framework.core.CPUPlace + ) + ): + # FP16 test case + np.testing.assert_allclose( + actual=dx, + desired=ddx, + rtol=1e-3, + atol=0, + ) + elif self.count == 1: + # FP64 test case + np.testing.assert_allclose( + actual=dx, + desired=ddx, + rtol=1e-15, + atol=1e-15, + ) + else: + # FP32 test cases + np.testing.assert_allclose( + actual=dx, + desired=ddx, + rtol=1e-5, + atol=0, + ) + core._set_prim_backward_enabled(False) + paddle.disable_static() + + +if __name__ == '__main__': + unittest.main()