Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framewor
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)

cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)

cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base variable_visitor scope ddim memory)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope variable_visitor ddim memory)

cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle)
Expand Down
107 changes: 42 additions & 65 deletions paddle/fluid/framework/details/broadcast_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,95 +13,72 @@
// limitations under the License.

#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"

namespace paddle {
namespace framework {
namespace details {

Tensor *GetTensorFromVar(Variable *in_var) {
if (in_var->IsType<LoDTensor>()) {
return in_var->GetMutable<LoDTensor>();
} else if (in_var->IsType<SelectedRows>()) {
return in_var->GetMutable<SelectedRows>()->mutable_value();
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
}
return nullptr;
}

BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}

void BroadcastOpHandle::RunImpl() {
// the input may have dummy var.
std::vector<VarHandle *> in_var_handle;
for (auto *in : inputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(in);
if (out_handle) {
in_var_handle.push_back(out_handle);
}
}
PADDLE_ENFORCE_EQ(in_var_handle.size(), 1,
"The number of input should be one.");
// the input and output may have dummy var.
VarHandle *in_var_handle;

// the output may have dummy var.
std::vector<VarHandle *> out_var_handles;
for (auto *out : outputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(out);
if (out_handle) {
out_var_handles.push_back(out_handle);
}
{
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1,
"The number of input should be one.");
in_var_handle = in_var_handles[0];
}

auto out_var_handles = DynamicCast<VarHandle>(outputs_);

PADDLE_ENFORCE_EQ(
out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");

// Wait input done, this Wait is asynchronous operation
auto &in_place = in_var_handle[0]->place_;
if (in_var_handle[0]->generated_op_) {
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
}
}
// Wait input done, this Wait is asynchronous operation platform::Place
// &in_place;
WaitInputVarGenerated(*in_var_handle);

//
auto in_scope_idx = in_var_handle[0]->scope_idx_;
auto in_var =
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_);
Tensor *in_tensor = GetTensorFromVar(in_var);
auto *in_var = local_scopes_.at(in_var_handle->scope_idx_)
->FindVar(in_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(in_var);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);

for (auto *out : out_var_handles) {
if (*out == *in_var_handle) {
continue;
}

auto &out_p = out->place_;
auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);
auto *out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);

PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(),
PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(),
"Places must be all on CPU or all on CUDA.");

if (in_var->IsType<framework::SelectedRows>()) {
auto &in_sr = in_var->Get<framework::SelectedRows>();
auto out_sr = out_var->GetMutable<framework::SelectedRows>();
if (&in_sr == out_sr) continue;
out_sr->set_height(in_sr.height());
out_sr->set_rows(in_sr.rows());
out_sr->mutable_value()->Resize(in_sr.value().dims());
out_sr->mutable_value()->mutable_data(out_p, in_sr.value().type());
} else if (in_var->IsType<framework::LoDTensor>()) {
auto in_lod = in_var->Get<framework::LoDTensor>();
auto out_lod = out_var->GetMutable<framework::LoDTensor>();
if (&in_lod == out_lod) continue;
out_lod->set_lod(in_lod.lod());
out_lod->Resize(in_lod.dims());
out_lod->mutable_data(out_p, in_lod.type());
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
}
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var)
.Resize(in_tensor.dims())
.mutable_data(out_p, in_tensor.type());

Tensor *out_tensor = GetTensorFromVar(out_var);
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]),
out_tensor);
auto dev_ctx = dev_ctxes_[out_p];
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
paddle::framework::TensorCopy(
in_tensor, out_p, *(dev_ctx),
&VariableVisitor::GetMutableTensor(out_var));
});
}
}

void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
if (in_var.generated_op_) {
for (auto &pair : dev_ctxes_) {
in_var.generated_op_->Wait(pair.second);
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/broadcast_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase {

protected:
void RunImpl() override;
void WaitInputVarGenerated(const VarHandle &in_var);

private:
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
};

} // namespace details
} // namespace framework
} // namespace paddle
40 changes: 40 additions & 0 deletions paddle/fluid/framework/details/container_cast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <type_traits>
#include <vector>

namespace paddle {
namespace framework {
namespace details {

template <typename ResultType, typename ElemType>
std::vector<ResultType*> DynamicCast(const std::vector<ElemType*>& container) {
static_assert(std::is_base_of<ElemType, ResultType>::value,
"ElementType must be a base class of ResultType");
std::vector<ResultType*> res;
for (auto* ptr : container) {
auto* derived = dynamic_cast<ResultType*>(ptr);
if (derived) {
res.emplace_back(derived);
}
}
return res;
}

} // namespace details
} // namespace framework
} // namespace paddle
83 changes: 43 additions & 40 deletions paddle/fluid/framework/details/gather_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// limitations under the License.

#include "paddle/fluid/framework/details/gather_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"

namespace paddle {
namespace framework {
Expand All @@ -23,81 +25,68 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
: local_scopes_(local_scopes), places_(places) {}

void GatherOpHandle::RunImpl() {
// the input may have dummy var.
std::vector<VarHandle *> in_var_handles;
for (auto *in : inputs_) {
auto *in_handle = dynamic_cast<VarHandle *>(in);
if (in_handle) {
in_var_handles.push_back(in_handle);
}
}
// the input and output may have dummy var.
auto in_var_handles = DynamicCast<VarHandle>(inputs_);

PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");

// the output may have dummy var.
std::vector<VarHandle *> out_var_handles;
for (auto *out : outputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(out);
if (out_handle) {
out_var_handles.push_back(out_handle);
}
VarHandle *out_var_handle;
{
auto out_var_handles = DynamicCast<VarHandle>(outputs_);

PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one.");
out_var_handle = out_var_handles.front();
}
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one.");

auto in_0_handle = static_cast<VarHandle *>(in_var_handles[0]);
auto in_0_handle = in_var_handles[0];
auto pre_in_var =
local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_);
auto pre_place = in_0_handle->place_;

PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(),
"Currently, gather_op only can gather SelectedRows.");

PADDLE_ENFORCE_EQ(out_var_handles[0]->place_.which(), pre_place.which(),
PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(),
"The place of input and output should be the same.");

// Wait input done, this Wait is asynchronous operation
for (auto *in : in_var_handles) {
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[in->place_]);
}
}
WaitInputVarGenerated(in_var_handles);

std::vector<int64_t> out_rows;
std::vector<Tensor> in_tensors;
std::vector<platform::Place> in_places;

auto &pre_in = pre_in_var->Get<framework::SelectedRows>();
// gather the inputs
for (auto *in : in_var_handles) {
auto in_handle = static_cast<VarHandle *>(in);
for (auto *in_handle : in_var_handles) {
auto in_p = in_handle->place_;
in_places.push_back(in_p);
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"Places must be all on CPU or all on CUDA.");
auto in_var =
auto *in_var =
local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
auto &in_sr = in_var->Get<framework::SelectedRows>();

PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(),
"The type of input is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(),
"The height of inputs is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), ,
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(),
"The dims of inputs is not consistent.");

auto in_sr_rows = in_sr.rows();
auto &in_sr_rows = in_sr.rows();
out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end());

in_tensors.emplace_back(in_sr.value());
}

// write the output
auto &out_place = out_var_handles[0]->place_;
auto out_scope_idx = out_var_handles[0]->scope_idx_;
auto out_var =
local_scopes_[out_scope_idx]->FindVar(out_var_handles[0]->name_);
auto &out_place = out_var_handle->place_;
auto out_scope_idx = out_var_handle->scope_idx_;
auto out_var = local_scopes_[out_scope_idx]->FindVar(out_var_handle->name_);

auto out = out_var->GetMutable<framework::SelectedRows>();
out->set_height(pre_in.height());
Expand All @@ -110,13 +99,27 @@ void GatherOpHandle::RunImpl() {
Tensor *out_tensor = out->mutable_value();

// copy
int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors[j], out_place,
*(dev_ctxes_[in_places[j]]), &sub_out);
s = e;
auto dev_ctx = dev_ctxes_[out_place];
RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] {
int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors[j], out_place, *(dev_ctx),
&sub_out);
s = e;
}
});
}

void GatherOpHandle::WaitInputVarGenerated(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
for (auto pair : dev_ctxes_) {
in->generated_op_->Wait(pair.second);
}
}
}
}

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/details/gather_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct GatherOpHandle : public OpHandleBase {

protected:
void RunImpl() override;
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);

private:
const std::vector<Scope *> &local_scopes_;
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/details/var_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ struct VarHandle : public VarHandleBase {
size_t scope_idx_;
std::string name_;
platform::Place place_;

bool operator==(const VarHandle& o) const {
return o.generated_op_ == generated_op_ && o.name_ == name_ &&
o.scope_idx_ == scope_idx_;
}
};

// Dummy Variable. It is used to represent dependencies between operators
Expand Down
Loading