Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,12 @@ set(op_source_file_tmp ${op_source_file}.tmp)
set(op_vjp_source_file ${PD_DIALECT_BINARY_DIR}/pd_op_vjp.cc)
set(op_vjp_source_file_tmp ${op_vjp_source_file}.tmp)

add_custom_command(
OUTPUT ${op_yaml_file3} ${op_yaml_file4}
execute_process(
COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_op_dir}
COMMAND ${PYTHON_EXECUTABLE} ${op_gen_parsed_yaml_file} --op_yaml_path
${pd_op_forward_yaml_file} --output_path ${op_yaml_file3}
COMMENT "Generate pd_ops.parsed.yaml"
COMMAND ${PYTHON_EXECUTABLE} ${op_gen_parsed_yaml_file} --op_yaml_path
${pd_op_backward_yaml_file} --output_path ${op_yaml_file4} --backward
COMMENT "Generate pd_ops_backward.parsed.yaml"
DEPENDS ${op_gen_parsed_yaml_file} ${pd_op_forward_yaml_file}
${pd_op_backward_yaml_file}
VERBATIM)
${pd_op_backward_yaml_file} --output_path ${op_yaml_file4} --backward)

add_custom_command(
OUTPUT ${op_header_file} ${op_source_file} ${op_vjp_source_file}
Expand Down
9 changes: 8 additions & 1 deletion paddle/fluid/primitive/codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ set(fwd_path ${parsed_yaml_path}/ops.parsed.yaml)
set(fwd_legacy_path ${parsed_yaml_path}/legacy_ops.parsed.yaml)
set(rev_path ${parsed_yaml_path}/backward_ops.parsed.yaml)
set(rev_legacy_path ${parsed_yaml_path}/legacy_backward_ops.parsed.yaml)
set(fwd_pd_op_path
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated/ops.parsed.yaml
)
set(rev_pd_op_path
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated/ops_backward.parsed.yaml
)
set(prim_path "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/primitive.yaml")
set(templates_dir
"${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/templates/")
Expand All @@ -17,7 +23,8 @@ execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${scripts} --fwd_path ${fwd_path} --fwd_legacy_path
${fwd_legacy_path} --rev_path ${rev_path} --rev_legacy_path
${rev_legacy_path} --prim_path ${prim_path} --templates_dir ${templates_dir}
${rev_legacy_path} --fwd_pd_op_path ${fwd_pd_op_path} --rev_pd_op_path
${rev_pd_op_path} --prim_path ${prim_path} --templates_dir ${templates_dir}
--compat_path ${compat_path} --destination_dir ${destination_dir}
RESULT_VARIABLE _result)
if(${_result})
Expand Down
64 changes: 51 additions & 13 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,15 @@
'add_n_grad',
]

BACKENDS_BLACK_LIST = ['copy_to', 'add_n_grad', "allclose", "isclose"]
BACKENDS_BLACK_LIST = [
'copy_to',
'add_n_grad',
"allclose",
"isclose",
"send_v2",
"assert",
"embedding_grad_sparse",
]


PRIM_VJP = [
Expand Down Expand Up @@ -280,18 +288,19 @@ def process_backward_invoke_info(apis):

def process_optional_output_info(apis):
for api in apis:
if not api['is_fwd']:
continue
inputs_dict = to_named_dict(api['inputs'])
for output in api['outputs']:
if (
api.get("inplace", None)
and output['name'] in api['inplace']
and inputs_dict[api['inplace'][output['name']]]['optional']
):
output['optional'] = True
else:
if not api['is_fwd']:
output['optional'] = False
else:
if (
api.get("inplace", None)
and output['name'] in api['inplace']
and inputs_dict[api['inplace'][output['name']]]['optional']
):
output['optional'] = True
else:
output['optional'] = False


def gen(
Expand All @@ -301,6 +310,8 @@ def gen(
rev_path: pathlib.Path,
rev_legacy_path: pathlib.Path,
compat_path: pathlib.Path,
fwd_pd_op_path: pathlib.Path,
rev_pd_op_path: pathlib.Path,
templates_dir: pathlib.Path,
destination_dir: pathlib.Path,
):
Expand All @@ -316,23 +327,38 @@ def gen(
rev_legacy_path (pathlib.Path): The YAML file path of the legacy
backward API.
compat_path: (pathlib.Path): The YAML file path of the ops compat.
fwd_pd_op_path (pathlib.Path): The YAML file path of the ir forward API.
rev_pd_op_path (pathlib.Path): The YAML file path of the ir backward API.
templates_dir (pathlib.Path): The directory of the templates.
destination_dir (pathlib.Path): The Directory of the generated file.

Returns:
None
"""
prims, fwds, legacy_fwds, revs, legacy_revs, compats = (
(
prims,
fwds,
legacy_fwds,
revs,
legacy_revs,
compats,
ir_fwds,
ir_revs,
) = (
load(prim_path),
load(fwd_path),
load(fwd_legacy_path),
load(rev_path),
load(rev_legacy_path),
load(compat_path),
load(fwd_pd_op_path),
load(rev_pd_op_path),
)
filter_compat_info(compats)
apis = [{**api, **{'is_fwd': True}} for api in fwds + legacy_fwds]
apis = apis + [{**api, **{'is_fwd': False}} for api in revs + legacy_revs]
apis = [{**api, **{'is_fwd': True}} for api in fwds + legacy_fwds + ir_fwds]
apis = apis + [
{**api, **{'is_fwd': False}} for api in revs + legacy_revs + ir_revs
]
apis = [
{**api, **{'is_prim': True}}
if api['name'] in prims
Expand Down Expand Up @@ -383,6 +409,16 @@ def gen(
type=str,
help='The parsed ops compat yaml file.',
)
parser.add_argument(
'--fwd_pd_op_path',
type=str,
help='The ir forward ops parsed yaml file.',
)
parser.add_argument(
'--rev_pd_op_path',
type=str,
help='The ir backward ops parsed yaml file.',
)
parser.add_argument(
'--templates_dir',
type=str,
Expand All @@ -402,6 +438,8 @@ def gen(
pathlib.Path(args.rev_path),
pathlib.Path(args.rev_legacy_path),
pathlib.Path(args.compat_path),
pathlib.Path(args.fwd_pd_op_path),
pathlib.Path(args.rev_pd_op_path),
pathlib.Path(args.templates_dir),
pathlib.Path(args.destination_dir),
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using IntArray = paddle::experimental::IntArray;
using DataType = phi::DataType;

{% for api in apis %}
{%- if api is only_composite_op -%}{#- render nothing -#}
{%- if api is only_composite_op or "infer_meta" not in api and "composite" not in api and "invoke" not in api -%}{#- render nothing -#}
{%- elif api.name not in backend_black_list -%}
{%- if 'invoke' not in api or 'invoke' in api and api.is_fwd -%}
{% if api.attrs is exist_mutable_attribute %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ namespace backend {

{%- macro args(inputs, attrs) -%} {#- Arguments are variable pass into method -#}
{{common.sequence('', '', ', ', inputs)}}
{%- if attrs|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between inputs and attrs -#}
{%- if attrs|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between
nputs and attrs -#}
{{common.sequence('', '', ', ', attrs)}}
{%- endmacro -%}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ auto op_res = paddle::dialect::{{name}}({{common.args(input_names, attr_names)}}


{% for api in apis %}
{%- if api is only_composite_op -%}{#- render nothing -#}
{%- if api is only_composite_op or "infer_meta" not in api and "composite" not in api and "invoke" not in api -%}{#- render nothing -#}
{% elif api.name not in backend_black_list %}
{%- if 'invoke' not in api or 'invoke' in api and api.is_fwd-%}
{% set api_outputs = api.outputs | trip_intermediate %}
Expand Down