Skip to content

Commit 2a1f009

Browse files
authored
[NewExe] Support layout/dtype transform by adding transfer_layout/transfer_dtype op (#37299)
* Add transfer_layout/dtype op * clean useless codes * fix unused var * add optest in white.txt * split into data_transfer.cc * fix cmake * modify according reviewer comment * replace cast_op with transfer_dtype_op
1 parent 684de4b commit 2a1f009

14 files changed

+882
-258
lines changed

paddle/fluid/framework/data_layout_transform.cc

+12-23
Original file line numberDiff line numberDiff line change
@@ -37,30 +37,19 @@ std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to) {
3737
}
3838
}
3939

40-
struct CastDataLayout {
41-
CastDataLayout(const platform::DeviceContext* ctx,
42-
const std::vector<int>& axis, const framework::Tensor& in,
43-
framework::Tensor* out)
44-
: in_(in), out_(out), ctx_(ctx), axis_(axis) {}
45-
const framework::Tensor in_;
46-
framework::Tensor* out_;
47-
const platform::DeviceContext* ctx_;
48-
const std::vector<int> axis_;
49-
50-
template <typename T>
51-
void apply() {
52-
auto place = ctx_->GetPlace();
53-
54-
if (platform::is_cpu_place(place)) {
55-
operators::math::Transpose<platform::CPUDeviceContext, T, 4> trans4;
56-
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
57-
trans4(*context, in_, out_, axis_);
58-
} else {
59-
PADDLE_THROW(platform::errors::PreconditionNotMet(
60-
"Unsupported data layout cast from CPU to GPU."));
61-
}
40+
template <typename T>
41+
void CastDataLayout::apply() {
42+
auto place = ctx_->GetPlace();
43+
44+
if (platform::is_cpu_place(place)) {
45+
operators::math::Transpose<platform::CPUDeviceContext, T, 4> trans4;
46+
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
47+
trans4(*context, in_, out_, axis_);
48+
} else {
49+
PADDLE_THROW(platform::errors::PreconditionNotMet(
50+
"Unsupported data layout cast from CPU to GPU."));
6251
}
63-
};
52+
}
6453

6554
void TransDataLayout(const OpKernelType& kernel_type_for_var,
6655
const OpKernelType& expected_kernel_type, const Tensor& in,

paddle/fluid/framework/data_layout_transform.h

+15
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,21 @@ class Tensor;
3636
namespace paddle {
3737
namespace framework {
3838

39+
struct CastDataLayout {
40+
CastDataLayout(const platform::DeviceContext* ctx,
41+
const std::vector<int>& axis, const framework::Tensor& in,
42+
framework::Tensor* out)
43+
: in_(in), out_(out), ctx_(ctx), axis_(axis) {}
44+
45+
const framework::Tensor in_;
46+
framework::Tensor* out_;
47+
const platform::DeviceContext* ctx_;
48+
const std::vector<int> axis_;
49+
50+
template <typename T>
51+
void apply();
52+
};
53+
3954
#ifdef PADDLE_WITH_MKLDNN
4055
using MKLDNNDataType = dnnl::memory::data_type;
4156

paddle/fluid/framework/new_executor/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ set(INTERPRETERCORE_DEPS op_registry device_context scope framework_proto data_f
22
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
33
graph_to_program_pass variable_helper timer monitor nan_inf_utils)
44

5+
cc_library(data_transfer SRCS data_transfer.cc DEPS enforce scope glog)
56
cc_library(workqueue SRCS workqueue.cc workqueue_utils.cc DEPS enforce)
67
cc_library(new_executor_defs SRCS new_executor_defs.cc DEPS enforce glog scope)
78
cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS} executor_gc_helper)
8-
cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS} workqueue new_executor_defs)
9+
cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS} workqueue new_executor_defs data_transfer)
910
cc_library(event_manager SRCS event_manager.cc DEPS ${DEVICE_EVENT_LIBS} glog new_executor_defs)
1011
cc_library(stream_analyzer SRCS stream_analyzer.cc DEPS ${DEVICE_EVENT_LIBS} glog device_context new_executor_defs)
1112
cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_garbage_collector stream_analyzer event_manager)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/new_executor/data_transfer.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
namespace interpreter {
20+
21+
bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
22+
const OpKernelType& expected_kernel_key,
23+
const std::string& var_name,
24+
std::string* new_var_name,
25+
std::vector<OpFuncNode>* op_func_nodes,
26+
bool use_local_scope) {
27+
bool is_transferred = false;
28+
auto* src_var_name = &var_name;
29+
30+
Scope* local_scope = use_local_scope ? var_scope_->GetMutableLocalScope()
31+
: var_scope_->GetMutableScope();
32+
33+
// 1. layout transform
34+
if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) {
35+
auto op = TransferLayout(
36+
*src_var_name, new_var_name, kernel_type_for_var.data_layout_,
37+
expected_kernel_key.data_layout_, var_scope_, local_scope);
38+
RunAndConstructOpFuncNode(op, *src_var_name, *new_var_name, op_func_nodes);
39+
// update src_var_name
40+
src_var_name = new_var_name;
41+
is_transferred = true;
42+
}
43+
// 2. dype transform
44+
if (need_dtype_transform(kernel_type_for_var, expected_kernel_key)) {
45+
auto op = TransferDtype(
46+
*src_var_name, new_var_name, kernel_type_for_var.data_type_,
47+
expected_kernel_key.data_type_, var_scope_, local_scope);
48+
RunAndConstructOpFuncNode(op, *src_var_name, *new_var_name, op_func_nodes);
49+
// update src_var_name
50+
src_var_name = new_var_name;
51+
is_transferred = true;
52+
}
53+
// 3. device transform
54+
if (need_device_transform(kernel_type_for_var, expected_kernel_key)) {
55+
auto src_place = kernel_type_for_var.place_;
56+
auto dst_place = expected_kernel_key.place_;
57+
auto op = TransferDevice(*src_var_name, new_var_name, src_place, dst_place,
58+
var_scope_, local_scope);
59+
RunAndConstructOpFuncNode(op, *src_var_name, *new_var_name, op_func_nodes);
60+
is_transferred = true;
61+
}
62+
return is_transferred;
63+
}
64+
65+
void DataTranferHelper::RunAndConstructOpFuncNode(
66+
const std::shared_ptr<OperatorBase>& op, const std::string& var_name,
67+
const std::string& new_var_name,
68+
std::vector<OpFuncNode>* new_op_func_nodes) {
69+
auto& op_type = op->Type();
70+
71+
// 1. Construct RuntimeContext
72+
RuntimeContext runtime_context({}, {});
73+
runtime_context.inputs["X"] = {var_scope_->Var(var_name)};
74+
runtime_context.outputs["Out"] = {var_scope_->Var(new_var_name)};
75+
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
76+
77+
// 2. Execute infer shape and choose kernel
78+
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
79+
static_cast<const framework::OperatorWithKernel*>(op.get())->InferShape(
80+
&infer_shape_ctx);
81+
auto kernels_iter = all_op_kernels.find(op_type);
82+
PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(),
83+
platform::errors::Unavailable(
84+
"There are no kernels which are registered in "
85+
"the %s operator.",
86+
op_type));
87+
OpKernelMap& kernels = kernels_iter->second;
88+
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
89+
auto* dev_ctx = pool.Get(place_);
90+
Scope scope;
91+
auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context);
92+
auto expected_kernel_key =
93+
dynamic_cast<const framework::OperatorWithKernel*>(op.get())
94+
->GetExpectedKernelType(exec_ctx);
95+
auto kernel_iter = kernels.find(expected_kernel_key);
96+
97+
// 3. Execute transfer op and construct OpFuncNode
98+
OpFuncNode new_op_func_node;
99+
new_op_func_node.input_index["X"] = {var_scope_->VarId(var_name)};
100+
new_op_func_node.output_index["Out"] = {var_scope_->VarId(new_var_name)};
101+
new_op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
102+
new_op_func_node.kernel_func_(exec_ctx);
103+
// NOTE(Aurelius84): data_transform_op is expensive operation, so we tag them
104+
// as kQueueSync and execute them in thread pool.
105+
new_op_func_node.type_ = OpFuncType::kQueueSync;
106+
new_op_func_node.dev_ctx_ = dev_ctx;
107+
new_op_func_node.operator_base_ = op;
108+
VLOG(3) << "Run " << op_type << " done.";
109+
110+
new_op_func_nodes->emplace_back(std::move(new_op_func_node));
111+
}
112+
113+
std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
114+
std::string* new_var_name,
115+
DataLayout in_layout,
116+
DataLayout out_layout,
117+
VariableScope* var_scope,
118+
framework::Scope* local_scope) {
119+
// 1. Generate new_var_name and Initialize it
120+
*new_var_name =
121+
var_name + "_layout_" + std::to_string(var_scope->VarSize() + 1);
122+
auto* ptr = local_scope->Var(new_var_name);
123+
124+
auto var_type = var_scope->Var(var_name)->Type();
125+
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
126+
VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
127+
<< ptr << "Variable Type " << var_type;
128+
var_scope->SetVarDesc(var_name, nullptr);
129+
130+
// 2. Construct VariableNameMap
131+
VariableNameMap in_name_map = {{"X", {var_name}}};
132+
VariableNameMap out_name_map = {{"Out", {*new_var_name}}};
133+
AttributeMap attr_map = {{"dst_layout", static_cast<int>(out_layout)}};
134+
135+
// 3. Create transfer_op
136+
std::string op_type("transfer_layout");
137+
auto& op_info = OpInfoMap::Instance().Get(op_type);
138+
auto op = std::shared_ptr<OperatorBase>(
139+
op_info.Creator()(op_type, in_name_map, out_name_map, attr_map));
140+
141+
VLOG(3) << string::Sprintf("Insert %s(%s) with %s -> %s(%s).", op_type,
142+
var_name, in_layout, *new_var_name, out_layout);
143+
return op;
144+
}
145+
146+
std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
147+
std::string* new_var_name,
148+
proto::VarType::Type in_dtype,
149+
proto::VarType::Type out_dtype,
150+
VariableScope* var_scope,
151+
framework::Scope* local_scope) {
152+
// 1. Generate new_var_name and Initialize it
153+
*new_var_name =
154+
var_name + "_dtype_" + std::to_string(var_scope->VarSize() + 1);
155+
auto* ptr = local_scope->Var(new_var_name);
156+
157+
auto var_type = var_scope->Var(var_name)->Type();
158+
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
159+
VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
160+
<< ptr << "Variable Type " << var_type;
161+
var_scope->SetVarDesc(var_name, nullptr);
162+
163+
// 2. Construct VariableNameMap
164+
VariableNameMap in_name_map = {{"X", {var_name}}};
165+
VariableNameMap out_name_map = {{"Out", {*new_var_name}}};
166+
AttributeMap attr_map;
167+
attr_map["in_dtype"] = static_cast<int>(in_dtype);
168+
attr_map["out_dtype"] = static_cast<int>(out_dtype);
169+
// NOTE(Aurelius84): In whice case use_mkldnn = true?
170+
attr_map["use_mkldnn"] = false;
171+
172+
// 3. Create transfer_op
173+
std::string op_type("transfer_dtype");
174+
auto& op_info = OpInfoMap::Instance().Get(op_type);
175+
auto op = std::shared_ptr<OperatorBase>(
176+
op_info.Creator()(op_type, in_name_map, out_name_map, attr_map));
177+
178+
VLOG(3) << string::Sprintf("Insert %s with %s(%s) -> %s(%s).", op_type,
179+
var_name, DataTypeToString(in_dtype),
180+
*new_var_name, DataTypeToString(out_dtype));
181+
return op;
182+
}
183+
184+
std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
185+
std::string* new_var_name,
186+
const platform::Place& src_place,
187+
const platform::Place& dst_place,
188+
VariableScope* var_scope,
189+
framework::Scope* local_scope) {
190+
// 1. Generate new_var_name and Initialize it
191+
*new_var_name =
192+
var_name + "_device_" + std::to_string(var_scope->VarSize() + 1);
193+
auto* ptr = local_scope->Var(new_var_name);
194+
195+
auto var_type = var_scope->Var(var_name)->Type();
196+
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
197+
VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
198+
<< ptr << "Variable Type " << var_type;
199+
var_scope->SetVarDesc(var_name, nullptr);
200+
201+
// 2. Construct VariableNameMap
202+
VariableNameMap in_name_map = {{"X", {var_name}}};
203+
VariableNameMap out_name_map = {{"Out", {*new_var_name}}};
204+
int dst_place_type = platform::is_cpu_place(dst_place)
205+
? 0
206+
: platform::is_gpu_place(dst_place) ? 1 : -1;
207+
AttributeMap attr_map = {{"dst_place_type", dst_place_type}};
208+
209+
// 3. Create transfer_op
210+
std::string op_type = get_memcpy_type(src_place, dst_place);
211+
auto& op_info = OpInfoMap::Instance().Get(op_type);
212+
auto op = std::shared_ptr<OperatorBase>(
213+
op_info.Creator()(op_type, in_name_map, out_name_map, attr_map));
214+
215+
VLOG(3) << string::Sprintf("Insert %s with %s(%s) -> %s(%s).", op_type,
216+
var_name, src_place, *new_var_name, dst_place);
217+
return op;
218+
}
219+
220+
void ApplyDataTransform(const OpKernelType& expected_kernel_key,
221+
const platform::Place& place,
222+
VariableValueMap* ins_map_temp,
223+
VariableScope* var_scope, OpFuncNode* op_func_node,
224+
std::vector<OpFuncNode>* new_op_func_nodes,
225+
bool use_local_scope) {
226+
auto op_base = op_func_node->operator_base_.get();
227+
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
228+
"op_base is null, please pass a valid "
229+
"op_base in apply_data_transform."));
230+
231+
VariableNameMap new_ins(op_base->Inputs());
232+
// record the no need transform variable index.
233+
std::unordered_set<int> no_data_transform_index;
234+
235+
DataTranferHelper data_transfer_helper(place, var_scope);
236+
for (auto& var_name_item : *ins_map_temp) {
237+
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
238+
auto var = var_name_item.second[i];
239+
if (!(var->IsType<LoDTensor>() || var->IsType<SelectedRows>())) {
240+
continue;
241+
}
242+
auto& var_name = new_ins[var_name_item.first].at(i);
243+
auto tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
244+
if (!tensor_in->IsInitialized()) {
245+
continue;
246+
}
247+
auto kernel_type_for_var =
248+
static_cast<const framework::OperatorWithKernel*>(op_base)
249+
->GetKernelTypeForVar(var_name_item.first, *tensor_in,
250+
expected_kernel_key);
251+
// apply data transform
252+
std::string new_var_name;
253+
bool is_transferred = data_transfer_helper.apply(
254+
kernel_type_for_var, expected_kernel_key, var_name, &new_var_name,
255+
new_op_func_nodes, use_local_scope);
256+
257+
if (is_transferred) {
258+
// update RuntimeContext.inputs and original op_func_node inputs
259+
op_func_node->input_index[var_name_item.first][i] =
260+
var_scope->VarId(new_var_name);
261+
var_name_item.second[i] = var_scope->Var(new_var_name);
262+
new_ins[var_name_item.first][i] = new_var_name;
263+
// NOTE(Aurelius84): avoid deepcopy twice if we already insert data
264+
// transfer op.
265+
if (op_base->Type() == "fetch_v2") {
266+
op_base->SetAttr("deepcopy", false);
267+
}
268+
} else {
269+
// record no need data transformer input var_id
270+
VLOG(3) << op_base->Type()
271+
<< " found no data_transform var: " << var_name
272+
<< " with id: " << var_scope->VarId(var_name);
273+
no_data_transform_index.emplace(var_scope->VarId(var_name));
274+
}
275+
}
276+
}
277+
278+
// NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent
279+
// with instruction. (hot fix, it is not good design here)
280+
op_func_node->operator_base_ =
281+
std::shared_ptr<OperatorBase>(framework::OpRegistry::CreateOp(
282+
op_base->Type(), new_ins, op_base->Outputs(), op_base->Attrs()));
283+
op_func_node->no_data_transform_index = std::move(no_data_transform_index);
284+
}
285+
286+
std::string get_memcpy_type(const platform::Place& src_place,
287+
const platform::Place& dst_place) {
288+
PADDLE_ENFORCE_EQ(platform::is_same_place(src_place, dst_place), false,
289+
platform::errors::PreconditionNotMet(
290+
"Required src_place shall be different with dst_place, "
291+
"but received same place: %s",
292+
src_place));
293+
if (platform::is_gpu_place(dst_place)) {
294+
return kMemcpyH2D;
295+
} else if (platform::is_gpu_place(src_place)) {
296+
return kMemcpyD2H;
297+
} else {
298+
PADDLE_THROW(platform::errors::PreconditionNotMet(
299+
"Not support Memcpy typ : %s -> %s", src_place, dst_place));
300+
}
301+
}
302+
303+
} // namespace interpreter
304+
} // namespace framework
305+
} // namespace paddle

0 commit comments

Comments
 (0)