|
| 1 | +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/fluid/framework/new_executor/data_transfer.h" |
| 16 | + |
| 17 | +namespace paddle { |
| 18 | +namespace framework { |
| 19 | +namespace interpreter { |
| 20 | + |
| 21 | +bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, |
| 22 | + const OpKernelType& expected_kernel_key, |
| 23 | + const std::string& var_name, |
| 24 | + std::string* new_var_name, |
| 25 | + std::vector<OpFuncNode>* op_func_nodes, |
| 26 | + bool use_local_scope) { |
| 27 | + bool is_transferred = false; |
| 28 | + auto* src_var_name = &var_name; |
| 29 | + |
| 30 | + Scope* local_scope = use_local_scope ? var_scope_->GetMutableLocalScope() |
| 31 | + : var_scope_->GetMutableScope(); |
| 32 | + |
| 33 | + // 1. layout transform |
| 34 | + if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) { |
| 35 | + auto op = TransferLayout( |
| 36 | + *src_var_name, new_var_name, kernel_type_for_var.data_layout_, |
| 37 | + expected_kernel_key.data_layout_, var_scope_, local_scope); |
| 38 | + RunAndConstructOpFuncNode(op, *src_var_name, *new_var_name, op_func_nodes); |
| 39 | + // update src_var_name |
| 40 | + src_var_name = new_var_name; |
| 41 | + is_transferred = true; |
| 42 | + } |
| 43 | + // 2. dype transform |
| 44 | + if (need_dtype_transform(kernel_type_for_var, expected_kernel_key)) { |
| 45 | + auto op = TransferDtype( |
| 46 | + *src_var_name, new_var_name, kernel_type_for_var.data_type_, |
| 47 | + expected_kernel_key.data_type_, var_scope_, local_scope); |
| 48 | + RunAndConstructOpFuncNode(op, *src_var_name, *new_var_name, op_func_nodes); |
| 49 | + // update src_var_name |
| 50 | + src_var_name = new_var_name; |
| 51 | + is_transferred = true; |
| 52 | + } |
| 53 | + // 3. device transform |
| 54 | + if (need_device_transform(kernel_type_for_var, expected_kernel_key)) { |
| 55 | + auto src_place = kernel_type_for_var.place_; |
| 56 | + auto dst_place = expected_kernel_key.place_; |
| 57 | + auto op = TransferDevice(*src_var_name, new_var_name, src_place, dst_place, |
| 58 | + var_scope_, local_scope); |
| 59 | + RunAndConstructOpFuncNode(op, *src_var_name, *new_var_name, op_func_nodes); |
| 60 | + is_transferred = true; |
| 61 | + } |
| 62 | + return is_transferred; |
| 63 | +} |
| 64 | + |
| 65 | +void DataTranferHelper::RunAndConstructOpFuncNode( |
| 66 | + const std::shared_ptr<OperatorBase>& op, const std::string& var_name, |
| 67 | + const std::string& new_var_name, |
| 68 | + std::vector<OpFuncNode>* new_op_func_nodes) { |
| 69 | + auto& op_type = op->Type(); |
| 70 | + |
| 71 | + // 1. Construct RuntimeContext |
| 72 | + RuntimeContext runtime_context({}, {}); |
| 73 | + runtime_context.inputs["X"] = {var_scope_->Var(var_name)}; |
| 74 | + runtime_context.outputs["Out"] = {var_scope_->Var(new_var_name)}; |
| 75 | + InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); |
| 76 | + |
| 77 | + // 2. Execute infer shape and choose kernel |
| 78 | + auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); |
| 79 | + static_cast<const framework::OperatorWithKernel*>(op.get())->InferShape( |
| 80 | + &infer_shape_ctx); |
| 81 | + auto kernels_iter = all_op_kernels.find(op_type); |
| 82 | + PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(), |
| 83 | + platform::errors::Unavailable( |
| 84 | + "There are no kernels which are registered in " |
| 85 | + "the %s operator.", |
| 86 | + op_type)); |
| 87 | + OpKernelMap& kernels = kernels_iter->second; |
| 88 | + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); |
| 89 | + auto* dev_ctx = pool.Get(place_); |
| 90 | + Scope scope; |
| 91 | + auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context); |
| 92 | + auto expected_kernel_key = |
| 93 | + dynamic_cast<const framework::OperatorWithKernel*>(op.get()) |
| 94 | + ->GetExpectedKernelType(exec_ctx); |
| 95 | + auto kernel_iter = kernels.find(expected_kernel_key); |
| 96 | + |
| 97 | + // 3. Execute transfer op and construct OpFuncNode |
| 98 | + OpFuncNode new_op_func_node; |
| 99 | + new_op_func_node.input_index["X"] = {var_scope_->VarId(var_name)}; |
| 100 | + new_op_func_node.output_index["Out"] = {var_scope_->VarId(new_var_name)}; |
| 101 | + new_op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); |
| 102 | + new_op_func_node.kernel_func_(exec_ctx); |
| 103 | + // NOTE(Aurelius84): data_transform_op is expensive operation, so we tag them |
| 104 | + // as kQueueSync and execute them in thread pool. |
| 105 | + new_op_func_node.type_ = OpFuncType::kQueueSync; |
| 106 | + new_op_func_node.dev_ctx_ = dev_ctx; |
| 107 | + new_op_func_node.operator_base_ = op; |
| 108 | + VLOG(3) << "Run " << op_type << " done."; |
| 109 | + |
| 110 | + new_op_func_nodes->emplace_back(std::move(new_op_func_node)); |
| 111 | +} |
| 112 | + |
| 113 | +std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name, |
| 114 | + std::string* new_var_name, |
| 115 | + DataLayout in_layout, |
| 116 | + DataLayout out_layout, |
| 117 | + VariableScope* var_scope, |
| 118 | + framework::Scope* local_scope) { |
| 119 | + // 1. Generate new_var_name and Initialize it |
| 120 | + *new_var_name = |
| 121 | + var_name + "_layout_" + std::to_string(var_scope->VarSize() + 1); |
| 122 | + auto* ptr = local_scope->Var(new_var_name); |
| 123 | + |
| 124 | + auto var_type = var_scope->Var(var_name)->Type(); |
| 125 | + InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type)); |
| 126 | + VLOG(3) << "Create Variable " << var_name << " locally, which pointer is " |
| 127 | + << ptr << "Variable Type " << var_type; |
| 128 | + var_scope->SetVarDesc(var_name, nullptr); |
| 129 | + |
| 130 | + // 2. Construct VariableNameMap |
| 131 | + VariableNameMap in_name_map = {{"X", {var_name}}}; |
| 132 | + VariableNameMap out_name_map = {{"Out", {*new_var_name}}}; |
| 133 | + AttributeMap attr_map = {{"dst_layout", static_cast<int>(out_layout)}}; |
| 134 | + |
| 135 | + // 3. Create transfer_op |
| 136 | + std::string op_type("transfer_layout"); |
| 137 | + auto& op_info = OpInfoMap::Instance().Get(op_type); |
| 138 | + auto op = std::shared_ptr<OperatorBase>( |
| 139 | + op_info.Creator()(op_type, in_name_map, out_name_map, attr_map)); |
| 140 | + |
| 141 | + VLOG(3) << string::Sprintf("Insert %s(%s) with %s -> %s(%s).", op_type, |
| 142 | + var_name, in_layout, *new_var_name, out_layout); |
| 143 | + return op; |
| 144 | +} |
| 145 | + |
| 146 | +std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name, |
| 147 | + std::string* new_var_name, |
| 148 | + proto::VarType::Type in_dtype, |
| 149 | + proto::VarType::Type out_dtype, |
| 150 | + VariableScope* var_scope, |
| 151 | + framework::Scope* local_scope) { |
| 152 | + // 1. Generate new_var_name and Initialize it |
| 153 | + *new_var_name = |
| 154 | + var_name + "_dtype_" + std::to_string(var_scope->VarSize() + 1); |
| 155 | + auto* ptr = local_scope->Var(new_var_name); |
| 156 | + |
| 157 | + auto var_type = var_scope->Var(var_name)->Type(); |
| 158 | + InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type)); |
| 159 | + VLOG(3) << "Create Variable " << var_name << " locally, which pointer is " |
| 160 | + << ptr << "Variable Type " << var_type; |
| 161 | + var_scope->SetVarDesc(var_name, nullptr); |
| 162 | + |
| 163 | + // 2. Construct VariableNameMap |
| 164 | + VariableNameMap in_name_map = {{"X", {var_name}}}; |
| 165 | + VariableNameMap out_name_map = {{"Out", {*new_var_name}}}; |
| 166 | + AttributeMap attr_map; |
| 167 | + attr_map["in_dtype"] = static_cast<int>(in_dtype); |
| 168 | + attr_map["out_dtype"] = static_cast<int>(out_dtype); |
| 169 | + // NOTE(Aurelius84): In whice case use_mkldnn = true? |
| 170 | + attr_map["use_mkldnn"] = false; |
| 171 | + |
| 172 | + // 3. Create transfer_op |
| 173 | + std::string op_type("transfer_dtype"); |
| 174 | + auto& op_info = OpInfoMap::Instance().Get(op_type); |
| 175 | + auto op = std::shared_ptr<OperatorBase>( |
| 176 | + op_info.Creator()(op_type, in_name_map, out_name_map, attr_map)); |
| 177 | + |
| 178 | + VLOG(3) << string::Sprintf("Insert %s with %s(%s) -> %s(%s).", op_type, |
| 179 | + var_name, DataTypeToString(in_dtype), |
| 180 | + *new_var_name, DataTypeToString(out_dtype)); |
| 181 | + return op; |
| 182 | +} |
| 183 | + |
| 184 | +std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name, |
| 185 | + std::string* new_var_name, |
| 186 | + const platform::Place& src_place, |
| 187 | + const platform::Place& dst_place, |
| 188 | + VariableScope* var_scope, |
| 189 | + framework::Scope* local_scope) { |
| 190 | + // 1. Generate new_var_name and Initialize it |
| 191 | + *new_var_name = |
| 192 | + var_name + "_device_" + std::to_string(var_scope->VarSize() + 1); |
| 193 | + auto* ptr = local_scope->Var(new_var_name); |
| 194 | + |
| 195 | + auto var_type = var_scope->Var(var_name)->Type(); |
| 196 | + InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type)); |
| 197 | + VLOG(3) << "Create Variable " << var_name << " locally, which pointer is " |
| 198 | + << ptr << "Variable Type " << var_type; |
| 199 | + var_scope->SetVarDesc(var_name, nullptr); |
| 200 | + |
| 201 | + // 2. Construct VariableNameMap |
| 202 | + VariableNameMap in_name_map = {{"X", {var_name}}}; |
| 203 | + VariableNameMap out_name_map = {{"Out", {*new_var_name}}}; |
| 204 | + int dst_place_type = platform::is_cpu_place(dst_place) |
| 205 | + ? 0 |
| 206 | + : platform::is_gpu_place(dst_place) ? 1 : -1; |
| 207 | + AttributeMap attr_map = {{"dst_place_type", dst_place_type}}; |
| 208 | + |
| 209 | + // 3. Create transfer_op |
| 210 | + std::string op_type = get_memcpy_type(src_place, dst_place); |
| 211 | + auto& op_info = OpInfoMap::Instance().Get(op_type); |
| 212 | + auto op = std::shared_ptr<OperatorBase>( |
| 213 | + op_info.Creator()(op_type, in_name_map, out_name_map, attr_map)); |
| 214 | + |
| 215 | + VLOG(3) << string::Sprintf("Insert %s with %s(%s) -> %s(%s).", op_type, |
| 216 | + var_name, src_place, *new_var_name, dst_place); |
| 217 | + return op; |
| 218 | +} |
| 219 | + |
| 220 | +void ApplyDataTransform(const OpKernelType& expected_kernel_key, |
| 221 | + const platform::Place& place, |
| 222 | + VariableValueMap* ins_map_temp, |
| 223 | + VariableScope* var_scope, OpFuncNode* op_func_node, |
| 224 | + std::vector<OpFuncNode>* new_op_func_nodes, |
| 225 | + bool use_local_scope) { |
| 226 | + auto op_base = op_func_node->operator_base_.get(); |
| 227 | + PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet( |
| 228 | + "op_base is null, please pass a valid " |
| 229 | + "op_base in apply_data_transform.")); |
| 230 | + |
| 231 | + VariableNameMap new_ins(op_base->Inputs()); |
| 232 | + // record the no need transform variable index. |
| 233 | + std::unordered_set<int> no_data_transform_index; |
| 234 | + |
| 235 | + DataTranferHelper data_transfer_helper(place, var_scope); |
| 236 | + for (auto& var_name_item : *ins_map_temp) { |
| 237 | + for (size_t i = 0; i < var_name_item.second.size(); ++i) { |
| 238 | + auto var = var_name_item.second[i]; |
| 239 | + if (!(var->IsType<LoDTensor>() || var->IsType<SelectedRows>())) { |
| 240 | + continue; |
| 241 | + } |
| 242 | + auto& var_name = new_ins[var_name_item.first].at(i); |
| 243 | + auto tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); |
| 244 | + if (!tensor_in->IsInitialized()) { |
| 245 | + continue; |
| 246 | + } |
| 247 | + auto kernel_type_for_var = |
| 248 | + static_cast<const framework::OperatorWithKernel*>(op_base) |
| 249 | + ->GetKernelTypeForVar(var_name_item.first, *tensor_in, |
| 250 | + expected_kernel_key); |
| 251 | + // apply data transform |
| 252 | + std::string new_var_name; |
| 253 | + bool is_transferred = data_transfer_helper.apply( |
| 254 | + kernel_type_for_var, expected_kernel_key, var_name, &new_var_name, |
| 255 | + new_op_func_nodes, use_local_scope); |
| 256 | + |
| 257 | + if (is_transferred) { |
| 258 | + // update RuntimeContext.inputs and original op_func_node inputs |
| 259 | + op_func_node->input_index[var_name_item.first][i] = |
| 260 | + var_scope->VarId(new_var_name); |
| 261 | + var_name_item.second[i] = var_scope->Var(new_var_name); |
| 262 | + new_ins[var_name_item.first][i] = new_var_name; |
| 263 | + // NOTE(Aurelius84): avoid deepcopy twice if we already insert data |
| 264 | + // transfer op. |
| 265 | + if (op_base->Type() == "fetch_v2") { |
| 266 | + op_base->SetAttr("deepcopy", false); |
| 267 | + } |
| 268 | + } else { |
| 269 | + // record no need data transformer input var_id |
| 270 | + VLOG(3) << op_base->Type() |
| 271 | + << " found no data_transform var: " << var_name |
| 272 | + << " with id: " << var_scope->VarId(var_name); |
| 273 | + no_data_transform_index.emplace(var_scope->VarId(var_name)); |
| 274 | + } |
| 275 | + } |
| 276 | + } |
| 277 | + |
| 278 | + // NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent |
| 279 | + // with instruction. (hot fix, it is not good design here) |
| 280 | + op_func_node->operator_base_ = |
| 281 | + std::shared_ptr<OperatorBase>(framework::OpRegistry::CreateOp( |
| 282 | + op_base->Type(), new_ins, op_base->Outputs(), op_base->Attrs())); |
| 283 | + op_func_node->no_data_transform_index = std::move(no_data_transform_index); |
| 284 | +} |
| 285 | + |
| 286 | +std::string get_memcpy_type(const platform::Place& src_place, |
| 287 | + const platform::Place& dst_place) { |
| 288 | + PADDLE_ENFORCE_EQ(platform::is_same_place(src_place, dst_place), false, |
| 289 | + platform::errors::PreconditionNotMet( |
| 290 | + "Required src_place shall be different with dst_place, " |
| 291 | + "but received same place: %s", |
| 292 | + src_place)); |
| 293 | + if (platform::is_gpu_place(dst_place)) { |
| 294 | + return kMemcpyH2D; |
| 295 | + } else if (platform::is_gpu_place(src_place)) { |
| 296 | + return kMemcpyD2H; |
| 297 | + } else { |
| 298 | + PADDLE_THROW(platform::errors::PreconditionNotMet( |
| 299 | + "Not support Memcpy typ : %s -> %s", src_place, dst_place)); |
| 300 | + } |
| 301 | +} |
| 302 | + |
| 303 | +} // namespace interpreter |
| 304 | +} // namespace framework |
| 305 | +} // namespace paddle |
0 commit comments