|
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | 15 | #include "paddle/fluid/framework/details/broadcast_op_handle.h" |
| 16 | +#include "paddle/fluid/framework/details/container_cast.h" |
| 17 | +#include "paddle/fluid/framework/details/variable_visitor.h" |
16 | 18 |
|
17 | 19 | namespace paddle { |
18 | 20 | namespace framework { |
19 | 21 | namespace details { |
20 | | - |
21 | | -Tensor *GetTensorFromVar(Variable *in_var) { |
22 | | - if (in_var->IsType<LoDTensor>()) { |
23 | | - return in_var->GetMutable<LoDTensor>(); |
24 | | - } else if (in_var->IsType<SelectedRows>()) { |
25 | | - return in_var->GetMutable<SelectedRows>()->mutable_value(); |
26 | | - } else { |
27 | | - PADDLE_THROW("Var should be LoDTensor or SelectedRows"); |
28 | | - } |
29 | | - return nullptr; |
30 | | -} |
31 | | - |
32 | 22 | BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes, |
33 | 23 | const std::vector<platform::Place> &places) |
34 | 24 | : local_scopes_(local_scopes), places_(places) {} |
35 | 25 |
|
36 | 26 | void BroadcastOpHandle::RunImpl() { |
37 | 27 | // the input and output may have dummy var. |
38 | | - std::vector<VarHandle *> in_var_handle = GetValidVarHandles(inputs_); |
39 | | - std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_); |
| 28 | + VarHandle *in_var_handle; |
| 29 | + |
| 30 | + { |
| 31 | + auto in_var_handles = DynamicCast<VarHandle>(inputs_); |
| 32 | + PADDLE_ENFORCE_EQ(in_var_handles.size(), 1, |
| 33 | + "The number of input should be one."); |
| 34 | + in_var_handle = in_var_handles[0]; |
| 35 | + } |
| 36 | + |
| 37 | + auto out_var_handles = DynamicCast<VarHandle>(outputs_); |
40 | 38 |
|
41 | | - PADDLE_ENFORCE_EQ(in_var_handle.size(), 1, |
42 | | - "The number of input should be one."); |
43 | 39 | PADDLE_ENFORCE_EQ( |
44 | 40 | out_var_handles.size(), places_.size(), |
45 | 41 | "The number of output should equal to the number of places."); |
46 | 42 |
|
47 | | - // Wait input done, this Wait is asynchronous operationplatform::Place |
| 43 | + // Wait input done, this Wait is asynchronous operation platform::Place |
48 | 44 | // &in_place; |
49 | | - WaitEvents(out_var_handles, in_var_handle); |
| 45 | + WaitInputVarGenerated(*in_var_handle); |
50 | 46 |
|
51 | | - auto in_place = in_var_handle[0]->place_; |
52 | | - auto in_scope_idx = in_var_handle[0]->scope_idx_; |
53 | | - auto in_var = |
54 | | - local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_); |
55 | | - Tensor *in_tensor = GetTensorFromVar(in_var); |
| 47 | + auto *in_var = local_scopes_.at(in_var_handle->scope_idx_) |
| 48 | + ->FindVar(in_var_handle->name_); |
| 49 | + PADDLE_ENFORCE_NOT_NULL(in_var); |
| 50 | + Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); |
56 | 51 |
|
57 | 52 | for (auto *out : out_var_handles) { |
| 53 | + if (*out == *in_var_handle) { |
| 54 | + continue; |
| 55 | + } |
| 56 | + |
58 | 57 | auto &out_p = out->place_; |
59 | | - auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_); |
| 58 | + auto *out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_); |
60 | 59 |
|
61 | | - PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(), |
| 60 | + PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(), |
62 | 61 | "Places must be all on CPU or all on CUDA."); |
63 | 62 |
|
64 | | - if (in_var->IsType<framework::SelectedRows>()) { |
65 | | - auto &in_sr = in_var->Get<framework::SelectedRows>(); |
66 | | - auto out_sr = out_var->GetMutable<framework::SelectedRows>(); |
67 | | - if (&in_sr == out_sr) continue; |
68 | | - out_sr->set_height(in_sr.height()); |
69 | | - out_sr->set_rows(in_sr.rows()); |
70 | | - out_sr->mutable_value()->Resize(in_sr.value().dims()); |
71 | | - out_sr->mutable_value()->mutable_data(out_p, in_sr.value().type()); |
72 | | - } else if (in_var->IsType<framework::LoDTensor>()) { |
73 | | - auto in_lod = in_var->Get<framework::LoDTensor>(); |
74 | | - auto out_lod = out_var->GetMutable<framework::LoDTensor>(); |
75 | | - if (&in_lod == out_lod) continue; |
76 | | - out_lod->set_lod(in_lod.lod()); |
77 | | - out_lod->Resize(in_lod.dims()); |
78 | | - out_lod->mutable_data(out_p, in_lod.type()); |
79 | | - } else { |
80 | | - PADDLE_THROW("Var should be LoDTensor or SelectedRows."); |
81 | | - } |
| 63 | + VariableVisitor::ShareDimsAndLoD(*in_var, out_var); |
| 64 | + VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p, |
| 65 | + in_tensor.type()); |
82 | 66 |
|
83 | 67 | auto dev_ctx = dev_ctxes_[out_p]; |
84 | 68 | RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { |
85 | | - Tensor *out_tensor = GetTensorFromVar(out_var); |
86 | | - paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctx), out_tensor); |
| 69 | + paddle::framework::TensorCopy( |
| 70 | + in_tensor, out_p, *(dev_ctx), |
| 71 | + &VariableVisitor::GetMutableTensor(out_var)); |
87 | 72 | }); |
88 | 73 | } |
89 | 74 | } |
90 | 75 |
|
91 | | -void BroadcastOpHandle::WaitEvents( |
92 | | - const std::vector<VarHandle *> &out_var_handles, |
93 | | - const std::vector<VarHandle *> &in_var_handle) { |
94 | | - if (in_var_handle[0]->generated_op_) { |
95 | | - for (auto *out : out_var_handles) { |
96 | | - auto &out_p = out->place_; |
97 | | - in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]); |
98 | | - } |
99 | | - } |
100 | | -} |
101 | | - |
102 | | -std::vector<VarHandle *> BroadcastOpHandle::GetValidVarHandles( |
103 | | - const std::vector<VarHandleBase *> &inputs) { |
104 | | - std::vector<VarHandle *> in_var_handle; |
105 | | - for (auto *in : inputs) { |
106 | | - auto *out_handle = dynamic_cast<VarHandle *>(in); |
107 | | - if (out_handle) { |
108 | | - in_var_handle.push_back(out_handle); |
109 | | - } |
| 76 | +void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) { |
| 77 | + for (auto &pair : dev_ctxes_) { |
| 78 | + in_var.generated_op_->Wait(pair.second); |
110 | 79 | } |
111 | | - return in_var_handle; |
112 | 80 | } |
113 | 81 |
|
114 | 82 | std::string BroadcastOpHandle::Name() const { return "broadcast"; } |
|
0 commit comments