From de37c94b09eb0e506af7a4775b4dfaacc068e565 Mon Sep 17 00:00:00 2001 From: chen2016013 <111894720+chen2016013@users.noreply.github.com> Date: Mon, 15 Jan 2024 12:52:55 +0800 Subject: [PATCH] [PIR] Improve Dtype Transfer (add infermeta by value) (#60677) * update * fix null value * adapt manual op * update --- .gitignore | 1 + .../fluid/pir/dialect/op_generator/op_gen.py | 48 + .../dialect/op_generator/op_infermeta_gen.py | 738 +++++++ .../dialect/op_generator/op_interface_gen.py | 6 +- .../dialect/operator/interface/infermeta.h | 23 +- .../pir/dialect/operator/ir/manual_op.cc | 1848 ++++++++++++++++- .../fluid/pir/dialect/operator/ir/manual_op.h | 72 +- .../fluid/pir/dialect/operator/utils/utils.cc | 61 + .../fluid/pir/dialect/operator/utils/utils.h | 3 + .../pir/transforms/pd_op_to_kernel_pass.cc | 137 +- test/cpp/pir/core/ir_infershape_test.cc | 7 + 11 files changed, 2878 insertions(+), 66 deletions(-) create mode 100644 paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py diff --git a/.gitignore b/.gitignore index 008e4b06e5834..a9cd56d760724 100644 --- a/.gitignore +++ b/.gitignore @@ -109,6 +109,7 @@ paddle/fluid/pir/dialect/operator/ir/op_decomp.cc paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc paddle/fluid/pir/dialect/operator/ir/pd_op.* paddle/fluid/pir/dialect/operator/ir/onednn_op.* +paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.* paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.* paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.* paddle/fluid/pir/dialect/operator/ir/pd_op_fused.* diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 5c9c4c97e0e78..3c4661a7caaf7 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -22,6 +22,10 @@ from decomp_interface_gen_op_list import decomp_interface_declare_gen_op_list from infer_symbolic_shape_gen import gen_infer_symbolic_shape_str from op_build_gen import gen_build_func_str, gen_build_func_str_by_invoke +from op_infermeta_gen import ( + gen_infermeta_by_invoke_func_str, + gen_infermeta_func_str, +) from op_interface_gen import ( gen_exclusive_interface_str, gen_op_infer_meta_str, @@ -142,6 +146,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/pir/dialect/op_generator/op_gen.py" #include "{h_file}" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h" #include "paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h" @@ -1712,6 +1717,48 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): op_info, op_class_name, op_info_items ) + op_infer_meta_from_type_str = "" + if op_infer_meta_map is not None: + muta_attr_is_input = ( + True + if len(op_mutable_attribute_name_list) > 0 + else False + ) + op_infer_meta_from_type_str = gen_infermeta_func_str( + op_class_name, + op_input_name_list, + op_input_type_list, + op_input_optional_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_output_name_list, + op_output_type_list, + op_output_size_list, + op_output_optional_list, + op_infer_meta_map, + op_inplace_map, + op_attribute_name_list, + op_attribute_type_list, + op_attribute_build_arg_type_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_non_mutable_attribute_build_arg_type_list, + muta_attr_is_input, + attr_args_is_map=True, + ) + + if (op_invoke_map is not None) and ( + op_invoke_map['func'] in op_info_items + ): + op_invoke_class_name = ( + to_pascal_case(op_invoke_map['func']) + "Op" + ) + op_infer_meta_from_type_str = ( + gen_infermeta_by_invoke_func_str( + op_class_name, op_invoke_class_name + ) + ) + # =================================== # # gen Vjp func str # # =================================== # @@ -1753,6 +1800,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): ops_defined_list.append(op_verify_str) ops_defined_list.append(op_infer_meta_str) + ops_defined_list.append(op_infer_meta_from_type_str) ops_defined_list.append(op_get_kernel_type_for_var_str) ops_defined_list.append(parse_kernel_key_define_str) ops_defined_list.append(infer_symbolic_shape_define_str) diff --git a/paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py b/paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py new file mode 100644 index 0000000000000..2c38dd43701aa --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py @@ -0,0 +1,738 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_build_gen import ( + _INFERMETA_NEED_META_CONFIG, + _PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE, +) + +OP_INFERMETA_TEMPLATE = """ +std::vector {op_name}::InferMeta(const std::vector& input_values, const pir::AttributeMap& attributes) {{ +{infermeta_inputs} +{get_attributes_str} +{infermeta_outputs} + return argument_outputs; +}} +""" + +CREATE_INPUT_VALUE_TEMPLATE = """ + pir::Value {input_name}_ = input_values[{index}]; (void){input_name}_;""" + +ENFORCE_INPUT_NUM_TEMPLATE = """ + IR_ENFORCE(input_values.size() == {op_input_name_list_size}, + "Num of inputs is expected to be {op_input_name_list_size} but got %d.", input_values.size()); +""" + +OP_INFERMETA_BY_INVOKE_TEMPLATE = """ +std::vector {op_name}::InferMeta(const std::vector& input_values, const pir::AttributeMap& attributes) {{ + return {invoke_class}::InferMeta(input_values, attributes); +}} +""" + +GET_INPUT_TYPE_TEMPLATE = """ + {type} {name}; + if ({name}_.type().isa<{type}>()) {{ + {name} = {name}_.type().dyn_cast<{type}>(); (void){name}; + }} else if ({name}_.type().isa<{allocated_type}>()) {{ + {allocated_type} allocated_{name} = {name}_.type().dyn_cast<{allocated_type}>(); + {name} = {type}::get(pir::IrContext::Instance(), + allocated_{name}.dtype(), + allocated_{name}.dims(), + allocated_{name}.data_layout(), + allocated_{name}.lod(), + allocated_{name}.offset()); + (void){name}; + }} else {{ + PADDLE_THROW(phi::errors::Unimplemented("Only support {type} or {allocated_type}")); + }} +""" + + +def get_infermeta_inputs_str( + inuse_infer_meta_args, + op_input_name_list, + op_input_type_list, + op_input_optional_list, + op_mutable_attribute_name_list, + mutable_attr_is_input, +): + op_input_name_list_size = len(op_input_name_list) + if mutable_attr_is_input: + op_input_name_list_size += len(op_mutable_attribute_name_list) + + infermeta_inputs_str = ENFORCE_INPUT_NUM_TEMPLATE.format( + op_input_name_list_size=str(op_input_name_list_size), + ) + + for i in range(len(op_input_name_list)): + if op_input_name_list[i] not in inuse_infer_meta_args: + continue + infermeta_inputs_str += CREATE_INPUT_VALUE_TEMPLATE.format( + input_name=op_input_name_list[i], index=str(i) + ) + + if mutable_attr_is_input: + # add mutable attributes as inputs + if len(op_mutable_attribute_name_list) > 0: + for i in range(len(op_mutable_attribute_name_list)): + if ( + op_mutable_attribute_name_list[i] + not in inuse_infer_meta_args + ): + continue + infermeta_inputs_str += CREATE_INPUT_VALUE_TEMPLATE.format( + input_name=op_mutable_attribute_name_list[i], + index=str(i + len(op_input_name_list)), + ) + infermeta_inputs_str += "\n" + + infermeta_inputs_str += ' VLOG(4) << "Builder construction outputs";\n' + # Prepar input type + for idx in range(len(op_input_name_list)): + if op_input_name_list[idx] not in inuse_infer_meta_args: + continue + # is a vector + if 'pir::VectorType' in op_input_type_list[idx]: + if op_input_optional_list[idx] == 'false': + infermeta_inputs_str += " pir::VectorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( + name=op_input_name_list[idx] + ) + # is a Tensor + else: + if op_input_optional_list[idx] == 'false': + type = op_input_type_list[idx] + allocated_type = type.replace( + 'DenseTensorType', 'AllocatedDenseTensorType' + ).replace("SelectedRowsType", "AllocatedSelectedRowsType") + infermeta_inputs_str += GET_INPUT_TYPE_TEMPLATE.format( + type=type, + name=op_input_name_list[idx], + allocated_type=allocated_type, + ) + + return infermeta_inputs_str + + +def GenBuildOutputsPart2( + op_class_name, + inuse_infer_meta_args, + op_input_name_list, + op_input_type_list, + op_input_optional_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_output_name_list, + op_output_type_list, + op_output_size_list, + op_output_optional_list, + op_infer_meta_map, + op_inplace_map, + mutable_attr_is_input, +): + CREATE_INPUT_METATENSOR_TEMPLATE = """ + VLOG(4) << "Builder construction dense_{name}"; + paddle::dialect::IrTensor ir_tensor_{name}(paddle::dialect::TransToPhiDataType({name}.dtype()), + {name}.dims(), + {name}.data_layout(), + {name}.lod(), + {name}.offset()); + VLOG(4) << "Builder construction meta_{name}"; + paddle::dialect::IrMetaTensor meta_{name}(&ir_tensor_{name}); +""" + + CREATE_OPTIONAL_INPUT_METATENSOR_TEMPLATE = """ + paddle::dialect::IrMetaTensor meta_{name}; + paddle::dialect::IrTensor ir_tensor_{name}; + + + if ({name}_.impl() != nullptr) {{ + VLOG(4) << "Builder construction dense_{name}"; + {type} {name}; + if ({name}_.type().isa<{type}>()) {{ + {name} = {name}_.type().dyn_cast<{type}>(); + }} else if ({name}_.type().isa<{allocated_type}>()) {{ + {allocated_type} allocated_{name} = {name}_.type().dyn_cast<{allocated_type}>(); + {name} = {type}::get(pir::IrContext::Instance(), + allocated_{name}.dtype(), + allocated_{name}.dims(), + allocated_{name}.data_layout(), + allocated_{name}.lod(), + allocated_{name}.offset()); + }} else {{ + PADDLE_THROW(phi::errors::Unimplemented("Only support {type} or {allocated_type}")); + }} + ir_tensor_{name} = paddle::dialect::IrTensor(paddle::dialect::TransToPhiDataType({name}.dtype()), + {name}.dims(), + {name}.data_layout(), + {name}.lod(), + {name}.offset()); + VLOG(4) << "Builder construction meta_{name}"; + meta_{name} = paddle::dialect::IrMetaTensor(&ir_tensor_{name}); + }} + +""" + + CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_ir_tensor_{name}; + for (size_t i=0; i < static_cast({name}.size()); i++) {{ + if({name}[i].isa()) {{ + auto {name}_type = {name}[i].dyn_cast(); + vec_ir_tensor_{name}.push_back(paddle::dialect::IrTensor(paddle::dialect::TransToPhiDataType({name}_type.dtype()), + {name}_type.dims(), + {name}_type.data_layout(), + {name}_type.lod(), + {name}_type.offset())); + }} else if({name}[i].isa()){{ + auto {name}_type = {name}[i].dyn_cast(); + vec_ir_tensor_{name}.push_back(paddle::dialect::IrTensor(paddle::dialect::TransToPhiDataType({name}_type.dtype()), + {name}_type.dims(), + {name}_type.data_layout(), + {name}_type.lod(), + {name}_type.offset())); + }} else {{ + PADDLE_THROW(phi::errors::Unimplemented("Only support DenseTensorType or AllocatedDenseTensorType")); + }} + }} + std::vector vec_meta_{name}; + for (size_t i=0; i < vec_ir_tensor_{name}.size(); i++) {{ + vec_meta_{name}.push_back(paddle::dialect::IrMetaTensor(&vec_ir_tensor_{name}[i])); + }} + + std::vector meta_{name}; + for (size_t i=0; i < static_cast(vec_meta_{name}.size()); i++) {{ + meta_{name}.push_back(&vec_meta_{name}[i]); + }} + """ + + CREATE_OPTIONAL_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_ir_tensor_{name}; + if ({name}_.impl() != nullptr) {{ + pir::VectorType {name} = {name}_.type().dyn_cast(); + for (size_t i=0; i < static_cast({name}.size()); i++) {{ + if({name}[i].isa()) {{ + auto {name}_type = {name}[i].dyn_cast(); + vec_ir_tensor_{name}.push_back(paddle::dialect::IrTensor(paddle::dialect::TransToPhiDataType({name}_type.dtype()), + {name}_type.dims(), + {name}_type.data_layout(), + {name}_type.lod(), + {name}_type.offset())); + }} else if({name}[i].isa()){{ + auto {name}_type = {name}[i].dyn_cast(); + vec_ir_tensor_{name}.push_back(paddle::dialect::IrTensor(paddle::dialect::TransToPhiDataType({name}_type.dtype()), + {name}_type.dims(), + {name}_type.data_layout(), + {name}_type.lod(), + {name}_type.offset())); + }} else {{ + PADDLE_THROW(phi::errors::Unimplemented("Only support DenseTensorType or AllocatedDenseTensorType")); + }} + }} + }} + + std::vector vec_meta_{name}; + for (size_t i=0; i < vec_ir_tensor_{name}.size(); i++) {{ + vec_meta_{name}.push_back(paddle::dialect::IrMetaTensor(&vec_ir_tensor_{name}[i])); + }} + + std::vector meta_{name}; + for (size_t i=0; i < static_cast(vec_meta_{name}.size()); i++) {{ + meta_{name}.push_back(&vec_meta_{name}[i]); + }} + +""" + + CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name}; + if ({name}_.dyn_cast() && {name}_.dyn_cast().owner()->isa()) {{ + {name} = std::move(phi::IntArray(paddle::dialect::GetInt64Vector( + {name}_.dyn_cast().owner() + ->dyn_cast() + .attribute("value")))); + }} else if ({name}_.type().isa()) {{ + size_t {name}_size = {name}_.type().dyn_cast().size(); + {name} = std::move(phi::IntArray(std::vector({name}_size, -1))); + {name}.SetFromTensor(true); + }} else if ({name}_.type().isa()) {{ + common::DDim {name}_dim = {name}_.type().dyn_cast().dims(); + size_t {name}_size = common::product({name}_dim); + if (common::contain_unknown_dim({name}_dim)) {{ + {name}_size = 1; + }} + {name} = std::move(phi::IntArray(std::vector({name}_size, -1))); + {name}.SetFromTensor(true); + }} else if ({name}_.type().isa()) {{ + common::DDim {name}_dim = {name}_.type().dyn_cast().dims(); + size_t {name}_size = common::product({name}_dim); + if (common::contain_unknown_dim({name}_dim)) {{ + {name}_size = 1; + }} + {name} = std::move(phi::IntArray(std::vector({name}_size, -1))); + {name}.SetFromTensor(true); + }} else {{ + PADDLE_THROW(phi::errors::Unimplemented("Only support VectorType or DenseTensorType or AllocatedDenseTensorType")); + }}\n""" + + CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ std::vector {name}; + if ({name}_.dyn_cast() && {name}_.dyn_cast().owner()->isa()) {{ + {name} = paddle::dialect::GetInt64Vector( + {name}_.dyn_cast().owner() + ->dyn_cast() + .attribute("value")); + }} else if ({name}_.type().isa()) {{ + size_t {name}_size = {name}_.type().dyn_cast().size(); + {name} = std::vector({name}_size, -1); + }} else if ({name}_.type().isa()) {{ + common::DDim {name}_dim = {name}_.type().dyn_cast().dims(); + size_t {name}_size = common::product({name}_dim); + if (common::contain_unknown_dim({name}_dim)) {{ + {name}_size = 1; + }} + {name} = std::vector({name}_size, -1); + }} else if ({name}_.type().isa()) {{ + common::DDim {name}_dim = {name}_.type().dyn_cast().dims(); + size_t {name}_size = common::product({name}_dim); + if (common::contain_unknown_dim({name}_dim)) {{ + {name}_size = 1; + }} + {name} = std::vector({name}_size, -1); + }} else {{ + PADDLE_THROW(phi::errors::Unimplemented("Only support VectorType or DenseTensorType or AllocatedDenseTensorType")); + }}\n""" + + CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name}; + if ({name}_.dyn_cast() && {name}_.dyn_cast().owner()->isa()) {{ + {name} = std::move(phi::Scalar({name}_.dyn_cast().owner() + ->dyn_cast() + .attribute("value") + .dyn_cast() + .data() + .to())); + }} + else {{ + {name} = std::move(phi::Scalar(-1)); + {name}.SetFromTensor(true); + }}\n""" + + CREATE_OUTPUT_METATENSOR_TEMPLATE = """ paddle::dialect::IrTensor dense_{name}; + paddle::dialect::IrMetaTensor meta_{name}(&dense_{name}); +""" + CREATE_OUTPUT_METASELETEROWS_TEMPLATE = """ paddle::dialect::IrSelectedRows dense_{name}; + paddle::dialect::IrMetaTensor meta_{name}(&dense_{name}); +""" + CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_dense_{name}(({output_size}), paddle::dialect::IrTensor()); + std::vector vec_meta_{name}; + for (size_t i=0; i < static_cast({output_size}); i++) {{ + vec_meta_{name}.push_back(paddle::dialect::IrMetaTensor(&vec_dense_{name}[i])); + }} + std::vector meta_{name}; + for (size_t i=0; i < static_cast(vec_meta_{name}.size()); i++) {{ + meta_{name}.push_back(&vec_meta_{name}[i]); + }} +""" + build_output_str = "" + # Prepare mutable attributes + if mutable_attr_is_input: + for idx in range(len(op_mutable_attribute_name_list)): + if op_mutable_attribute_name_list[idx] not in inuse_infer_meta_args: + continue + attr_dtype = op_mutable_attribute_type_list[idx] + # int_array + if attr_dtype[0] == "paddle::dialect::IntArrayAttribute": + if ( + op_class_name + in _PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE + ): + build_output_str += CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format( + name=op_mutable_attribute_name_list[idx] + ) + else: + build_output_str += CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format( + name=op_mutable_attribute_name_list[idx] + ) + # scalar + elif attr_dtype[0] == "paddle::dialect::ScalarAttribute": + build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format( + name=op_mutable_attribute_name_list[idx], + dtype=attr_dtype[1], + ) + # string + elif attr_dtype[0] == "pir::StrAttribute": + build_output_str += "" + else: + assert "mutable attribtue type is not right." + build_output_str += "\n" + + # Prepare inputs_meta_tensor & attributes for infer meta + infer_meta_args = [] + for idx in range(len(op_infer_meta_map['param'])): + # is input + if op_infer_meta_map['param'][idx] in op_input_name_list: + if ( + "meta_" + op_infer_meta_map['param'][idx] + ) not in infer_meta_args: + # is a vector + if ( + 'pir::VectorType' + in op_input_type_list[ + op_input_name_list.index( + op_infer_meta_map['param'][idx] + ) + ] + ): + input_index = op_input_name_list.index( + op_infer_meta_map['param'][idx] + ) + if op_input_optional_list[input_index] == 'true': + build_output_str += CREATE_OPTIONAL_INPUT_VEC_METATENSOR_TEMPLATE.format( + name=op_infer_meta_map['param'][idx] + ) + else: + build_output_str += ( + CREATE_INPUT_VEC_METATENSOR_TEMPLATE.format( + name=op_infer_meta_map['param'][idx] + ) + ) + # is a Tensor + else: + input_index = op_input_name_list.index( + op_infer_meta_map['param'][idx] + ) + if op_input_optional_list[input_index] == 'true': + type = op_input_type_list[idx] + allocated_type = type.replace( + 'DenseTensorType', 'AllocatedDenseTensorType' + ).replace( + "SelectedRowsType", "AllocatedSelectedRowsType" + ) + build_output_str += ( + CREATE_OPTIONAL_INPUT_METATENSOR_TEMPLATE.format( + name=op_infer_meta_map['param'][idx], + type=op_input_type_list[idx], + allocated_type=allocated_type, + ) + ) + else: + build_output_str += ( + CREATE_INPUT_METATENSOR_TEMPLATE.format( + name=op_infer_meta_map['param'][idx] + ) + ) + + infer_meta_args.append("meta_" + op_infer_meta_map['param'][idx]) + # is attribute + else: + infer_meta_args.append(op_infer_meta_map['param'][idx]) + + # Prepare outputs_meta_tensor for infer meta + for idx in range(len(op_output_name_list)): + # is a vector + if 'pir::VectorType' in op_output_type_list[idx]: + build_output_str += CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE.format( + name=op_output_name_list[idx], + output_size=op_output_size_list[idx], + ) + infer_meta_args.append(f"meta_{op_output_name_list[idx]}") + # is a Tensor + else: + if op_output_type_list[idx] == "paddle::dialect::DenseTensorType": + build_output_str += CREATE_OUTPUT_METATENSOR_TEMPLATE.format( + name=op_output_name_list[idx] + ) + infer_meta_args.append(f"&meta_{op_output_name_list[idx]}") + else: + build_output_str += ( + CREATE_OUTPUT_METASELETEROWS_TEMPLATE.format( + name=op_output_name_list[idx] + ) + ) + infer_meta_args.append(f"&meta_{op_output_name_list[idx]}") + + # Execute infer meta function + CREATE_INFER_META_FUNC_TEMPLATE = """ + phi::{func}({args}); +""" + CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE = """ + phi::{func}({args}, phi::MetaConfig(false, false)); +""" + if op_infer_meta_map['func'] in _INFERMETA_NEED_META_CONFIG: + build_output_str += ( + CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE.format( + func=op_infer_meta_map['func'], args=", ".join(infer_meta_args) + ) + ) + else: + build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format( + func=op_infer_meta_map['func'], args=", ".join(infer_meta_args) + ) + + # use dense_{name} or vec_dense_{name} to create Outputs type + build_output_str += "\n std::vector argument_outputs;" + + CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE = """ + pir::Type {name}_dense_tensor_type = {type}::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_{name}.dtype()), dense_{name}.dims(), dense_{name}.layout(), dense_{name}.lod(), dense_{name}.offset()); + argument_outputs.push_back({name}_dense_tensor_type); +""" + + CREATE_OUTPUT_INPLACE_OPTIONAL_DENSE_TENSOR_TEMPLATE = """ + if ({input_name}_.impl() != nullptr) {{ + pir::Type {output_name}_dense_tensor_type = {type}::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_{output_name}.dtype()), dense_{output_name}.dims(), dense_{output_name}.layout(), dense_{output_name}.lod(), dense_{output_name}.offset()); + argument_outputs.push_back({output_name}_dense_tensor_type); + }} else {{ + pir::Type {output_name}_type; + argument_outputs.push_back({output_name}_type); + }} + +""" + + CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE = """ + std::vector {name}_types; + for (size_t i=0; i < static_cast({output_size}); i++) {{ + {name}_types.push_back(paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(vec_dense_{name}[i].dtype()), vec_dense_{name}[i].dims(), vec_dense_{name}[i].layout(), vec_dense_{name}[i].lod(), vec_dense_{name}[i].offset())); + }} + pir::Type {name}_vector_type = pir::VectorType::get(pir::IrContext::Instance(), {name}_types); + argument_outputs.push_back({name}_vector_type); +""" + for idx in range(len(op_output_name_list)): + # is a vector + if 'pir::VectorType' in op_output_type_list[idx]: + build_output_str += CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE.format( + name=op_output_name_list[idx], + output_size=op_output_size_list[idx], + ) + # is a Tensor + else: + output_name = op_output_name_list[idx] + has_input_inplace = ( + op_inplace_map is not None + and output_name in op_inplace_map.keys() + ) + if op_output_optional_list[idx] == 'true' and has_input_inplace: + # is a inplace optional output + build_output_str += ( + CREATE_OUTPUT_INPLACE_OPTIONAL_DENSE_TENSOR_TEMPLATE.format( + input_name=op_inplace_map[output_name], + output_name=output_name, + type=op_output_type_list[idx], + ) + ) + else: + build_output_str += CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE.format( + type=op_output_type_list[idx], name=output_name + ) + return build_output_str + + +def GetAttributes( + op_class_name, + muta_attr_is_input, + inuse_infer_meta_args, + op_attribute_name_list, + op_attribute_type_list, + op_attribute_build_arg_type_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_non_mutable_attribute_build_arg_type_list, + attr_args_is_map, +): + GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """ + IR_ENFORCE( + attributes.find("{attribute_name}") != attributes.end(), + "'{attribute_name}' Attribute is expected for {op_name}. "); + {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data(); +""" + GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE = """ + IR_ENFORCE( + attributes.find("{attribute_name}") != attributes.end(), + "'{attribute_name}' Attribute is expected for {op_name}. "); + {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().AsString(); +""" + GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ + IR_ENFORCE( + attributes.find("{attribute_name}") != attributes.end(), + "'{attribute_name}' Attribute is expected for {op_name}. "); + {attr_type} {attribute_name}; + for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ + {attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast().at(i).dyn_cast<{inner_type}>().{data_name}()); + }} +""" + GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ + IR_ENFORCE( + attributes.find("{attribute_name}") != attributes.end(), + "'{attribute_name}' Attribute is expected for {op_name}. "); + {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().GetData(); +""" + GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE = """ + IR_ENFORCE( + attributes.find("{attribute_name}") != attributes.end(), + "'{attribute_name}' Attribute is expected for {op_name}. "); + {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().to<{attr_type}>(); +""" + + get_attributes_str = "" + array_attr_str = "pir::ArrayAttribute" + + attr_names = [] + attr_types = [] + attr_build_arg_types = [] + if not muta_attr_is_input: + attr_names = op_attribute_name_list + attr_types = op_attribute_type_list + attr_build_arg_types = op_attribute_build_arg_type_list + else: + attr_names = op_non_mutable_attribute_name_list + attr_types = op_non_mutable_attribute_type_list + attr_build_arg_types = op_non_mutable_attribute_build_arg_type_list + if attr_args_is_map: + for idx in range(len(attr_names)): + if attr_names[idx] not in inuse_infer_meta_args: + continue + attr_type = attr_build_arg_types[idx] + attr_type = attr_type.replace("const ", "") + attr_type = attr_type.replace("&", "") + # if attr_build_arg_types[idx] == "const std::vector&": + # attr_type = "std::vector" + + if array_attr_str in attr_types[idx]: + inner_type = attr_types[idx][len(array_attr_str) + 1 : -1] + data_name = "data" + if inner_type == "pir::StrAttribute": + data_name = "AsString" + get_attributes_str += ( + GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( + op_name=op_class_name, + attr_type=attr_type, + attribute_name=attr_names[idx], + inner_type=inner_type, + data_name=data_name, + ) + ) + elif "paddle::dialect::IntArrayAttribute" in attr_types[idx]: + get_attributes_str += ( + GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( + op_name=op_class_name, + attr_type=attr_type, + attribute_name=attr_names[idx], + ) + ) + elif "paddle::dialect::ScalarAttribute" in attr_types[idx]: + get_attributes_str += ( + GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE.format( + op_name=op_class_name, + attr_type=attr_type, + attribute_name=attr_names[idx], + ) + ) + elif "pir::StrAttribute" in attr_types[idx]: + get_attributes_str += ( + GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE.format( + op_name=op_class_name, + attr_type=attr_type, + attribute_name=attr_names[idx], + attr_ir_type=attr_types[idx], + ) + ) + else: + get_attributes_str += GET_ATTRIBUTES_FROM_MAP_TEMPLATE.format( + op_name=op_class_name, + attr_type=attr_type, + attribute_name=attr_names[idx], + attr_ir_type=attr_types[idx], + ) + return get_attributes_str + + +def gen_infermeta_func_str( + op_class_name, + op_input_name_list, + op_input_type_list, + op_input_optional_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_output_name_list, + op_output_type_list, + op_output_size_list, + op_output_optional_list, + op_infer_meta_map, + op_inplace_map, + op_attribute_name_list, + op_attribute_type_list, + op_attribute_build_arg_type_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_non_mutable_attribute_build_arg_type_list, + muta_attr_is_input=False, + attr_args_is_map=True, +): + inuse_infer_meta_args = [] + for idx in range(len(op_infer_meta_map['param'])): + inuse_infer_meta_args.append(op_infer_meta_map['param'][idx]) + + # Prepare outputs_meta_tensor for infer meta + for idx in range(len(op_output_name_list)): + if op_output_name_list[idx].endswith('_grad'): + inuse_infer_meta_args.append(f"{op_output_name_list[idx][0:-5]}") + if op_output_name_list[idx].endswith('_grad_'): + inuse_infer_meta_args.append(f"{op_output_name_list[idx][0:-6]}") + inuse_infer_meta_args.append(f"{op_output_name_list[idx]}") + + infermeta_inputs_str = get_infermeta_inputs_str( + inuse_infer_meta_args, + op_input_name_list, + op_input_type_list, + op_input_optional_list, + op_mutable_attribute_name_list, + muta_attr_is_input, + ) + + get_attributes_str = GetAttributes( + op_class_name, + muta_attr_is_input, + inuse_infer_meta_args, + op_attribute_name_list, + op_attribute_type_list, + op_attribute_build_arg_type_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_non_mutable_attribute_build_arg_type_list, + attr_args_is_map, + ) + + infermeta_outputs_str = GenBuildOutputsPart2( + op_class_name, + inuse_infer_meta_args, + op_input_name_list, + op_input_type_list, + op_input_optional_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_output_name_list, + op_output_type_list, + op_output_size_list, + op_output_optional_list, + op_infer_meta_map, + op_inplace_map, + muta_attr_is_input, + ) + + infermeta_func = OP_INFERMETA_TEMPLATE.format( + op_name=op_class_name, + infermeta_inputs=infermeta_inputs_str, + get_attributes_str=get_attributes_str, + infermeta_outputs=infermeta_outputs_str, + ) + + return infermeta_func + + +def gen_infermeta_by_invoke_func_str(op_class_name, invoke_class_name): + return OP_INFERMETA_BY_INVOKE_TEMPLATE.format( + op_name=op_class_name, invoke_class=invoke_class_name + ) diff --git a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py index 2a68a1ad43067..08090ba434bcf 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py @@ -294,12 +294,14 @@ def gen_exclusive_interface_str(op_info, op_info_items): exclusive_interface_str = "" if op_info.infer_meta_func: exclusive_interface_str += ( - " static void InferMeta( phi::InferMetaContext *infer_meta );" + " static void InferMeta( phi::InferMetaContext *infer_meta );\n" + " static std::vector InferMeta( const std::vector& input_values, const pir::AttributeMap& attributes );" ) elif op_info.invoke_map and op_info.invoke_map['func'] in op_info_items: if op_info_items[op_info.invoke_map['func']].infer_meta_func: exclusive_interface_str += ( - " static void InferMeta( phi::InferMetaContext *infer_meta );" + " static void InferMeta( phi::InferMetaContext *infer_meta );\n" + " static std::vector InferMeta( const std::vector& input_values, const pir::AttributeMap& attributes );" ) if op_info.op_phi_name[0] not in vjp_interface_black_list: exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients);" diff --git a/paddle/fluid/pir/dialect/operator/interface/infermeta.h b/paddle/fluid/pir/dialect/operator/interface/infermeta.h index fe0f50a456008..4f7497ee97fc5 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infermeta.h +++ b/paddle/fluid/pir/dialect/operator/interface/infermeta.h @@ -23,9 +23,15 @@ class InferMetaInterface : public pir::OpInterfaceBase { public: /// Defined these methods with the interface. struct Concept { - explicit Concept(void (*infer_meta)(phi::InferMetaContext *)) - : infer_meta_(infer_meta) {} + explicit Concept(void (*infer_meta)(phi::InferMetaContext *), + std::vector (*infer_meta_by_value)( + const std::vector &, + const pir::AttributeMap &)) + : infer_meta_(infer_meta), infer_meta_by_value_(infer_meta_by_value) {} + void (*infer_meta_)(phi::InferMetaContext *); + std::vector (*infer_meta_by_value_)( + const std::vector &, const pir::AttributeMap &); }; template @@ -33,8 +39,12 @@ class InferMetaInterface : public pir::OpInterfaceBase { static inline void InferMeta(phi::InferMetaContext *infer_meta) { return ConcreteOp::InferMeta(infer_meta); } - - Model() : Concept(InferMeta) {} + static inline std::vector InferMetaByValue( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + return ConcreteOp::InferMeta(input_values, attributes); + } + Model() : Concept(InferMeta, InferMetaByValue) {} }; /// Constructor @@ -45,6 +55,11 @@ class InferMetaInterface : public pir::OpInterfaceBase { impl_->infer_meta_(infer_meta); } + std::vector InferMeta(const std::vector &input_values, + const pir::AttributeMap &attributes) { + return impl_->infer_meta_by_value_(input_values, attributes); + } + private: Concept *impl_; }; diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index ad2b836c8e220..5c7387dc22e6c 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -177,6 +177,71 @@ void AddNOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector AddNOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta AddNOp"; + IR_ENFORCE(input_values.size() == 1, + "Num of inputs is expected to be 1 but got %d.", + input_values.size()); + pir::Value inputs_ = input_values[0]; + + VLOG(4) << "Builder construction outputs"; + pir::VectorType x = inputs_.type().dyn_cast(); + + std::vector vec_dense_x; + for (size_t i = 0; i < x.size(); i++) { + if (x[i].isa()) { + vec_dense_x.push_back(paddle::dialect::IrTensor( + TransToPhiDataType( + x[i].dyn_cast().dtype()), + x[i].dyn_cast().dims(), + x[i].dyn_cast().data_layout(), + x[i].dyn_cast().lod(), + x[i].dyn_cast().offset())); + } else if (x[i].isa()) { + vec_dense_x.push_back(paddle::dialect::IrTensor( + TransToPhiDataType( + x[i].dyn_cast() + .dtype()), + x[i].dyn_cast().dims(), + x[i].dyn_cast() + .data_layout(), + x[i].dyn_cast().lod(), + x[i].dyn_cast().offset())); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + } + std::vector vec_meta_x; + for (size_t i = 0; i < vec_dense_x.size(); i++) { + vec_meta_x.push_back(paddle::dialect::IrMetaTensor(&vec_dense_x[i])); + } + + std::vector meta_x; + for (size_t i = 0; i < static_cast(vec_meta_x.size()); i++) { + meta_x.push_back(&vec_meta_x[i]); + } + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::AddNInferMeta(meta_x, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + return argument_outputs; +} + OpInfoTuple AddN_Op::GetOpInfo() { std::vector inputs = { paddle::dialect::OpInputInfo( @@ -300,6 +365,77 @@ void AddN_Op::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector AddN_Op::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta AddN_Op"; + IR_ENFORCE(input_values.size() == 1, + "Num of inputs is expected to be 1 but got %d.", + input_values.size()); + pir::Value inputs_ = input_values[0]; + + VLOG(4) << "Builder construction outputs"; + pir::VectorType inputs = inputs_.type().dyn_cast(); + std::vector vec_dense_inputs; + for (size_t i = 0; i < static_cast(inputs.size()); i++) { + if (inputs[i].isa()) { + vec_dense_inputs.push_back(paddle::dialect::IrTensor( + paddle::dialect::TransToPhiDataType( + inputs[i].dyn_cast().dtype()), + inputs[i].dyn_cast().dims(), + inputs[i].dyn_cast().data_layout(), + inputs[i].dyn_cast().lod(), + inputs[i].dyn_cast().offset())); + } else if (inputs[i].isa()) { + vec_dense_inputs.push_back(paddle::dialect::IrTensor( + TransToPhiDataType( + inputs[i] + .dyn_cast() + .dtype()), + inputs[i] + .dyn_cast() + .dims(), + inputs[i] + .dyn_cast() + .data_layout(), + inputs[i].dyn_cast().lod(), + inputs[i] + .dyn_cast() + .offset())); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + } + + std::vector vec_meta_inputs; + for (size_t i = 0; i < vec_dense_inputs.size(); i++) { + vec_meta_inputs.push_back( + paddle::dialect::IrMetaTensor(&vec_dense_inputs[i])); + } + + std::vector meta_inputs; + for (size_t i = 0; i < static_cast(vec_meta_inputs.size()); i++) { + meta_inputs.push_back(&vec_meta_inputs[i]); + } + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::AddNInferMeta(meta_inputs, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + return argument_outputs; +} + OpInfoTuple AddNWithKernelOp::GetOpInfo() { std::vector inputs = { paddle::dialect::OpInputInfo( @@ -425,6 +561,103 @@ void AddNWithKernelOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector AddNWithKernelOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta AddNWithKernelOp"; + IR_ENFORCE(input_values.size() == 1, + "Num of inputs is expected to be 1 but got %d.", + input_values.size()); + pir::Value inputs_ = input_values[0]; + + VLOG(4) << "Builder construction outputs"; + pir::VectorType inputs = inputs_.type().dyn_cast(); + std::vector vec_dense_inputs; + for (size_t i = 0; i < static_cast(inputs.size()); i++) { + if (inputs[i].isa()) { + vec_dense_inputs.push_back(paddle::dialect::IrTensor( + paddle::dialect::TransToPhiDataType( + inputs[i].dyn_cast().dtype()), + inputs[i].dyn_cast().dims(), + inputs[i].dyn_cast().data_layout(), + inputs[i].dyn_cast().lod(), + inputs[i].dyn_cast().offset())); + } else if (inputs[i].isa()) { + vec_dense_inputs.push_back(paddle::dialect::IrTensor( + TransToPhiDataType( + inputs[i] + .dyn_cast() + .dtype()), + inputs[i] + .dyn_cast() + .dims(), + inputs[i] + .dyn_cast() + .data_layout(), + inputs[i].dyn_cast().lod(), + inputs[i] + .dyn_cast() + .offset())); + } else if (inputs[i].isa()) { + vec_dense_inputs.push_back(paddle::dialect::IrTensor( + paddle::dialect::TransToPhiDataType( + inputs[i].dyn_cast().dtype()), + inputs[i].dyn_cast().dims(), + inputs[i].dyn_cast().data_layout(), + inputs[i].dyn_cast().lod(), + inputs[i].dyn_cast().offset())); + } else if (inputs[i].isa()) { + vec_dense_inputs.push_back(paddle::dialect::IrTensor( + TransToPhiDataType( + inputs[i] + .dyn_cast() + .dtype()), + inputs[i] + .dyn_cast() + .dims(), + inputs[i] + .dyn_cast() + .data_layout(), + inputs[i] + .dyn_cast() + .lod(), + inputs[i] + .dyn_cast() + .offset())); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support DenseTensorType or AllocatedDenseTensorType or " + "SelectedRowsType or AllocatedSelectedRowsType")); + } + } + + std::vector vec_meta_inputs; + for (size_t i = 0; i < vec_dense_inputs.size(); i++) { + vec_meta_inputs.push_back( + paddle::dialect::IrMetaTensor(&vec_dense_inputs[i])); + } + + std::vector meta_inputs; + for (size_t i = 0; i < static_cast(vec_meta_inputs.size()); i++) { + meta_inputs.push_back(&vec_meta_inputs[i]); + } + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::AddNInferMeta(meta_inputs, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + return argument_outputs; +} + OpInfoTuple AddNArrayOp::GetOpInfo() { std::vector inputs = { OpInputInfo("inputs", @@ -555,6 +788,75 @@ void AddNArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector AddNArrayOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta AddNArrayOp"; + IR_ENFORCE(input_values.size() == 1, + "Num of inputs is expected to be 1 but got %d.", + input_values.size()); + pir::Value inputs_ = input_values[0]; + VLOG(4) << "Builder construction outputs"; + pir::VectorType inputs = inputs_.type().dyn_cast(); + + std::vector vec_dense_inputs; + for (size_t i = 0; i < inputs.size(); i++) { + if (inputs[i].isa()) { + vec_dense_inputs.push_back(paddle::dialect::IrTensor( + TransToPhiDataType( + inputs[i] + .dyn_cast() + .dtype()), + {}, + inputs[i] + .dyn_cast() + .data_layout(), + {})); + } else if (inputs[i] + .isa()) { + vec_dense_inputs.push_back(paddle::dialect::IrTensor( + TransToPhiDataType( + inputs[i] + .dyn_cast() + .dtype()), + {}, + inputs[i] + .dyn_cast() + .data_layout(), + {})); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorArrayType or " + "paddle::dialect::AllocatedDenseTensorArrayType")); + } + } + + std::vector vec_meta_inputs; + for (size_t i = 0; i < vec_dense_inputs.size(); i++) { + vec_meta_inputs.push_back( + paddle::dialect::IrMetaTensor(&vec_dense_inputs[i])); + } + + std::vector meta_inputs; + for (size_t i = 0; i < static_cast(vec_meta_inputs.size()); i++) { + meta_inputs.push_back(&vec_meta_inputs[i]); + } + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::AddNTensorArrayInferMeta( + meta_inputs, &meta_out, phi::MetaConfig(false, false)); + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + TransToIrDataType(dense_out.dtype()), + dense_out.layout()); + + argument_outputs.push_back(out_dense_tensor_type); + return argument_outputs; +} + const char *FusedGemmEpilogueOp::attributes_name[3] = { "trans_x", "trans_y", "activation"}; @@ -786,6 +1088,165 @@ void FusedGemmEpilogueOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector FusedGemmEpilogueOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta FusedGemmEpilogueOp"; + IR_ENFORCE(input_values.size() == 3, + "Num of inputs is expected to be 3 but got %d.", + input_values.size()); + pir::Value x_ = input_values[0]; + pir::Value y_ = input_values[1]; + pir::Value bias_ = input_values[2]; + + PADDLE_ENFORCE( + attributes.find("trans_x") != attributes.end(), + phi::errors::NotFound( + "'trans_x' Attribute is expected for FusedGemmEpilogueOp")); + bool trans_x = attributes.at("trans_x").dyn_cast().data(); + + PADDLE_ENFORCE( + attributes.find("trans_y") != attributes.end(), + phi::errors::NotFound( + "'trans_y' Attribute is expected for FusedGemmEpilogueOp")); + bool trans_y = attributes.at("trans_y").dyn_cast().data(); + + PADDLE_ENFORCE( + attributes.find("activation") != attributes.end(), + phi::errors::NotFound( + "'activation' Attribute is expected for FusedGemmEpilogueOp")); + std::string activation = + attributes.at("activation").dyn_cast().AsString(); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x; + if (x_.type().isa()) { + x = x_.type().dyn_cast(); + (void)x; + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_x = + x_.type().dyn_cast(); + x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), + allocated_x.dtype(), + allocated_x.dims(), + allocated_x.data_layout(), + allocated_x.lod(), + allocated_x.offset()); + (void)x; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + + paddle::dialect::DenseTensorType y; + if (y_.type().isa()) { + y = y_.type().dyn_cast(); + (void)y; + } else if (y_.type().isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_y = + y_.type().dyn_cast(); + y = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), + allocated_y.dtype(), + allocated_y.dims(), + allocated_y.data_layout(), + allocated_y.lod(), + allocated_y.offset()); + (void)y; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + + paddle::dialect::DenseTensorType bias; + if (bias_.type().isa()) { + bias = bias_.type().dyn_cast(); + (void)bias; + } else if (bias_.type().isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_bias = + bias_.type().dyn_cast(); + bias = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), + allocated_bias.dtype(), + allocated_bias.dims(), + allocated_bias.data_layout(), + allocated_bias.lod(), + allocated_bias.offset()); + (void)bias; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor dense_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&dense_x); + + VLOG(4) << "Builder construction dense_y"; + paddle::dialect::IrTensor dense_y( + paddle::dialect::TransToPhiDataType(y.dtype()), + y.dims(), + y.data_layout(), + y.lod(), + y.offset()); + VLOG(4) << "Builder construction meta_y"; + paddle::dialect::IrMetaTensor meta_y(&dense_y); + + VLOG(4) << "Builder construction dense_bias"; + paddle::dialect::IrTensor dense_bias( + paddle::dialect::TransToPhiDataType(bias.dtype()), + bias.dims(), + bias.data_layout(), + bias.lod(), + bias.offset()); + VLOG(4) << "Builder construction meta_bias"; + paddle::dialect::IrMetaTensor meta_bias(&dense_bias); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + paddle::dialect::IrTensor dense_reserve_space; + paddle::dialect::IrMetaTensor meta_reserve_space(&dense_reserve_space); + + phi::FusedGemmEpilogueInferMeta( + meta_x, + meta_y, + meta_bias, + trans_x, + trans_y, + activation, + &meta_out, + activation == "none" ? nullptr : &meta_reserve_space); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + + pir::Type reserve_space_dense_tensor_type = + activation == "none" + ? pir::Type() + : paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_reserve_space.dtype()), + dense_reserve_space.dims(), + dense_reserve_space.layout(), + dense_reserve_space.lod(), + dense_reserve_space.offset()); + argument_outputs.push_back(reserve_space_dense_tensor_type); + return argument_outputs; +} + const char *FusedGemmEpilogueGradOp::attributes_name[3] = { "trans_x", "trans_y", "activation_grad"}; @@ -999,24 +1460,238 @@ void FusedGemmEpilogueGradOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } -const char *SplitGradOp::attributes_name[1] = {"axis"}; +std::vector FusedGemmEpilogueGradOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + IR_ENFORCE(input_values.size() == 4, + "Num of inputs is expected to be 4 but got %d.", + input_values.size()); + + pir::Value x_ = input_values[0]; + pir::Value y_ = input_values[1]; + pir::Value reserve_space_ = input_values[2]; + pir::Value out_grad_ = input_values[3]; + VLOG(4) << "Start build FusedGemmEpilogueGradOp"; -OpInfoTuple SplitGradOp::GetOpInfo() { - std::vector inputs = { - OpInputInfo("out_grad", - "pir::VectorType", - false, - false, - false, - true), - OpInputInfo("axis", - "paddle::dialect::ScalarAttribute", - false, - false, - true, - false)}; - std::vector attributes = {}; - std::vector outputs = { + PADDLE_ENFORCE( + attributes.find("trans_x") != attributes.end(), + phi::errors::NotFound( + "'trans_x' Attribute is expected for FusedGemmEpilogueGradOp")); + bool trans_x = attributes.at("trans_x").dyn_cast().data(); + + PADDLE_ENFORCE( + attributes.find("trans_y") != attributes.end(), + phi::errors::NotFound( + "'trans_y' Attribute is expected for FusedGemmEpilogueGradOp")); + bool trans_y = attributes.at("trans_y").dyn_cast().data(); + + PADDLE_ENFORCE( + attributes.find("activation_grad") != attributes.end(), + phi::errors::NotFound("'activation_grad' Attribute is expected for" + "FusedGemmEpilogueGradOp")); + std::string activation_grad = + attributes.at("activation_grad").dyn_cast().AsString(); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x; + if (x_.type().isa()) { + x = x_.type().dyn_cast(); + (void)x; + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_x = + x_.type().dyn_cast(); + x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), + allocated_x.dtype(), + allocated_x.dims(), + allocated_x.data_layout(), + allocated_x.lod(), + allocated_x.offset()); + (void)x; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + + paddle::dialect::DenseTensorType y; + if (y_.type().isa()) { + y = y_.type().dyn_cast(); + (void)y; + } else if (y_.type().isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_y = + y_.type().dyn_cast(); + y = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), + allocated_y.dtype(), + allocated_y.dims(), + allocated_y.data_layout(), + allocated_y.lod(), + allocated_y.offset()); + (void)y; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + + paddle::dialect::DenseTensorType reserve_space; + if (reserve_space_) { + if (reserve_space_.type().isa()) { + reserve_space = + reserve_space_.type().dyn_cast(); + (void)reserve_space; + } else if (reserve_space_.type() + .isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_reserve_space = + reserve_space_.type() + .dyn_cast(); + reserve_space = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + allocated_reserve_space.dtype(), + allocated_reserve_space.dims(), + allocated_reserve_space.data_layout(), + allocated_reserve_space.lod(), + allocated_reserve_space.offset()); + (void)reserve_space; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + } else { + reserve_space = paddle::dialect::DenseTensorType(); + (void)reserve_space; + } + + paddle::dialect::DenseTensorType out_grad; + if (out_grad_.type().isa()) { + out_grad = out_grad_.type().dyn_cast(); + (void)out_grad; + } else if (out_grad_.type() + .isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_out_grad = + out_grad_.type().dyn_cast(); + out_grad = + paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), + allocated_out_grad.dtype(), + allocated_out_grad.dims(), + allocated_out_grad.data_layout(), + allocated_out_grad.lod(), + allocated_out_grad.offset()); + (void)out_grad; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor dense_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&dense_x); + + VLOG(4) << "Builder construction dense_y"; + paddle::dialect::IrTensor dense_y( + paddle::dialect::TransToPhiDataType(y.dtype()), + y.dims(), + y.data_layout(), + y.lod(), + y.offset()); + VLOG(4) << "Builder construction meta_y"; + paddle::dialect::IrMetaTensor meta_y(&dense_y); + + VLOG(4) << "Builder construction dense_reserve_space"; + std::unique_ptr dense_reserve_space = + reserve_space_ + ? std::make_unique( + paddle::dialect::TransToPhiDataType(reserve_space.dtype()), + reserve_space.dims(), + reserve_space.data_layout(), + reserve_space.lod(), + reserve_space.offset()) + : nullptr; + VLOG(4) << "Builder construction meta_reserve_space"; + paddle::dialect::IrMetaTensor meta_reserve_space(dense_reserve_space.get()); + + VLOG(4) << "Builder construction dense_out_grad"; + paddle::dialect::IrTensor dense_out_grad( + paddle::dialect::TransToPhiDataType(out_grad.dtype()), + out_grad.dims(), + out_grad.data_layout(), + out_grad.lod(), + out_grad.offset()); + VLOG(4) << "Builder construction meta_out_grad"; + paddle::dialect::IrMetaTensor meta_out_grad(&dense_out_grad); + paddle::dialect::IrTensor dense_x_grad; + paddle::dialect::IrMetaTensor meta_x_grad(&dense_x_grad); + paddle::dialect::IrTensor dense_y_grad; + paddle::dialect::IrMetaTensor meta_y_grad(&dense_y_grad); + paddle::dialect::IrTensor dense_bias_grad; + paddle::dialect::IrMetaTensor meta_bias_grad(&dense_bias_grad); + + phi::FusedGemmEpilogueGradInferMeta(meta_x, + meta_y, + meta_reserve_space, + meta_out_grad, + trans_x, + trans_y, + activation_grad, + &meta_x_grad, + &meta_y_grad, + &meta_bias_grad); + + std::vector argument_outputs; + pir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_x_grad.dtype()), + dense_x_grad.dims(), + dense_x_grad.layout(), + dense_x_grad.lod(), + dense_x_grad.offset()); + argument_outputs.push_back(x_grad_dense_tensor_type); + + pir::Type y_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_y_grad.dtype()), + dense_y_grad.dims(), + dense_y_grad.layout(), + dense_y_grad.lod(), + dense_y_grad.offset()); + argument_outputs.push_back(y_grad_dense_tensor_type); + + pir::Type bias_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_bias_grad.dtype()), + dense_bias_grad.dims(), + dense_bias_grad.layout(), + dense_bias_grad.lod(), + dense_bias_grad.offset()); + argument_outputs.push_back(bias_grad_dense_tensor_type); + return argument_outputs; +} + +const char *SplitGradOp::attributes_name[1] = {"axis"}; + +OpInfoTuple SplitGradOp::GetOpInfo() { + std::vector inputs = { + OpInputInfo("out_grad", + "pir::VectorType", + false, + false, + false, + true), + OpInputInfo("axis", + "paddle::dialect::ScalarAttribute", + false, + false, + true, + false)}; + std::vector attributes = {}; + std::vector outputs = { OpOutputInfo("x_grad", "paddle::dialect::DenseTensorType", false, false)}; paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("ConcatInferMeta", @@ -1203,6 +1878,63 @@ void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector SplitGradOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta SplitGradOp"; + + IR_ENFORCE(input_values.size() == 2, + "Num of inputs is expected to be 2 but got %d.", + input_values.size()); + pir::Value out_grad_ = input_values[0]; + pir::Value axis_ = input_values[1]; + + VLOG(4) << "Builder construction outputs"; + pir::VectorType out_grad = out_grad_.type().dyn_cast(); + int axis = axis_.dyn_cast() + .owner() + ->dyn_cast() + .attribute("value") + .data() + .to(); + + std::vector vec_dense_out_grad; + for (size_t i = 0; i < static_cast(out_grad.size()); i++) { + vec_dense_out_grad.push_back(paddle::dialect::IrTensor( + TransToPhiDataType( + out_grad[i].dyn_cast().dtype()), + out_grad[i].dyn_cast().dims(), + out_grad[i].dyn_cast().data_layout(), + out_grad[i].dyn_cast().lod(), + out_grad[i].dyn_cast().offset())); + } + std::vector vec_meta_out_grad; + for (size_t i = 0; i < vec_dense_out_grad.size(); i++) { + vec_meta_out_grad.push_back( + paddle::dialect::IrMetaTensor(&vec_dense_out_grad[i])); + } + + std::vector meta_out_grad; + for (size_t i = 0; i < static_cast(vec_meta_out_grad.size()); i++) { + meta_out_grad.push_back(&vec_meta_out_grad[i]); + } + paddle::dialect::IrTensor dense_x_grad; + paddle::dialect::IrMetaTensor meta_x_grad(&dense_x_grad); + + phi::ConcatInferMeta(meta_out_grad, axis, &meta_x_grad); + + std::vector argument_outputs; + pir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + TransToIrDataType(dense_x_grad.dtype()), + dense_x_grad.dims(), + dense_x_grad.layout(), + dense_x_grad.lod(), + dense_x_grad.offset()); + argument_outputs.push_back(x_grad_dense_tensor_type); + return argument_outputs; +} + const char *CreateArrayOp::attributes_name[1] = {"dtype"}; OpInfoTuple CreateArrayOp::GetOpInfo() { @@ -1293,6 +2025,32 @@ void CreateArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector CreateArrayOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta CreateArrayOp"; + + PADDLE_ENFORCE( + attributes.find("dtype") != attributes.end(), + phi::errors::NotFound("'dtype' Attribute is expected for CreateArrayOp")); + phi::DataType dtype = attributes.at("dtype") + .dyn_cast() + .data(); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::CreateArrayInferMeta(dtype, &meta_out); + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.layout()); + argument_outputs.push_back(out_dense_tensor_type); + return argument_outputs; +} + const char *CreateArrayLikeOp::attributes_name[1] = {"val"}; OpInfoTuple CreateArrayLikeOp::GetOpInfo() { @@ -1359,6 +2117,7 @@ void CreateArrayLikeOp::Build(pir::Builder &builder, // NOLINT paddle::dialect::TransToIrDataType(dense_out.dtype()), dense_out.layout()); argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); } @@ -1401,6 +2160,60 @@ void CreateArrayLikeOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector CreateArrayLikeOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta CreateArrayLikeOp"; + IR_ENFORCE(input_values.size() == 1, + "Num of inputs is expected to be 1 but got %d.", + input_values.size()); + pir::Value input_ = input_values[0]; + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType input_type; + if (input_.type().isa()) { + input_type = + input_.type().dyn_cast(); + (void)input_type; + } else if (input_.type() + .isa()) { + paddle::dialect::AllocatedDenseTensorArrayType allocated_input = + input_.type() + .dyn_cast(); + input_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.data_layout()); + (void)input_type; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorArrayType or " + "paddle::dialect::AllocatedDenseTensorArrayType")); + } + + paddle::dialect::IrTensor dense_input( + paddle::dialect::TransToPhiDataType(input_type.dtype()), + {}, + input_type.data_layout(), + {}); + + paddle::dialect::IrMetaTensor meta_input(&dense_input); + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::CreateArrayLikeInferMeta(meta_input, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.layout()); + argument_outputs.push_back(out_dense_tensor_type); + + return argument_outputs; +} + OpInfoTuple ArrayLengthOp::GetOpInfo() { std::vector inputs = { OpInputInfo("x", @@ -1498,6 +2311,57 @@ void ArrayLengthOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector ArrayLengthOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta ArrayLengthOp"; + IR_ENFORCE(input_values.size() == 1, + "Num of inputs is expected to be 1 but got %d.", + input_values.size()); + pir::Value x_ = input_values[0]; + + paddle::dialect::DenseTensorArrayType x_type; + if (x_.type().isa()) { + x_type = x_.type().dyn_cast(); + (void)x_type; + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorArrayType allocated_input = + x_.type().dyn_cast(); + x_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.data_layout()); + (void)x_type; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorArrayType or " + "paddle::dialect::AllocatedDenseTensorArrayType")); + } + + paddle::dialect::IrTensor dense_x( + paddle::dialect::TransToPhiDataType(x_type.dtype()), + {}, + x_type.data_layout(), + {}); + paddle::dialect::IrMetaTensor meta_x(&dense_x); + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::ArrayLengthInferMeta(meta_x, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + return argument_outputs; +} + OpInfoTuple ArrayReadOp::GetOpInfo() { std::vector inputs = { OpInputInfo("array", @@ -1658,12 +2522,83 @@ void ArrayReadOp::VerifySig() { phi::errors::PreconditionNotMet( "Type validation failed for the 0th output.")); } - VLOG(4) << "End Verifying for: ArrayWrite_Op."; -} + VLOG(4) << "End Verifying for: ArrayWrite_Op."; +} + +void ArrayReadOp::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(phi::ArrayReadInferMeta); + fn(infer_meta); +} + +std::vector ArrayReadOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta ArrayLengthOp"; + IR_ENFORCE(input_values.size() == 2, + "Num of inputs is expected to be 2 but got %d.", + input_values.size()); + pir::Value array_ = input_values[0]; + pir::Value i_ = input_values[1]; + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType array_type; + if (array_.type().isa()) { + array_type = + array_.type().dyn_cast(); + (void)array_type; + } else if (array_.type() + .isa()) { + paddle::dialect::AllocatedDenseTensorArrayType allocated_input = + array_.type() + .dyn_cast(); + array_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.data_layout()); + (void)array_type; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorArrayType or " + "paddle::dialect::AllocatedDenseTensorArrayType")); + } + paddle::dialect::IrTensor dense_array( + paddle::dialect::TransToPhiDataType(array_type.dtype()), + {}, + array_type.data_layout(), + {}); + paddle::dialect::IrMetaTensor meta_array(&dense_array); + + phi::Scalar i_scalar; + if (i_.dyn_cast() && + i_.dyn_cast().owner()->isa()) { + i_scalar = + std::move(phi::Scalar(i_.dyn_cast() + .owner() + ->dyn_cast() + .attribute("value") + .dyn_cast() + .data() + .to())); + } else { + i_scalar = std::move(phi::Scalar(-1)); + i_scalar.SetFromTensor(true); + } -void ArrayReadOp::InferMeta(phi::InferMetaContext *infer_meta) { - auto fn = PD_INFER_META(phi::ArrayReadInferMeta); - fn(infer_meta); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::ArrayReadInferMeta( + meta_array, i_scalar, &meta_out, phi::MetaConfig(false, false)); + + std::vector argument_outputs; + pir::Type out_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod()); + argument_outputs.push_back(out_type); + return argument_outputs; } OpInfoTuple ArrayWrite_Op::GetOpInfo() { @@ -1796,6 +2731,88 @@ void ArrayWrite_Op::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector ArrayWrite_Op::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta ArrayWrite_Op"; + IR_ENFORCE(input_values.size() == 3, + "Num of inputs is expected to be 3 but got %d.", + input_values.size()); + pir::Value array_ = input_values[0]; + pir::Value x_ = input_values[1]; + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType array_type; + if (array_.type().isa()) { + array_type = + array_.type().dyn_cast(); + (void)array_type; + } else if (array_.type() + .isa()) { + paddle::dialect::AllocatedDenseTensorArrayType allocated_input = + array_.type() + .dyn_cast(); + array_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.data_layout()); + (void)array_type; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorArrayType or " + "paddle::dialect::AllocatedDenseTensorArrayType")); + } + + paddle::dialect::IrTensor dense_array( + paddle::dialect::TransToPhiDataType(array_type.dtype()), + {}, + array_type.data_layout(), + {}); + paddle::dialect::IrMetaTensor meta_array(&dense_array); + + paddle::dialect::DenseTensorType x_type; + if (x_.type().isa()) { + x_type = x_.type().dyn_cast(); + (void)x_type; + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_input = + x_.type().dyn_cast(); + x_type = + paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.dims(), + allocated_input.data_layout(), + allocated_input.lod(), + allocated_input.offset()); + (void)x_type; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + paddle::dialect::IrTensor dense_x( + paddle::dialect::TransToPhiDataType(x_type.dtype()), + x_type.dims(), + x_type.data_layout(), + x_type.lod(), + x_type.offset()); + paddle::dialect::IrMetaTensor meta_x(&dense_x); + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::ArrayWriteInferMeta( + meta_array, meta_x, &meta_out, phi::MetaConfig(false, false)); + + std::vector argument_outputs; + pir::Type out_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.layout()); + argument_outputs.push_back(out_type); + return argument_outputs; +} + const char *ArrayToTensorOp::attributes_name[2] = {"axis", "use_stack"}; OpInfoTuple ArrayToTensorOp::GetOpInfo() { @@ -1944,6 +2961,83 @@ void ArrayToTensorOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector ArrayToTensorOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta ArrayToTensorOp"; + IR_ENFORCE(input_values.size() == 1, + "Num of inputs is expected to be 1 but got %d.", + input_values.size()); + pir::Value x_ = input_values[0]; + + VLOG(4) << "Builder construction attributes"; + IR_ENFORCE(attributes.find("axis") != attributes.end(), + "'value' Attribute is expected for IncrementOp. "); + int32_t axis = attributes.at("axis").dyn_cast().data(); + + IR_ENFORCE(attributes.find("use_stack") != attributes.end(), + "'value' Attribute is expected for IncrementOp. "); + bool use_stack = + attributes.at("use_stack").dyn_cast().data(); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType x_type; + if (x_.type().isa()) { + x_type = x_.type().dyn_cast(); + (void)x_type; + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorArrayType allocated_input = + x_.type().dyn_cast(); + x_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.data_layout()); + (void)x_type; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorArrayType or " + "paddle::dialect::AllocatedDenseTensorArrayType")); + } + paddle::dialect::IrTensor dense_x( + paddle::dialect::TransToPhiDataType(x_type.dtype()), + {}, + x_type.data_layout(), + {}); + paddle::dialect::IrMetaTensor meta_x(&dense_x); + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + paddle::dialect::IrTensor dense_out_index; + paddle::dialect::IrMetaTensor meta_out_index(&dense_out_index); + + phi::ArrayToTensorInferMeta(meta_x, + axis, + use_stack, + &meta_out, + &meta_out_index, + phi::MetaConfig(false, false)); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + pir::Type out_index_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out_index.dtype()), + dense_out_index.dims(), + dense_out_index.layout(), + dense_out_index.lod(), + dense_out_index.offset()); + argument_outputs.push_back(out_index_dense_tensor_type); + return argument_outputs; +} + const char *TensorToArrayOp::attributes_name[2] = {"axis", "use_stack"}; OpInfoTuple TensorToArrayOp::GetOpInfo() { @@ -2091,6 +3185,99 @@ void TensorToArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector TensorToArrayOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta TensorToArrayOp"; + IR_ENFORCE(input_values.size() == 2, + "Num of inputs is expected to be 2 but got %d.", + input_values.size()); + pir::Value x_ = input_values[0]; + pir::Value out_grad_ = input_values[1]; + + VLOG(4) << "Builder construction attributes"; + + IR_ENFORCE(attributes.find("axis") != attributes.end(), + "'value' Attribute is expected for IncrementOp. "); + int32_t axis = attributes.at("axis").dyn_cast().data(); + + IR_ENFORCE(attributes.find("use_stack") != attributes.end(), + "'value' Attribute is expected for IncrementOp. "); + bool use_stack = + attributes.at("use_stack").dyn_cast().data(); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType x; + + if (x_.type().isa()) { + x = x_.type().dyn_cast(); + (void)x; + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorArrayType allocated_input = + x_.type().dyn_cast(); + x = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.data_layout()); + (void)x; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorArrayType or " + "paddle::dialect::AllocatedDenseTensorArrayType")); + } + + paddle::dialect::IrTensor dense_x( + paddle::dialect::TransToPhiDataType(x.dtype()), {}, x.data_layout(), {}); + + paddle::dialect::DenseTensorType out_grad; + if (out_grad_.type().isa()) { + out_grad = out_grad_.type().dyn_cast(); + (void)out_grad; + } else if (out_grad_.type() + .isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_input = + out_grad_.type().dyn_cast(); + out_grad = + paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.dims(), + allocated_input.data_layout(), + allocated_input.lod(), + allocated_input.offset()); + (void)out_grad; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + + paddle::dialect::IrTensor dense_out_grad( + paddle::dialect::TransToPhiDataType(out_grad.dtype()), + out_grad.dims(), + out_grad.data_layout(), + out_grad.lod(), + out_grad.offset()); + + VLOG(4) << "Builder construction meta_x, meta_out_grad"; + paddle::dialect::IrMetaTensor meta_out_grad(&dense_out_grad); + paddle::dialect::IrMetaTensor meta_x(&dense_x); + + paddle::dialect::IrTensor dense_x_grad; + paddle::dialect::IrMetaTensor meta_x_grad(&dense_x_grad); + + phi::TensorToArrayInferMeta( + meta_x, meta_out_grad, axis, use_stack, &meta_x_grad); + + std::vector argument_outputs; + pir::Type out_dense_tensor_array_type = + paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_x_grad.dtype()), + dense_x_grad.layout()); + argument_outputs.push_back(out_dense_tensor_array_type); + return argument_outputs; +} + const char *SliceArrayOp::attributes_name[2] = {"starts", "ends"}; OpInfoTuple SliceArrayOp::GetOpInfo() { @@ -2171,6 +3358,81 @@ void SliceArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector SliceArrayOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta SliceArrayOp"; + IR_ENFORCE(input_values.size() == 1, + "Num of inputs is expected to be 1 but got %d.", + input_values.size()); + pir::Value input = input_values[0]; + + IR_ENFORCE(attributes.count("starts") > 0, "starts does not exist."); + IR_ENFORCE( + attributes.at("starts").isa(), + "Type of attribute: starts is not paddle::dialect::IntArrayAttribute."); + + IR_ENFORCE(attributes.count("ends") > 0, "ends does not exist."); + IR_ENFORCE( + attributes.at("ends").isa(), + "Type of attribute: ends is not paddle::dialect::IntArrayAttribute."); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType input_type; + if (input.type().isa()) { + input_type = input.type().dyn_cast(); + (void)input_type; + } else if (input.type() + .isa()) { + paddle::dialect::AllocatedDenseTensorArrayType allocated_input = + input.type().dyn_cast(); + input_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.data_layout()); + (void)input_type; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::AllocatedDenseTensorArrayType or " + "paddle::dialect::AllocatedDenseTensorArrayType")); + } + + paddle::dialect::IrTensor dense_input( + paddle::dialect::TransToPhiDataType(input_type.dtype()), + {}, + input_type.data_layout(), + {}); + paddle::dialect::IrMetaTensor meta_input(&dense_input); + + phi::IntArray starts_list = + attributes.at("starts") + .dyn_cast() + .data(); + phi::IntArray ends_list = attributes.at("ends") + .dyn_cast() + .data(); + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::SliceArrayInferMeta(meta_input, + starts_list, + ends_list, + &meta_out, + phi::MetaConfig(false, false)); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + return argument_outputs; +} + phi::DataType SliceArrayOp::GetKernelTypeForVar( const std::string &var_name, const phi::DataType &tensor_dtype, @@ -2292,8 +3554,110 @@ void SliceArrayDenseOp::Build(pir::Builder &builder, // NOLINT std::move(phi::IntArray(std::vector(starts_size, -1))); starts_list.SetFromTensor(true); } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Only support VectorType or DenseTensorType")); + PADDLE_THROW(phi::errors::Unimplemented( + "Only support VectorType or DenseTensorType")); + } + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::SliceArrayDenseInferMeta( + meta_input, starts_list, &meta_out, phi::MetaConfig(false, false)); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); +} + +void SliceArrayDenseOp::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(phi::SliceArrayDenseInferMeta); + fn(infer_meta); +} + +std::vector SliceArrayDenseOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta SliceArrayDenseOp"; + IR_ENFORCE(input_values.size() == 2, + "Num of inputs is expected to be 2 but got %d.", + input_values.size()); + pir::Value input = input_values[0]; + pir::Value starts = input_values[1]; + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType input_type; + if (input.type().isa()) { + input_type = input.type().dyn_cast(); + (void)input_type; + } else if (input.type() + .isa()) { + paddle::dialect::AllocatedDenseTensorArrayType allocated_input = + input.type().dyn_cast(); + input_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.data_layout()); + (void)input_type; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorArrayType or " + "paddle::dialect::AllocatedDenseTensorArrayType")); + } + paddle::dialect::IrTensor dense_input( + paddle::dialect::TransToPhiDataType(input_type.dtype()), + {}, + input_type.data_layout(), + {}); + paddle::dialect::IrMetaTensor meta_input(&dense_input); + + phi::IntArray starts_list; + if (starts.dyn_cast() + .owner() + ->isa()) { + starts_list = std::move(phi::IntArray(paddle::dialect::GetInt64Vector( + starts.dyn_cast() + .owner() + ->dyn_cast() + .attribute("value")))); + } else if (starts.type().isa()) { + size_t starts_size = starts.type().dyn_cast().size(); + starts_list = + std::move(phi::IntArray(std::vector(starts_size, -1))); + starts_list.SetFromTensor(true); + } else if (starts.type().isa()) { + common::DDim starts_dim = + starts.type().dyn_cast().dims(); + size_t starts_size = common::product(starts_dim); + if (common::contain_unknown_dim(starts_dim)) { + starts_size = 1; + } + starts_list = + std::move(phi::IntArray(std::vector(starts_size, -1))); + starts_list.SetFromTensor(true); + } else if (starts.type().isa()) { + common::DDim starts_dim = + starts.type() + .dyn_cast() + .dims(); + size_t starts_size = common::product(starts_dim); + if (common::contain_unknown_dim(starts_dim)) { + starts_size = 1; + } + starts_list = + std::move(phi::IntArray(std::vector(starts_size, -1))); + starts_list.SetFromTensor(true); + } else { + PADDLE_THROW( + phi::errors::Unimplemented("Only support VectorType or DenseTensorType " + "or AllocatedDenseTensorType")); } paddle::dialect::IrTensor dense_out; @@ -2311,13 +3675,7 @@ void SliceArrayDenseOp::Build(pir::Builder &builder, // NOLINT dense_out.lod(), dense_out.offset()); argument_outputs.push_back(out_dense_tensor_type); - argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); - ::pir::PassStopGradientsDefaultly(argument); -} - -void SliceArrayDenseOp::InferMeta(phi::InferMetaContext *infer_meta) { - auto fn = PD_INFER_META(phi::SliceArrayDenseInferMeta); - fn(infer_meta); + return argument_outputs; } phi::DataType SliceArrayDenseOp::GetKernelTypeForVar( @@ -2389,6 +3747,57 @@ void AssignArray_Op::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector AssignArray_Op::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta AssignArray_Op"; + IR_ENFORCE(input_values.size() == 1, + "Num of inputs is expected to be 1 but got %d.", + input_values.size()); + pir::Value x_ = input_values[0]; + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType x_type; + if (x_.type().isa()) { + x_type = x_.type().dyn_cast(); + (void)x_type; + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorArrayType allocated_input = + x_.type().dyn_cast(); + x_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.data_layout()); + (void)x_type; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorArrayType or " + "paddle::dialect::AllocatedDenseTensorArrayType")); + } + paddle::dialect::IrTensor dense_input( + paddle::dialect::TransToPhiDataType(x_type.dtype()), + {}, + x_type.data_layout(), + {}); + paddle::dialect::IrMetaTensor meta_input(&dense_input); + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::UnchangedArrayInferMeta(meta_input, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + return argument_outputs; +} + phi::DataType AssignArray_Op::GetKernelTypeForVar( const std::string &var_name, const phi::DataType &tensor_dtype, @@ -2651,6 +4060,90 @@ void ExpandOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector ExpandOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta ExpandOp"; + IR_ENFORCE(input_values.size() == 2, + "Num of inputs is expected to be 2 but got %d.", + input_values.size()); + pir::Value x_ = input_values[0]; + pir::Value shape_ = input_values[1]; + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x; + if (x_.type().isa()) { + x = x_.type().dyn_cast(); + (void)x; + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_input = + x_.type().dyn_cast(); + x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.dims(), + allocated_input.data_layout(), + allocated_input.lod(), + allocated_input.offset()); + (void)x; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + + phi::IntArray shape; + if (shape_.dyn_cast() + .owner() + ->isa()) { + shape = std::move(phi::IntArray(paddle::dialect::GetInt64Vector( + shape_.dyn_cast() + .owner() + ->dyn_cast() + .attribute("value")))); + } else if (shape_.type().isa()) { + size_t shape_size = shape_.type().dyn_cast().size(); + // In ExpandInferMeta use -2 to represent the element in expand_shape is a + // var. + shape = std::move(phi::IntArray(std::vector(shape_size, -2))); + shape.SetFromTensor(true); + } else if (shape_.type().isa()) { + size_t shape_size = common::product( + shape_.type().dyn_cast().dims()); + // In ExpandInferMeta use -2 to represent the element in expand_shape is a + // var. + shape = std::move(phi::IntArray(std::vector(shape_size, -2))); + shape.SetFromTensor(true); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support VectorType or DenseTensorType")); + } + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor ir_meta_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_meta_tensor_x); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::ExpandInferMeta(meta_x, shape, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + return argument_outputs; +} + phi::DataType ExpandOp::GetKernelTypeForVar( const std::string &var_name, const phi::DataType &tensor_dtype, @@ -2898,6 +4391,66 @@ void IncrementOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector IncrementOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta IncrementOp"; + IR_ENFORCE(input_values.size() == 1, + "Num of inputs is expected to be 1 but got %d.", + input_values.size()); + pir::Value x_ = input_values[0]; + + IR_ENFORCE(attributes.find("value") != attributes.end(), + "'value' Attribute is expected for IncrementOp. "); + float value = attributes.at("value").dyn_cast().data(); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x; + if (x_.type().isa()) { + x = x_.type().dyn_cast(); + (void)x; + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_input = + x_.type().dyn_cast(); + x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.dims(), + allocated_input.data_layout(), + allocated_input.lod(), + allocated_input.offset()); + (void)x; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::IncrementInferMeta(meta_x, value, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + return argument_outputs; +} + phi::DataType IncrementOp::GetKernelTypeForVar( const std::string &var_name, const phi::DataType &tensor_dtype, @@ -3070,6 +4623,66 @@ void Increment_Op::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector Increment_Op::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta Increment_Op"; + IR_ENFORCE(input_values.size() == 1, + "Num of inputs is expected to be 1 but got %d.", + input_values.size()); + pir::Value x_ = input_values[0]; + + IR_ENFORCE(attributes.find("value") != attributes.end(), + "'value' Attribute is expected for Increment_Op. "); + float value = attributes.at("value").dyn_cast().data(); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x; + if (x_.type().isa()) { + x = x_.type().dyn_cast(); + (void)x; + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_input = + x_.type().dyn_cast(); + x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.dims(), + allocated_input.data_layout(), + allocated_input.lod(), + allocated_input.offset()); + (void)x; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::IncrementInferMeta(meta_x, value, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + return argument_outputs; +} + phi::DataType Increment_Op::GetKernelTypeForVar( const std::string &var_name, const phi::DataType &tensor_dtype, @@ -3136,6 +4749,128 @@ void ShapeBroadcastOp::Build(pir::Builder &builder, namespace { +void ShapeBroadcastOpInferMeta(const phi::MetaTensor &x, + const phi::MetaTensor &y, + phi::MetaTensor *out) { + PADDLE_ENFORCE_EQ( + x.dims().size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of x.dims() must be equal to 1.", x.dims().size())); + PADDLE_ENFORCE_EQ( + y.dims().size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of y.dims() must be equal to 1.", y.dims().size())); + out->set_dims({std::max(x.dims().at(0), y.dims().at(0))}); + // dtype need promote when meet input dtype with more precision + paddle::experimental::DataTypeSet dtype_set{x.dtype()}; + dtype_set = dtype_set | paddle::experimental::DataTypeSet(y.dtype()); + DataType promote_result = PromoteTypes(dtype_set); + if (promote_result == DataType::UNDEFINED) { + promote_result = x.dtype(); + } + out->set_dtype(promote_result); + out->set_layout(x.layout()); + out->share_lod(x); +} + +} // namespace + +void ShapeBroadcastOp::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(ShapeBroadcastOpInferMeta); + fn(infer_meta); +} + +std::vector ShapeBroadcastOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta ShapeBroadcastOp"; + IR_ENFORCE(input_values.size() == 2, + "Num of inputs is expected to be 2 but got %d.", + input_values.size()); + pir::Value x_ = input_values[0]; + pir::Value y_ = input_values[1]; + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x; + if (x_.type().isa()) { + x = x_.type().dyn_cast(); + (void)x; + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_x = + x_.type().dyn_cast(); + x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), + allocated_x.dtype(), + allocated_x.dims(), + allocated_x.data_layout(), + allocated_x.lod(), + allocated_x.offset()); + (void)x; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + + paddle::dialect::DenseTensorType y; + if (y_.type().isa()) { + y = y_.type().dyn_cast(); + (void)y; + } else if (y_.type().isa()) { + paddle::dialect::AllocatedDenseTensorType allocated_x = + y_.type().dyn_cast(); + y = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), + allocated_x.dtype(), + allocated_x.dims(), + allocated_x.data_layout(), + allocated_x.lod(), + allocated_x.offset()); + (void)y; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType or " + "paddle::dialect::AllocatedDenseTensorType")); + } + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); + + VLOG(4) << "Builder construction dense_y"; + paddle::dialect::IrTensor ir_tensor_y( + paddle::dialect::TransToPhiDataType(y.dtype()), + y.dims(), + y.data_layout(), + y.lod(), + y.offset()); + VLOG(4) << "Builder construction meta_y"; + paddle::dialect::IrMetaTensor meta_y(&ir_tensor_y); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::ElementwiseInferMeta(meta_x, meta_y, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + return argument_outputs; +} + +namespace { + symbol::DimExpr GetBroadcastDimExpr(const symbol::DimExpr &lhs, const symbol::DimExpr &rhs) { if (lhs.isa() && rhs.isa()) { @@ -3253,6 +4988,57 @@ void MemcpyD2hMultiIoOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +std::vector MemcpyD2hMultiIoOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + IR_ENFORCE(input_values.size() == 1, + "Num of inputs is expected to be 1 but got %d.", + input_values.size()); + + pir::Value x_ = input_values[0]; + (void)x_; + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType x_type; + if (x_.type().isa()) { + x_type = x_.type().dyn_cast(); + (void)x_type; + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorArrayType allocated_input = + x_.type().dyn_cast(); + x_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.data_layout()); + (void)x_type; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorArrayType or " + "paddle::dialect::AllocatedDenseTensorArrayType")); + } + paddle::dialect::IrTensor dense_input( + paddle::dialect::TransToPhiDataType(x_type.dtype()), + {}, + x_type.data_layout(), + {}); + paddle::dialect::IrMetaTensor meta_input(&dense_input); + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::UnchangedArrayInferMeta(meta_input, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + return argument_outputs; +} + phi::DataType MemcpyD2hMultiIoOp::GetKernelTypeForVar( const std::string &var_name, const phi::DataType &tensor_dtype, diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 00d2203ca747b..60766e45842cd 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -53,6 +53,10 @@ class AddNOp : public pir::Op InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); + static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -81,6 +85,9 @@ class AddN_Op : public pir::Op InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class AddNWithKernelOp : public pir::Op InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class AddNArrayOp : public pir::Op InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class FusedGemmEpilogueOp @@ -148,6 +161,9 @@ class FusedGemmEpilogueOp pir::OpResult reserve_space() { return result(1); } static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class FusedGemmEpilogueGradOp @@ -178,6 +194,9 @@ class FusedGemmEpilogueGradOp pir::OpResult bias_grad() { return result(2); } static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class SplitGradOp : public pir::Op { @@ -201,6 +220,9 @@ class SplitGradOp : public pir::Op { pir::Value axis() { return operand_source(1); } pir::OpResult x_grad() { return result(0); } static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class CreateArrayOp @@ -217,6 +239,9 @@ class CreateArrayOp void VerifySig(); pir::OpResult out() { return result(0); } static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class CreateArrayLikeOp : public pir::Op InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class ArrayLengthOp @@ -253,6 +281,9 @@ class ArrayLengthOp pir::Value x() { return operand_source(0); } pir::OpResult out() { return result(0); } static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class ArrayReadOp : public pir::Op InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -308,6 +342,9 @@ class ArrayWrite_Op : public pir::Op InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -336,6 +373,9 @@ class ArrayToTensorOp : public pir::Op InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -363,6 +403,9 @@ class TensorToArrayOp pir::Value out_grad() { return operand_source(1); } pir::OpResult x_grad() { return result(0); } static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class SliceArrayOp @@ -388,6 +431,9 @@ class SliceArrayOp pir::OpResult out() { return result(0); } static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class SliceArrayDenseOp @@ -417,6 +463,9 @@ class SliceArrayDenseOp pir::OpResult out() { return result(0); } static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class AssignArray_Op @@ -442,6 +491,9 @@ class AssignArray_Op pir::OpResult out() { return result(0); } static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class ExpandOp : public pir::Op InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -534,6 +589,9 @@ class IncrementOp pir::OpResult out() { return result(0); } static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -576,6 +634,9 @@ class Increment_Op pir::OpResult out() { return result(0); } static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -616,11 +677,15 @@ class MemcpyD2hMultiIoOp pir::OpResult out() { return result(0); } static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); }; class IR_API ShapeBroadcastOp : public pir::Op { + paddle::dialect::InferSymbolicShapeInterface, + paddle::dialect::InferMetaInterface> { public: using Op::Op; static const char *name() { return "pd_op.shape_broadcast"; } @@ -637,6 +702,11 @@ class IR_API ShapeBroadcastOp pir::Value y() { return operand_source(1); } pir::OpResult out() { return result(0); } + static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); + bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); }; diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 7d96d5096ea31..9a2a9ea0957c1 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -18,6 +18,7 @@ #include "paddle/common/errors.h" #include "paddle/fluid/framework/phi_utils.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" @@ -281,7 +282,44 @@ std::set GetRegisterDataType(const std::string& op_name) { return data_type; } +std::string GetValueDataType(const pir::Type& type) { + if (type.isa()) { + return phi::DataTypeToString(dialect::TransToPhiDataType( + type.dyn_cast().dtype())); + } else if (type.isa()) { + return phi::DataTypeToString(dialect::TransToPhiDataType( + type.dyn_cast().dtype())); + } else if (type.isa()) { + return phi::DataTypeToString(dialect::TransToPhiDataType( + type.dyn_cast().dtype())); + } else if (type.isa()) { + auto vec_value = type.dyn_cast(); + if (vec_value.size() > 0) { + return GetValueDataType(vec_value[0]); + } else { + return ""; + } + } else if (type.isa()) { + return phi::DataTypeToString(dialect::TransToPhiDataType( + type.dyn_cast().dtype())); + } else if (type.isa()) { + return phi::DataTypeToString(dialect::TransToPhiDataType( + type.dyn_cast().dtype())); + } else if (type.isa()) { + return phi::DataTypeToString(dialect::TransToPhiDataType( + type.dyn_cast() + .dtype())); + } else { + PADDLE_THROW( + phi::errors::InvalidType("Currently, we can only get dtype for " + "DenseTensorType and SelectedRowsType.")); + } +} + std::string GetValueDataType(const pir::Value& value) { + if (value.impl() == nullptr) { + return ""; + } if (value.type().isa()) { return phi::DataTypeToString(dialect::TransToPhiDataType( value.type().dyn_cast().dtype())); @@ -291,6 +329,29 @@ std::string GetValueDataType(const pir::Value& value) { } else if (value.type().isa()) { return phi::DataTypeToString(dialect::TransToPhiDataType( value.type().dyn_cast().dtype())); + } else if (value.type().isa()) { + auto vec_value = value.type().dyn_cast(); + if (vec_value.size() > 0) { + return GetValueDataType(vec_value[0]); + } else { + return ""; + } + } else if (value.type().isa()) { + return phi::DataTypeToString(dialect::TransToPhiDataType( + value.type() + .dyn_cast() + .dtype())); + } else if (value.type().isa()) { + return phi::DataTypeToString(dialect::TransToPhiDataType( + value.type() + .dyn_cast() + .dtype())); + } else if (value.type() + .isa()) { + return phi::DataTypeToString(dialect::TransToPhiDataType( + value.type() + .dyn_cast() + .dtype())); } else { PADDLE_THROW( phi::errors::InvalidType("Currently, we can only get dtype for " diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.h b/paddle/fluid/pir/dialect/operator/utils/utils.h index 42a9758e9ca6a..3a88f053e9284 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.h +++ b/paddle/fluid/pir/dialect/operator/utils/utils.h @@ -159,5 +159,8 @@ void CheckDataTypeOrValue(const phi::DataType& dtype, const pir::Value& value, const std::string& value_name, const std::string& op_name); + +std::string GetValueDataType(const pir::Value& value); + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 66b935d1001c4..a905cabff7037 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -84,6 +84,22 @@ std::vector GetValueShape(const pir::Value& value) { } } +static const std::vector InferMetaByValue( + pir::Operation* op, + const std::vector& input_values, + const pir::AttributeMap& attribute_map) { + pir::OpInfo op_info = + pir::IrContext::Instance()->GetRegisteredOpInfo(op->name()); + auto infer_meta_interface = + op_info.GetInterfaceImpl(); + std::vector output_types; + if (infer_meta_interface) { + output_types = + infer_meta_interface->infer_meta_by_value_(input_values, attribute_map); + } + return output_types; +} + std::unordered_map Str2PhiDataType = { {"DataType::FLOAT16", phi::DataType::FLOAT16}, {"DataType::BFLOAT16", phi::DataType::BFLOAT16}, @@ -1502,11 +1518,11 @@ void HandleForSpecialOp( void PushBackOutputTypes(pir::IrContext* ctx, pir::Operation* op_item, + const pir::Type& origin_type, const phi::Place& out_place, const phi::KernelKey& kernel_key, - std::vector* op_output_types, - size_t index) { - auto result_type = op_item->result(index).type(); + std::vector* op_output_types) { + auto result_type = origin_type; if (!result_type) { op_output_types->push_back(result_type); } else if (result_type.isa() || @@ -1589,8 +1605,12 @@ void HandleForCustomOp( for (size_t i = 0; i < op_item->num_results(); ++i) { phi::Place out_place = phi::TransToPhiPlace(kernel_key.backend()); - PushBackOutputTypes( - ctx, op_item, out_place, kernel_key, &op_output_types, i); + PushBackOutputTypes(ctx, + op_item, + op_item->result(i).type(), + out_place, + kernel_key, + &op_output_types); } // Prepare input @@ -1672,14 +1692,17 @@ void HandleForCustomOp( block->push_back(op); } -std::vector BuildOutputs(pir::Operation* op_item, - const std::string& kernel_fn_str, - const phi::KernelKey& kernel_key, - pir::IrContext* ctx) { +std::vector BuildOutputs( + pir::Operation* op_item, + const std::string& kernel_fn_str, + const phi::KernelKey& kernel_key, + const std::vector& new_vec_inputs, + pir::IrContext* ctx) { if (op_item->num_results() == 0) { return {}; } std::vector op_output_types; + pir::AttributeMap attribute_map = op_item->attributes(); auto phi_kernel = phi::KernelFactory::Instance().SelectKernelWithGPUDNN( kernel_fn_str, kernel_key); @@ -1700,16 +1723,73 @@ std::vector BuildOutputs(pir::Operation* op_item, op_item->name())); } - for (size_t i = 0; i < op_item->num_results(); ++i) { - phi::Place out_place = phi::TransToPhiPlace(kernel_key.backend()); - if ((!UnchangeOutputOps.count(op_item->name())) && - (!IsLegacyOp(op_item->name())) && phi_kernel.IsValid()) { - out_place = phi::TransToPhiPlace(output_defs[i].backend); + bool is_input_type_changed = false; + for (size_t i = 0; i < op_item->num_operands(); ++i) { + if (GetValueDataType(op_item->operand(i).source()) != + GetValueDataType(new_vec_inputs[i])) { + is_input_type_changed = true; + break; } - PushBackOutputTypes( - ctx, op_item, out_place, kernel_key, &op_output_types, i); } + bool is_custom_set = false; + if (is_input_type_changed) { + std::vector input_values; + for (size_t i = 0; i < op_item->num_operands(); ++i) { + input_values.emplace_back(op_item->operand(i).source()); + } + std::vector output_types = + InferMetaByValue(op_item, input_values, attribute_map); + + if (output_types.size() != 0) { + PADDLE_ENFORCE_EQ( + output_types.size(), + op_item->num_results(), + phi::errors::PreconditionNotMet( + "output_types.size() is expected to be %d but got %d", + op_item->num_results(), + output_types.size())); + for (size_t i = 0; i < op_item->num_results(); ++i) { + if (output_types[i] != op_item->result(i).type()) { + is_custom_set = true; + break; + } + } + } + } + + if (!is_input_type_changed || is_custom_set) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + phi::Place out_place = phi::TransToPhiPlace(kernel_key.backend()); + if ((!UnchangeOutputOps.count(op_item->name())) && + (!IsLegacyOp(op_item->name())) && phi_kernel.IsValid()) { + out_place = phi::TransToPhiPlace(output_defs[i].backend); + } + PushBackOutputTypes(ctx, + op_item, + op_item->result(i).type(), + out_place, + kernel_key, + &op_output_types); + } + } else { + auto base_types = InferMetaByValue(op_item, new_vec_inputs, attribute_map); + PADDLE_ENFORCE_EQ(base_types.size(), + op_item->num_results(), + phi::errors::PreconditionNotMet( + "base_types.size() is expected to be %d but got %d", + op_item->num_results(), + base_types.size())); + for (size_t i = 0; i < op_item->num_results(); ++i) { + phi::Place out_place = phi::TransToPhiPlace(kernel_key.backend()); + if ((!UnchangeOutputOps.count(op_item->name())) && + (!IsLegacyOp(op_item->name())) && phi_kernel.IsValid()) { + out_place = phi::TransToPhiPlace(output_defs[i].backend); + } + PushBackOutputTypes( + ctx, op_item, base_types[i], out_place, kernel_key, &op_output_types); + } + } return op_output_types; } @@ -2261,23 +2341,24 @@ void ProcessBlock( op_info_parser = GetOpYamlInfoParser(op_item_inner); } #endif - // build output type - auto op_output_types = BuildOutputs(op_item, kernel_name, kernel_key, ctx); // build input - auto vec_inputs = BuildInputs(op_item, - kernel_name, - kernel_key, - place, - op_info_parser.get(), - ctx, - map_op_pair, - map_value_pair, - new_block); + auto new_vec_inputs = BuildInputs(op_item, + kernel_name, + kernel_key, + place, + op_info_parser.get(), + ctx, + map_op_pair, + map_value_pair, + new_block); + // build output type + auto op_output_types = + BuildOutputs(op_item, kernel_name, kernel_key, new_vec_inputs, ctx); // build op pir::Operation* op = BuildKernelOp(kernel_name, kernel_key, - vec_inputs, + new_vec_inputs, op_output_types, op_item, new_block, diff --git a/test/cpp/pir/core/ir_infershape_test.cc b/test/cpp/pir/core/ir_infershape_test.cc index 65646d851ce60..e866825842bb0 100644 --- a/test/cpp/pir/core/ir_infershape_test.cc +++ b/test/cpp/pir/core/ir_infershape_test.cc @@ -52,6 +52,13 @@ class OperationTest auto fn = PD_INFER_META(phi::CreateInferMeta); fn(infer_meta); } + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta OperationTest"; + std::vector argument_outputs; + return argument_outputs; + } }; IR_DECLARE_EXPLICIT_TEST_TYPE_ID(OperationTest) IR_DEFINE_EXPLICIT_TYPE_ID(OperationTest)