diff --git a/csrc/xpu/adam/multi_tensor_apply.dp.hpp b/csrc/xpu/adam/multi_tensor_apply.dp.hpp index 04674020a6c5..01964e1a4d38 100644 --- a/csrc/xpu/adam/multi_tensor_apply.dp.hpp +++ b/csrc/xpu/adam/multi_tensor_apply.dp.hpp @@ -10,6 +10,7 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85 #include #include +#include #include #include #include "compat.h" @@ -22,10 +23,8 @@ namespace at { namespace cuda { sycl::queue* getCurrentCUDAStream() { - auto device_type = c10::DeviceType::XPU; - c10::impl::VirtualGuardImpl impl(device_type); - c10::Stream c10_stream = impl.getStream(c10::Device(device_type)); - auto& queue = xpu::get_queue_from_stream(c10_stream); + c10::xpu::XPUStream stream = c10::xpu::getCurrentXPUStream(); + auto& queue = stream.queue(); return &queue; }