Skip to content

Commit

Permalink
update xpu fusedadam opbuilder for pytorch 2.3 (#5702)
Browse files Browse the repository at this point in the history
update the way to get queue for FusedAdam OpBuilder.

---------

Signed-off-by: baodii <di.bao@intel.com>
Co-authored-by: Logan Adams <loadams@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 1, 2024
1 parent df58a78 commit e392296
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions csrc/xpu/adam/multi_tensor_apply.dp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <c10/xpu/XPUStream.h>
#include <ipex.h>
#include <sycl/sycl.hpp>
#include "compat.h"
Expand All @@ -22,10 +23,8 @@ namespace at {
namespace cuda {
sycl::queue* getCurrentCUDAStream()
{
auto device_type = c10::DeviceType::XPU;
c10::impl::VirtualGuardImpl impl(device_type);
c10::Stream c10_stream = impl.getStream(c10::Device(device_type));
auto& queue = xpu::get_queue_from_stream(c10_stream);
c10::xpu::XPUStream stream = c10::xpu::getCurrentXPUStream();
auto& queue = stream.queue();
return &queue;
}

Expand Down

0 comments on commit e392296

Please sign in to comment.