Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit 86190b6

Browse files
committed
fix(deferinit): support torch2.2 deferinit with defer_device
1 parent bcaaffb commit 86190b6

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

src/python/torchdistx/_C/fake.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88

99
#include <ATen/Context.h>
1010
#include <ATen/Tensor.h>
11+
#include <torch/torch.h>
12+
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR>=3
13+
#include <torch/csrc/utils/device_lazy_init.h>
14+
#else
1115
#include <torch/csrc/utils/cuda_lazy_init.h>
16+
#endif
1217
#include <torch/csrc/utils/pybind.h>
1318
#include <torchdistx/fake.h>
1419

@@ -22,7 +27,12 @@ void pyEnterFakeMode(bool fake_cuda) {
2227
// subsystem which would fail and prevent us from instantiating CUDA devices.
2328
if (fake_cuda) {
2429
if (!at::hasCUDA()) {
30+
31+
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR>=3
32+
torch::utils::set_requires_device_init(at::kCUDA, false);
33+
#else
2534
torch::utils::set_requires_cuda_init(false);
35+
#endif
2636
}
2737
}
2838
}
@@ -31,7 +41,11 @@ void pyLeaveFakeMode() {
3141
leaveFakeMode();
3242

3343
if (!isFakeModeActive() && !at::hasCUDA()) {
44+
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR>=3
45+
torch::utils::set_requires_device_init(at::kCUDA, true);
46+
#else
3447
torch::utils::set_requires_cuda_init(true);
48+
#endif
3549
}
3650
}
3751

0 commit comments

Comments
 (0)