@@ -200,10 +200,9 @@ void Copy(const Context& dev_ctx,
200200 paddle::memory::Copy (
201201 dst_cuda_pinned_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
202202#endif
203- }
204203#ifdef PADDLE_WITH_XPU
205- else if (paddle::platform::is_xpu_place (src_place) && // NOLINT
206- paddle::platform::is_cpu_place (dst_place)) {
204+ } else if (paddle::platform::is_xpu_place (src_place) && // NOLINT
205+ paddle::platform::is_cpu_place (dst_place)) {
207206 paddle::memory::Copy (dst_place, dst_ptr, src_place, src_ptr, size);
208207 } else if (paddle::platform::is_cpu_place (src_place) &&
209208 paddle::platform::is_xpu_place (dst_place)) {
@@ -216,11 +215,40 @@ void Copy(const Context& dev_ctx,
216215 return ;
217216 }
218217 paddle::memory::Copy (dst_place, dst_ptr, src_place, src_ptr, size);
218+ #endif
219+ #ifdef PADDLE_WITH_CUSTOM_DEVICE
220+ } else if (paddle::platform::is_custom_place (src_place) && // NOLINT
221+ paddle::platform::is_cpu_place (dst_place)) {
222+ auto stream =
223+ blocking
224+ ? nullptr
225+ : reinterpret_cast <const paddle::platform::CustomDeviceContext&>(
226+ dev_ctx)
227+ .stream ();
228+ paddle::memory::Copy (dst_place, dst_ptr, src_place, src_ptr, size, stream);
229+ } else if (paddle::platform::is_cpu_place (src_place) && // NOLINT
230+ paddle::platform::is_custom_place (dst_place)) {
231+ auto stream =
232+ blocking
233+ ? nullptr
234+ : reinterpret_cast <const paddle::platform::CustomDeviceContext&>(
235+ dev_ctx)
236+ .stream ();
237+ paddle::memory::Copy (dst_place, dst_ptr, src_place, src_ptr, size, stream);
238+ } else if (paddle::platform::is_custom_place (src_place) && // NOLINT
239+ paddle::platform::is_custom_place (dst_place)) {
240+ auto stream =
241+ blocking
242+ ? nullptr
243+ : reinterpret_cast <const paddle::platform::CustomDeviceContext&>(
244+ dev_ctx)
245+ .stream ();
246+ paddle::memory::Copy (dst_place, dst_ptr, src_place, src_ptr, size, stream);
247+ #endif
219248 } else {
220249 PADDLE_THROW (phi::errors::Unimplemented (
221250 " Copy from %s to %s is not supported." , src_place, dst_place));
222251 }
223- #endif
224252}
225253
226254template <typename Context>
@@ -363,4 +391,11 @@ template void Copy(const XPUContext& dev_ctx,
363391 DenseTensor* dst);
364392#endif
365393
394+ #ifdef PADDLE_WITH_CUSTOM_DEVICE
395+ template void Copy (const CustomContext& dev_ctx,
396+ const DenseTensor& src,
397+ Place dst_place,
398+ bool blocking,
399+ DenseTensor* dst);
400+ #endif
366401} // namespace phi
0 commit comments