Skip to content

Commit

Permalink
[PIR] Improve Dtype Transfer (add infermeta by value) (PaddlePaddle#6…
Browse files Browse the repository at this point in the history
…0677)

* update

* fix null value

* adapt manual op

* update
  • Loading branch information
chen2016013 authored Jan 15, 2024
1 parent 1728013 commit de37c94
Show file tree
Hide file tree
Showing 11 changed files with 2,878 additions and 66 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ paddle/fluid/pir/dialect/operator/ir/op_decomp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op.*
paddle/fluid/pir/dialect/operator/ir/onednn_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.*
paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.*
paddle/fluid/pir/dialect/operator/ir/pd_op_fused.*
Expand Down
48 changes: 48 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from decomp_interface_gen_op_list import decomp_interface_declare_gen_op_list
from infer_symbolic_shape_gen import gen_infer_symbolic_shape_str
from op_build_gen import gen_build_func_str, gen_build_func_str_by_invoke
from op_infermeta_gen import (
gen_infermeta_by_invoke_func_str,
gen_infermeta_func_str,
)
from op_interface_gen import (
gen_exclusive_interface_str,
gen_op_infer_meta_str,
Expand Down Expand Up @@ -142,6 +146,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/pir/dialect/op_generator/op_gen.py"
#include "{h_file}"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h"
#include "paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h"
Expand Down Expand Up @@ -1712,6 +1717,48 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
op_info, op_class_name, op_info_items
)

op_infer_meta_from_type_str = ""
if op_infer_meta_map is not None:
muta_attr_is_input = (
True
if len(op_mutable_attribute_name_list) > 0
else False
)
op_infer_meta_from_type_str = gen_infermeta_func_str(
op_class_name,
op_input_name_list,
op_input_type_list,
op_input_optional_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_output_name_list,
op_output_type_list,
op_output_size_list,
op_output_optional_list,
op_infer_meta_map,
op_inplace_map,
op_attribute_name_list,
op_attribute_type_list,
op_attribute_build_arg_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_non_mutable_attribute_build_arg_type_list,
muta_attr_is_input,
attr_args_is_map=True,
)

if (op_invoke_map is not None) and (
op_invoke_map['func'] in op_info_items
):
op_invoke_class_name = (
to_pascal_case(op_invoke_map['func']) + "Op"
)
op_infer_meta_from_type_str = (
gen_infermeta_by_invoke_func_str(
op_class_name, op_invoke_class_name
)
)

# =================================== #
# gen Vjp func str #
# =================================== #
Expand Down Expand Up @@ -1753,6 +1800,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):

ops_defined_list.append(op_verify_str)
ops_defined_list.append(op_infer_meta_str)
ops_defined_list.append(op_infer_meta_from_type_str)
ops_defined_list.append(op_get_kernel_type_for_var_str)
ops_defined_list.append(parse_kernel_key_define_str)
ops_defined_list.append(infer_symbolic_shape_define_str)
Expand Down
Loading

0 comments on commit de37c94

Please sign in to comment.