Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
replace cast_op with transfer_dtype_op
  • Loading branch information
Aurelius84 committed Nov 19, 2021
commit fe377b7c2fc813b77f4f8e840ce5fc335200a29b
24 changes: 16 additions & 8 deletions paddle/fluid/framework/new_executor/data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
// 1. layout transform
if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) {
auto op = TransferLayout(src_var_name, new_var_name,
kernel_type_for_var.data_layout_,
expected_kernel_key.data_layout_, var_scope_);
RunAndConstructOpFuncNode(op, src_var_name, new_var_name,
new_op_func_nodes);
Expand All @@ -39,6 +40,7 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
// 2. dype transform
if (need_dtype_transform(kernel_type_for_var, expected_kernel_key)) {
auto op = TransferDtype(src_var_name, new_var_name,
kernel_type_for_var.data_type_,
expected_kernel_key.data_type_, var_scope_);
RunAndConstructOpFuncNode(op, src_var_name, new_var_name,
new_op_func_nodes);
Expand Down Expand Up @@ -109,7 +111,7 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
std::shared_ptr<OperatorBase> TransferLayout(
const std::string& var_name,
std::string& new_var_name, // NOLINT
DataLayout layout, VariableScope* var_scope) {
DataLayout in_layout, DataLayout out_layout, VariableScope* var_scope) {
// 1. Generate new_var_name
new_var_name =
var_name + "_layout_" + std::to_string(var_scope->VarSize() + 1);
Expand All @@ -118,23 +120,24 @@ std::shared_ptr<OperatorBase> TransferLayout(
// 2. Construct VariableNameMap
VariableNameMap in_name_map = {{"X", {var_name}}};
VariableNameMap out_name_map = {{"Out", {new_var_name}}};
AttributeMap attr_map = {{"dst_layout", static_cast<int>(layout)}};
AttributeMap attr_map = {{"dst_layout", static_cast<int>(out_layout)}};

// 3. Create transfer_op
std::string op_type("transfer_layout");
auto& op_info = OpInfoMap::Instance().Get(op_type);
auto op = std::shared_ptr<OperatorBase>(
op_info.Creator()(op_type, in_name_map, out_name_map, attr_map));

VLOG(3) << string::Sprintf("Insert %s with %s -> %s(%s).", op_type, var_name,
new_var_name, layout);
VLOG(3) << string::Sprintf("Insert %s(%s) with %s -> %s(%s).", op_type,
var_name, in_layout, new_var_name, out_layout);
return op;
}

std::shared_ptr<OperatorBase> TransferDtype(
const std::string& var_name,
std::string& new_var_name, // NOLINT
proto::VarType::Type dtype, VariableScope* var_scope) {
proto::VarType::Type in_dtype, proto::VarType::Type out_dtype,
VariableScope* var_scope) {
// 1. Generate new_var_name
new_var_name =
var_name + "_dtype_" + std::to_string(var_scope->VarSize() + 1);
Expand All @@ -143,16 +146,21 @@ std::shared_ptr<OperatorBase> TransferDtype(
// 2. Construct VariableNameMap
VariableNameMap in_name_map = {{"X", {var_name}}};
VariableNameMap out_name_map = {{"Out", {new_var_name}}};
AttributeMap attr_map = {{"dst_dtype", static_cast<int>(dtype)}};
AttributeMap attr_map;
attr_map["in_dtype"] = static_cast<int>(in_dtype);
attr_map["out_dtype"] = static_cast<int>(out_dtype);
// NOTE(Aurelius84): In whice case use_mkldnn = true?
attr_map["use_mkldnn"] = false;

// 3. Create transfer_op
std::string op_type("transfer_dtype");
auto& op_info = OpInfoMap::Instance().Get(op_type);
auto op = std::shared_ptr<OperatorBase>(
op_info.Creator()(op_type, in_name_map, out_name_map, attr_map));

VLOG(3) << string::Sprintf("Insert %s with %s -> %s(%s).", op_type, var_name,
new_var_name, DataTypeToString(dtype));
VLOG(3) << string::Sprintf("Insert %s with %s(%s) -> %s(%s).", op_type,
var_name, DataTypeToString(in_dtype), new_var_name,
DataTypeToString(out_dtype));
return op;
}

Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/framework/new_executor/data_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,13 @@ inline bool need_layout_transform(const OpKernelType& kernel_type_for_var,
std::shared_ptr<OperatorBase> TransferLayout(
const std::string& var_name,
std::string& new_var_name, // NOLINT
DataLayout layout, VariableScope* var_scope);
DataLayout in_layout, DataLayout out_layout, VariableScope* var_scope);

std::shared_ptr<OperatorBase> TransferDtype(
const std::string& var_name,
std::string& new_var_name, // NOLINT
proto::VarType::Type dtype, VariableScope* var_scope);
proto::VarType::Type in_dtype, proto::VarType::Type out_dtype,
VariableScope* var_scope);

std::shared_ptr<OperatorBase> TransferDevice(
const std::string& var_name,
Expand Down
33 changes: 20 additions & 13 deletions paddle/fluid/operators/cast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,23 @@ class CastOp : public framework::OperatorWithKernel {

namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cast, ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>,
ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
cast, ops::CastOpKernel<CPU, float>, ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int16_t>,
ops::CastOpKernel<CPU, bool>, ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
ops::CastOpKernel<CPU, paddle::platform::complex<float>>,
ops::CastOpKernel<CPU, paddle::platform::complex<double>>);
#define REGISTER_CAST_CPU_BASE(op_name, ...) \
REGISTER_OPERATOR(op_name, ops::CastOp, \
ops::CastOpGradMaker<paddle::framework::OpDesc>, \
ops::CastOpGradMaker<paddle::imperative::OpBase>, \
ops::CastOpProtoMaker); \
REGISTER_OP_CPU_KERNEL( \
op_name, ops::CastOpKernel<CPU, float>, ops::CastOpKernel<CPU, double>, \
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int64_t>, \
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int16_t>, \
ops::CastOpKernel<CPU, bool>, ops::CastOpKernel<CPU, uint8_t>, \
ops::CastOpKernel<CPU, paddle::platform::float16>, \
ops::CastOpKernel<CPU, paddle::platform::bfloat16>, \
ops::CastOpKernel<CPU, paddle::platform::complex<float>>, \
ops::CastOpKernel<CPU, paddle::platform::complex<double>>);

REGISTER_CAST_CPU_BASE(cast)
// [ why register transfer_dtype_op alias with cast_op? ]
// In case of InterpreterCore, if we reuse cast_op, we cannot distinguish
// which cast_op is inserted by new executor when we do profiling.
REGISTER_CAST_CPU_BASE(transfer_dtype)
3 changes: 3 additions & 0 deletions paddle/fluid/operators/cast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ namespace plat = paddle::platform;

#if !defined(PADDLE_WITH_HIP)
REGISTER_CAST_CUDA_BASE(cast, ops::CastCUDAOpKernel<plat::bfloat16>)
// See [ why register transfer_dtype_op alias with cast_op? ] in cast_op.cc
REGISTER_CAST_CUDA_BASE(transfer_dtype, ops::CastCUDAOpKernel<plat::bfloat16>)
#else
REGISTER_CAST_CUDA_BASE(cast)
REGISTER_CAST_CUDA_BASE(transfer_dtype)
#endif
124 changes: 0 additions & 124 deletions paddle/fluid/operators/transfer_dtype_op.cc

This file was deleted.

65 changes: 0 additions & 65 deletions paddle/fluid/operators/transfer_dtype_op.h

This file was deleted.