Skip to content

[CINN] Support auto code-gen of CacheGradOpSymbolicShapeInterface #65500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions paddle/fluid/pir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ set(op_info_file_tmp ${op_info_file}.tmp)
set(op_vjp_source_file ${PIR_DIALECT_BINARY_DIR}/pd_op_vjp.cc)
set(op_vjp_source_file_tmp ${op_vjp_source_file}.tmp)

set(cache_grad_op_symbol_shape_file
${PIR_DIALECT_BINARY_DIR}/pd_op_cache_grad_op_symbol_shape.cc)
set(cache_grad_op_symbol_shape_file_tmp ${cache_grad_op_symbol_shape_file}.tmp)

set(op_source_file ${PIR_DIALECT_BINARY_DIR}/pd_op.cc)
set(op_source_file_tmp ${op_source_file}.tmp)

Expand Down Expand Up @@ -214,10 +218,24 @@ execute_process(

set(generated_files_ops_api "${ops_api_source_file}")

set(cache_grad_op_symbol_shape_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/cache_grad_op_symbol_shape_gen.py
)
set(cache_grad_op_shape_yaml_files
${op_fwd_yaml},${op_bwd_yaml},${fused_op_fwd_yaml},${fused_op_bwd_yaml},${pir_op_fwd_yaml},${pir_op_bwd_yaml}
)

execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${cache_grad_op_symbol_shape_gen_file} --op_yaml_files
${cache_grad_op_shape_yaml_files} --op_compat_yaml_file
${op_compat_yaml_file} --cache_grad_op_symbol_shape_file
${cache_grad_op_symbol_shape_file_tmp})

set(generated_files_pir
${generated_files_pd_op} ${generated_files_onednn_pd_op}
${generated_files_pd_api} ${generated_files_python_c}
${generated_files_ops_api})
${generated_files_ops_api} ${cache_grad_op_symbol_shape_file})
foreach(generated_file ${generated_files_pir})
if(EXISTS "${generated_file}.tmp" AND EXISTS "${generated_file}")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
Expand Down Expand Up @@ -267,7 +285,8 @@ set(op_dialect_srcs
${sparse_op_source_file}
${bwd_sparse_op_source_file}
${api_source_file}
${api_source_file})
${api_source_file}
${cache_grad_op_symbol_shape_file})

if(WITH_ONEDNN)
set(op_dialect_srcs
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os

import yaml
from op_gen import (
OpCompatParser,
OpInfoParser,
to_pascal_case,
)

CPP_FILE_TEMPLATE = """
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/shape_analysis_utils.h"

namespace paddle {{
namespace dialect {{

{body}

}} // namespace dialect
}} // namespace paddle
"""

CACHE_GRAD_OP_SYMBOL_SHAPE_FUNC_CODE_TEMPLATE = """
void {op_name}Op::CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext* infer_context) {{
{create_grad_op_shape_info_code}
pir::InferSymbolicShapeCacheKey op_shape_info(
"{grad_op_name}", {{{input_shape_list}}}, this->operation()->attributes());
{create_grad_op_output_shape_code}
std::vector<symbol::ShapeOrDataDimExprs> output_shape_or_data{{{output_shape_list}}};

infer_context->SetOpInferSymbolicShapeCache(op_shape_info,
output_shape_or_data);
}}
"""

UNIMPLEMENTED_CODE_TEMPLATE = """
void {op_name}Op::CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext* infer_context) {{
PADDLE_THROW(common::errors::Unimplemented("{op_name} CacheGradOpSymbolicShape is not implemented!"));
}}
"""

SHAPE_VAR_NAME_SUFFIX = "_shape"

GET_INPUT_SHAPE_CODE_TEMPLATE = """
const auto& {input_name}{name_suffix} = GetInputShape(infer_context, this->operation(), {index});"""

GET_OUTPUT_SHAPE_CODE_TEMPLATE = """
const auto& {output_name}{name_suffix} = GetOutputShape(infer_context, this->operation(), {index});"""

GET_OUT_GRAD_SHAPE_CODE_TEMPLATE = """
const auto& {output_grad_name}{name_suffix} = GetGradVarShapeFromOutput(infer_context, this->operation(), {index});"""

GET_INPUT_GRAD_SHAPE_CODE_TEMPLATE = """
const auto& {input_grad_name}{name_suffix} = GetGradVarShapeFromInput(infer_context, this->operation(), {index});"""


cache_grad_op_shape_black_list = {"fused_attention"}


class CacheGradOpSymbolShapeCodeGen:
def __init__(
self, op_yaml_files, op_compat_yaml_file, dialect_name="pd_op"
):
self.op_info_maps = self.parse_yaml(
op_yaml_files,
op_compat_yaml_file,
)
self.dialect_name = dialect_name

def parse_yaml(self, op_yaml_files, op_compat_yaml_file):
op_compat_parser = OpCompatParser(op_compat_yaml_file)

op_info_maps = {}
for yaml_file in op_yaml_files:
with open(yaml_file, "r") as f:
ops = yaml.safe_load(f)
for op in ops:
op_compat_item = op_compat_parser.get_compat(op['name'])
if (
op_compat_item is not None
and op_compat_item['op'] == "pow"
and 'scalar' in op_compat_item
):
op_compat_item = op_compat_item.pop('scalar')
op_info_maps[op["name"]] = OpInfoParser(
op, op_compat_item, yaml_file
)
return op_info_maps

def gen_cpp_file_code(self, cpp_file_path):
body_code = ""
for op_info_item in self.op_info_maps.values():
if op_info_item.backward_name is None:
continue
if op_info_item.backward_name not in self.op_info_maps:
continue

grad_op_item = self.op_info_maps[op_info_item.backward_name]
if grad_op_item.infer_meta_map is None:
continue

for op_phi_name in op_info_item.op_phi_name:
create_grad_op_shape_info_code = ""
for input_name in grad_op_item.input_name_list:
if input_name in grad_op_item.forward_input_name_list:
# forward input
index = grad_op_item.forward_input_name_list.index(
input_name
)
create_grad_op_shape_info_code += (
GET_INPUT_SHAPE_CODE_TEMPLATE.format(
input_name=input_name,
name_suffix=SHAPE_VAR_NAME_SUFFIX,
index=index,
)
)
elif input_name in grad_op_item.forward_output_name_list:
# forward output
index = grad_op_item.forward_output_name_list.index(
input_name
)
create_grad_op_shape_info_code += (
GET_OUTPUT_SHAPE_CODE_TEMPLATE.format(
output_name=input_name,
name_suffix=SHAPE_VAR_NAME_SUFFIX,
index=index,
)
)
elif input_name in op_info_item.mutable_attribute_name_list:
# mutable attribute
index = len(
op_info_item.input_name_list
) + op_info_item.mutable_attribute_name_list.index(
input_name
)
create_grad_op_shape_info_code += (
GET_INPUT_SHAPE_CODE_TEMPLATE.format(
input_name=input_name,
name_suffix=SHAPE_VAR_NAME_SUFFIX,
index=index,
)
)
elif input_name.endswith("_grad"):
# output grad
origin_out_name = input_name[:-5]
index = grad_op_item.forward_output_name_list.index(
origin_out_name
)
create_grad_op_shape_info_code += (
GET_OUT_GRAD_SHAPE_CODE_TEMPLATE.format(
output_grad_name=input_name,
name_suffix=SHAPE_VAR_NAME_SUFFIX,
index=index,
)
)
else:
raise (
f"Not found input name {input_name} for backward op {op_info_item.backward_name}."
)

create_grad_op_output_shape_code = ""
for output_name in grad_op_item.output_name_list:
assert output_name.endswith("_grad")
origin_input_name = output_name[:-5]
if (
origin_input_name
not in grad_op_item.forward_input_name_list
):
continue
index = grad_op_item.forward_input_name_list.index(
origin_input_name
)
create_grad_op_output_shape_code += (
GET_INPUT_GRAD_SHAPE_CODE_TEMPLATE.format(
input_grad_name=output_name,
name_suffix=SHAPE_VAR_NAME_SUFFIX,
index=index,
)
)

if (
len(create_grad_op_output_shape_code) == 0
or op_phi_name in cache_grad_op_shape_black_list
):
logging.warning(
f"{op_phi_name}'s grad op has some exception, please check it in yaml file."
)
body_code += UNIMPLEMENTED_CODE_TEMPLATE.format(
op_name=to_pascal_case(op_phi_name),
)
continue

body_code += CACHE_GRAD_OP_SYMBOL_SHAPE_FUNC_CODE_TEMPLATE.format(
op_name=to_pascal_case(op_phi_name),
create_grad_op_shape_info_code=create_grad_op_shape_info_code,
grad_op_name=self.dialect_name
+ "."
+ grad_op_item.op_phi_name[0],
input_shape_list=", ".join(
[
input_name + SHAPE_VAR_NAME_SUFFIX
for input_name in grad_op_item.input_name_list
]
),
create_grad_op_output_shape_code=create_grad_op_output_shape_code,
output_shape_list=", ".join(
[
output_name + SHAPE_VAR_NAME_SUFFIX
for output_name in grad_op_item.output_name_list
]
),
)

directory_path = os.path.dirname(cpp_file_path)
if not os.path.exists(directory_path):
os.makedirs(directory_path, exist_ok=True)

with open(cpp_file_path, 'w') as f:
f.write(
CPP_FILE_TEMPLATE.format(
body=body_code,
)
)


def ParseArguments():
parser = argparse.ArgumentParser(
description='Generate Cache GradOp Symbol Shape Inferface Files By Yaml'
)
parser.add_argument('--op_yaml_files', type=str)
parser.add_argument('--op_compat_yaml_file', type=str)
parser.add_argument('--cache_grad_op_symbol_shape_file', type=str)
return parser.parse_args()


if __name__ == '__main__':
args = ParseArguments()
op_yaml_files = args.op_yaml_files.split(",")
op_compat_yaml_file = args.op_compat_yaml_file
cache_grad_op_symbol_shape_file = args.cache_grad_op_symbol_shape_file

code_gen = CacheGradOpSymbolShapeCodeGen(
op_yaml_files,
op_compat_yaml_file,
)
code_gen.gen_cpp_file_code(cache_grad_op_symbol_shape_file)
21 changes: 11 additions & 10 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
{get_kernel_type_for_var_declare}
{parse_kernel_key_declare}
{infer_symbolic_shape_declare}
{cache_grad_op_symbolic_shape_declare}
{exclusive_interface}
}};
"""
Expand All @@ -178,6 +179,8 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
"""

cache_grad_op_symbolic_shape_template = " void CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext* infer_context);"

# =====================================
# String Template for cc file code gen
# =====================================
Expand Down Expand Up @@ -1512,6 +1515,12 @@ def AutoCodeGen(
):
infer_symbolic_shape_str = infer_symbolic_shape_template

cache_grad_op_symbolic_shape_str = ""
if op_info.backward_name:
cache_grad_op_symbolic_shape_str = (
cache_grad_op_symbolic_shape_template
)

if op_infer_meta_map is not None:
(
build_args_with_muta_attr_not_input_for_declare,
Expand Down Expand Up @@ -1625,6 +1634,7 @@ def AutoCodeGen(
get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str,
parse_kernel_key_declare=parse_kernel_key_str,
infer_symbolic_shape_declare=infer_symbolic_shape_str,
cache_grad_op_symbolic_shape_declare=cache_grad_op_symbolic_shape_str,
)
op_defined_str = ""
else:
Expand All @@ -1648,6 +1658,7 @@ def AutoCodeGen(
get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str,
parse_kernel_key_declare=parse_kernel_key_str,
infer_symbolic_shape_declare=infer_symbolic_shape_str,
cache_grad_op_symbolic_shape_declare=cache_grad_op_symbolic_shape_str,
)
attribute_names_str = (
'"'
Expand Down Expand Up @@ -1911,16 +1922,6 @@ def AutoCodeGen(
gen_infer_symbolic_shape_str(op_class_name)
)

# generate op GetKernelKeyForVar function str
infer_symbolic_shape_define_str = ''
if (
"paddle::dialect::InferSymbolicShapeInterface"
in all_interface_list
):
infer_symbolic_shape_define_str = (
gen_infer_symbolic_shape_str(op_class_name)
)

# generate op GetKernelKeyForVar function str
op_get_kernel_type_for_var_str = ''
if dialect_name == "pd_op" or dialect_name == "onednn_op":
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ class ExpandOp : public pir::Op<ExpandOp,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients);
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
void CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
};

class IncrementOp
Expand Down Expand Up @@ -702,6 +703,7 @@ class AssignOut_Op
const std::vector<pir::Value> &input_values,
pir::AttributeMap *p_attributes);
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
void CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
static std::vector<std::vector<pir::Value>> Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &inputs_,
Expand Down
Loading