@@ -53,25 +53,25 @@ void BroadcastOpHandle::RunImpl() {
5353
5454 Tensor &in_tensor = VariableVisitor::GetMutableTensor (in_var);
5555
56- // NOTE(zcd): the Place of input can get from in_tensor and in_var_handle ,
57- // maybe they are different, because the Place that getting from in_tensor is
58- // determined at runtime, the other is determined at building SSA graph stage.
59- // If they are different, DataTransform should be applied. Currently, it has
60- // not been done yet.
56+ // NOTE: The tensors' Place of input and output must be all on GPU or all on
57+ // CPU.
6158 for (auto *out_var_handle : out_var_handles) {
6259 if (*out_var_handle == *in_var_handle) {
6360 continue ;
6461 }
65- auto &out_p = out_var_handle->place_ ;
62+ auto t_out_p = out_var_handle->place_ ;
6663 auto *out_var = var_scopes.at (out_var_handle->scope_idx_ )
6764 ->FindVar (out_var_handle->name_ );
6865 PADDLE_ENFORCE_NOT_NULL (out_var);
69- PADDLE_ENFORCE_EQ (
70- out_p.which (), in_tensor.place ().which (),
71- " Currently, Places of input and output must be all on CPU "
72- " or all on GPU." );
66+ if (platform::is_gpu_place (in_tensor.place ())) {
67+ PADDLE_ENFORCE (
68+ platform::is_gpu_place (t_out_p),
69+ " Currently, Places of input and output must be all on GPU." );
70+ } else {
71+ t_out_p = platform::CPUPlace ();
72+ }
7373 VariableVisitor::ShareDimsAndLoD (*in_var, out_var);
74- VariableVisitor::GetMutableTensor (out_var).mutable_data (out_p ,
74+ VariableVisitor::GetMutableTensor (out_var).mutable_data (t_out_p ,
7575 in_tensor.type ());
7676 }
7777
@@ -80,15 +80,13 @@ void BroadcastOpHandle::RunImpl() {
8080 if (*out_var_handle == *in_var_handle) {
8181 continue ;
8282 }
83-
8483 auto &out_p = out_var_handle->place_ ;
85- auto dev_ctx = dev_ctxes_.at (out_p);
8684 auto *out_var = var_scopes.at (out_var_handle->scope_idx_ )
8785 ->FindVar (out_var_handle->name_ );
8886
89- RunAndRecordEvent (out_p, [in_tensor, out_var, dev_ctx, out_p ] {
87+ RunAndRecordEvent (out_p, [in_tensor, out_var] {
9088 paddle::framework::TensorCopy (
91- in_tensor, out_p, *dev_ctx ,
89+ in_tensor, platform::CPUPlace () ,
9290 &VariableVisitor::GetMutableTensor (out_var));
9391 });
9492 }
0 commit comments