Skip to content

Commit

Permalink
make mechanism of caching grad op symbol shape to work (PaddlePaddle#…
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg authored Jul 1, 2024
1 parent df690f1 commit 90b2cc2
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def gen_cpp_file_code(self, cpp_file_path):
continue
if op_info_item.backward_name not in self.op_info_maps:
continue
if op_info_item.kernel_map is None:
continue

grad_op_item = self.op_info_maps[op_info_item.backward_name]
if grad_op_item.infer_meta_map is None:
continue

for op_phi_name in op_info_item.op_phi_name:
create_grad_op_shape_info_code = ""
Expand Down Expand Up @@ -143,20 +143,6 @@ def gen_cpp_file_code(self, cpp_file_path):
index=index,
)
)
elif input_name in op_info_item.mutable_attribute_name_list:
# mutable attribute
index = len(
op_info_item.input_name_list
) + op_info_item.mutable_attribute_name_list.index(
input_name
)
create_grad_op_shape_info_code += (
GET_INPUT_SHAPE_CODE_TEMPLATE.format(
input_name=input_name,
name_suffix=SHAPE_VAR_NAME_SUFFIX,
index=index,
)
)
elif input_name.endswith("_grad"):
# output grad
origin_out_name = input_name[:-5]
Expand All @@ -171,13 +157,35 @@ def gen_cpp_file_code(self, cpp_file_path):
)
)
else:
raise (
raise ValueError(
f"Not found input name {input_name} for backward op {op_info_item.backward_name}."
)
# mutable attribute
for (
mutable_attribute_name
) in grad_op_item.mutable_attribute_name_list:
assert (
mutable_attribute_name
in op_info_item.mutable_attribute_name_list
), f"{mutable_attribute_name} is not found in {op_info_item.backward_name}'s mutable_attribute name list."
index = len(
op_info_item.input_name_list
) + op_info_item.mutable_attribute_name_list.index(
mutable_attribute_name
)
create_grad_op_shape_info_code += (
GET_INPUT_SHAPE_CODE_TEMPLATE.format(
input_name=mutable_attribute_name,
name_suffix=SHAPE_VAR_NAME_SUFFIX,
index=index,
)
)

create_grad_op_output_shape_code = ""
for output_name in grad_op_item.output_name_list:
assert output_name.endswith("_grad")
if not output_name.endswith("_grad"):
create_grad_op_output_shape_code = ""
break
origin_input_name = output_name[:-5]
if (
origin_input_name
Expand Down Expand Up @@ -216,7 +224,10 @@ def gen_cpp_file_code(self, cpp_file_path):
input_shape_list=", ".join(
[
input_name + SHAPE_VAR_NAME_SUFFIX
for input_name in grad_op_item.input_name_list
for input_name in (
grad_op_item.input_name_list
+ grad_op_item.mutable_attribute_name_list
)
]
),
create_grad_op_output_shape_code=create_grad_op_output_shape_code,
Expand All @@ -228,6 +239,21 @@ def gen_cpp_file_code(self, cpp_file_path):
),
)

if len(op_info_item.kernel_map['func']) == 1:
continue
for kernel_func_name in op_info_item.kernel_map['func']:
is_inplace_version = op_phi_name.endswith('_')
op_origin_name = (
op_phi_name[:-1] if is_inplace_version else op_phi_name
)
if kernel_func_name == op_origin_name:
continue
inplace_suffix = '_' if is_inplace_version else ''
body_code += UNIMPLEMENTED_CODE_TEMPLATE.format(
op_name=to_pascal_case(kernel_func_name)
+ inplace_suffix
)

directory_path = os.path.dirname(cpp_file_path)
if not os.path.exists(directory_path):
os.makedirs(directory_path, exist_ok=True)
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cache_grad_op_symbolic_shape.h"
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/layout_transformation.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
Expand Down Expand Up @@ -1301,6 +1302,14 @@ def AutoCodeGen(
op_traits = op_info.traits_list
op_interfaces = op_info.interfaces_list
op_interfaces += ["paddle::dialect::OpYamlInfoInterface"]
if (
dialect_name == "pd_op"
and op_info.backward_name
and not op_info.is_sparse_op
):
op_interfaces += [
"paddle::dialect::CacheGradOpSymbolicShapeInterface"
]
exclusive_interface_str = gen_exclusive_interface_str(
op_info, op_info_items
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ class CacheGradOpSymbolicShapeInterface
public:
/// Defined these methods with the interface.
struct Concept {
explicit Concept(bool (*cache_grad_op_symbolic_shape)(
explicit Concept(void (*cache_grad_op_symbolic_shape)(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context))
: cache_grad_op_symbolic_shape(cache_grad_op_symbolic_shape) {}
bool (*cache_grad_op_symbolic_shape)(
void (*cache_grad_op_symbolic_shape)(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context);
};

template <class ConcreteOp>
struct Model : public Concept {
static inline bool CacheGradOpSymbolicShape(
static inline void CacheGradOpSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return op->dyn_cast<ConcreteOp>().CacheGradOpSymbolicShape(infer_context);
}
Expand All @@ -50,7 +50,7 @@ class CacheGradOpSymbolicShapeInterface
: pir::OpInterfaceBase<CacheGradOpSymbolicShapeInterface>(op),
impl_(impl) {}

bool CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
void CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext *infer_context);

private:
Concept *impl_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

namespace pir {

bool CacheGradOpSymbolicShapeInterface::CacheGradOpSymbolicShape(
void CacheGradOpSymbolicShapeInterface::CacheGradOpSymbolicShape(
pir::InferSymbolicShapeContext *infer_context) {
return impl_->cache_grad_op_symbolic_shape(operation(), infer_context);
impl_->cache_grad_op_symbolic_shape(operation(), infer_context);
}

} // namespace pir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ void InferSymExprForOp(Operation* op,
}
} else {
// risk set
LOG(WARNING) << "Not found symbolic shape cache for " << op->name()
<< "[id:" << op->id() << "]";
for (uint32_t i = 0; i < op->num_results(); ++i) {
infer_context->SetSymbolForValueByStaticShape(op->result(i));
}
Expand Down Expand Up @@ -349,12 +351,8 @@ void CacheBackwardOpSymbolicShape(Operation* op,
op->dyn_cast<pir::CacheGradOpSymbolicShapeInterface>();
if (cache_grad_op_symbolic_shape_interface) {
VLOG(3) << "CacheBackwardOpSymbolicShape for: " << op->name();
PADDLE_ENFORCE_EQ(
cache_grad_op_symbolic_shape_interface.CacheGradOpSymbolicShape(
infer_context),
true,
common::errors::Fatal("CacheBackwardOpSymbolicShape for %s failed.",
op->name()));
cache_grad_op_symbolic_shape_interface.CacheGradOpSymbolicShape(
infer_context);
}
}

Expand Down

0 comments on commit 90b2cc2

Please sign in to comment.