We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2832e7b commit 5265631Copy full SHA for 5265631
csrc/cache_kernels.cu
@@ -34,7 +34,7 @@ void swap_blocks(
34
char *dst_ptr = static_cast<char*>(dst.data_ptr());
35
36
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
37
- const at::cuda::OptionalCUDAGuard device_guard(src_device);
+ const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
38
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
39
// NOTE(woosuk): This can be slow if the number of blocks is large.
40
for (const auto& pair : block_mapping) {
0 commit comments