Skip to content

Commit 2fb74ce

Browse files
committed
Device agnostic for DCP
1 parent 0f21fa8 commit 2fb74ce

File tree

5 files changed

+33
-29
lines changed

5 files changed

+33
-29
lines changed

test/distributed/checkpoint/_experimental/test_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def test_make_async_checkpointer(self) -> None:
123123
# Create async checkpointer using factory function with default parameters
124124
config: CheckpointerConfig = CheckpointerConfig()
125125
config.staging_config = CheckpointStagerConfig(
126-
use_cuda_non_blocking_copy=torch.cuda.is_available(),
126+
use_non_blocking_copy=torch.cuda.is_available(),
127127
use_pinned_memory=torch.cuda.is_available(),
128128
)
129129
checkpointer = make_async_checkpointer(config=config, rank_info=self.rank_info)

test/distributed/checkpoint/_experimental/test_staging.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_cuda_non_blocking_without_cuda(self) -> None:
7474
if torch.cuda.is_available():
7575
self.skipTest("CUDA is available, cannot test CUDA unavailable scenario")
7676

77-
options = CheckpointStagerConfig(use_cuda_non_blocking_copy=True)
77+
options = CheckpointStagerConfig(use_non_blocking_copy=True)
7878
with self.assertRaises(AssertionError):
7979
DefaultStager(options)
8080

@@ -86,21 +86,21 @@ def test_different_option_combinations(self) -> None:
8686
use_pinned_memory=False,
8787
use_shared_memory=False,
8888
use_async_staging=False,
89-
use_cuda_non_blocking_copy=False,
89+
use_non_blocking_copy=False,
9090
),
9191
# Only pinned memory
9292
CheckpointStagerConfig(
9393
use_pinned_memory=True,
9494
use_shared_memory=False,
9595
use_async_staging=False,
96-
use_cuda_non_blocking_copy=False,
96+
use_non_blocking_copy=False,
9797
),
9898
# Only shared memory
9999
CheckpointStagerConfig(
100100
use_pinned_memory=False,
101101
use_shared_memory=True,
102102
use_async_staging=False,
103-
use_cuda_non_blocking_copy=False,
103+
use_non_blocking_copy=False,
104104
),
105105
]
106106

@@ -111,7 +111,7 @@ def test_different_option_combinations(self) -> None:
111111
use_pinned_memory=torch.cuda.is_available(),
112112
use_shared_memory=False,
113113
use_async_staging=True,
114-
use_cuda_non_blocking_copy=False,
114+
use_non_blocking_copy=False,
115115
)
116116
)
117117
# Only CUDA non-blocking copy
@@ -120,7 +120,7 @@ def test_different_option_combinations(self) -> None:
120120
use_pinned_memory=torch.cuda.is_available(),
121121
use_shared_memory=False,
122122
use_async_staging=False,
123-
use_cuda_non_blocking_copy=torch.cuda.is_available(),
123+
use_non_blocking_copy=torch.cuda.is_available(),
124124
)
125125
)
126126

@@ -185,7 +185,7 @@ def test_multiple_staging_operations(self) -> None:
185185
use_async_staging=False,
186186
use_pinned_memory=torch.cuda.is_available(),
187187
use_shared_memory=False,
188-
use_cuda_non_blocking_copy=torch.cuda.is_available(),
188+
use_non_blocking_copy=torch.cuda.is_available(),
189189
)
190190
stager = DefaultStager(options)
191191

test/distributed/checkpoint/e2e/test_e2e_save_and_load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def _run_e2e_test(
279279
use_async_staging=zoc,
280280
use_shared_memory=use_shared_memory,
281281
use_pinned_memory=zoc,
282-
use_cuda_non_blocking_copy=zoc,
282+
use_non_blocking_copy=zoc,
283283
)
284284
stager = DefaultStager(staging_options)
285285
async_save_response_or_future = saver.async_save(

torch/distributed/checkpoint/_experimental/staging.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class CheckpointStagerConfig:
8282
use_async_staging (bool): Enable asynchronous staging using a
8383
background thread pool. Allows overlapping computation with
8484
staging operations. Requires CUDA. Default: True
85-
use_cuda_non_blocking_copy (bool): Use non-blocking CUDA memory
85+
use_non_blocking_copy (bool): Use non-blocking CUDA memory
8686
copies with stream synchronization. Improves performance by
8787
allowing CPU work to continue during GPU transfers. Default: True
8888
@@ -93,7 +93,7 @@ class CheckpointStagerConfig:
9393
use_pinned_memory: bool = True
9494
use_shared_memory: bool = True
9595
use_async_staging: bool = True
96-
use_cuda_non_blocking_copy: bool = True
96+
use_non_blocking_copy: bool = True
9797

9898

9999
class DefaultStager(CheckpointStager):
@@ -153,15 +153,17 @@ def __init__(
153153

154154
if self._config.use_async_staging:
155155
self._staging_executor = ThreadPoolExecutor(max_workers=1)
156-
if torch.cuda.is_available():
156+
if torch.accelerator.is_available():
157157
# Note: stream needs to be initialized on the main thread after default cuda
158158
# stream is setup/used to avoid the risk of accidentally reusing the main
159159
# compute stream or in other cases kernels actually launching from the
160160
# main thread.
161-
self._staging_stream = torch.cuda.Stream()
161+
self._staging_stream = torch.Stream()
162162

163-
if self._config.use_cuda_non_blocking_copy:
164-
assert torch.cuda.is_available(), "Non-blocking copy requires CUDA"
163+
if self._config.use_non_blocking_copy:
164+
assert torch.accelerator.is_available(), (
165+
"Non-blocking copy requires CUDA/XPU"
166+
)
165167

166168
def stage(
167169
self,
@@ -182,16 +184,16 @@ def stage(
182184

183185
def _stage(self, state_dict: STATE_DICT, **kwargs: Any) -> STATE_DICT:
184186
state_dict = self._state_dict_stager.stage(
185-
state_dict, non_blocking=self._config.use_cuda_non_blocking_copy, **kwargs
187+
state_dict, non_blocking=self._config.use_non_blocking_copy, **kwargs
186188
)
187189

188-
if self._config.use_cuda_non_blocking_copy:
190+
if self._config.use_non_blocking_copy:
189191
assert self._staging_stream or not self._config.use_async_staging, (
190-
"Non-blocking cuda copy in a background thread for async staging needs staging_stream to be initialized."
192+
"Non-blocking copy in a background thread for async staging needs staging_stream to be initialized."
191193
)
192194

193195
# waits for the enqued copy operations to finish.
194-
self._staging_stream.synchronize() if self._staging_stream else torch.cuda.synchronize()
196+
self._staging_stream.synchronize() if self._staging_stream else torch.accelerator.synchronize()
195197

196198
return state_dict
197199

torch/distributed/checkpoint/staging.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class StagingOptions:
110110
use_async_staging (bool): Enable asynchronous staging using a
111111
background thread pool. Allows overlapping computation with
112112
staging operations. Requires CUDA. Default: True
113-
use_cuda_non_blocking_copy (bool): Use non-blocking CUDA memory
113+
use_non_blocking_copy (bool): Use non-blocking CUDA memory
114114
copies with stream synchronization. Improves performance by
115115
allowing CPU work to continue during GPU transfers. Default: True
116116
@@ -121,7 +121,7 @@ class StagingOptions:
121121
use_pinned_memory: bool = True
122122
use_shared_memory: bool = True
123123
use_async_staging: bool = True
124-
use_cuda_non_blocking_copy: bool = True
124+
use_non_blocking_copy: bool = True
125125

126126

127127
class DefaultStager(AsyncStager):
@@ -177,15 +177,17 @@ def __init__(
177177
self._staging_stream = None
178178
if self._config.use_async_staging:
179179
self._staging_executor = ThreadPoolExecutor(max_workers=1)
180-
if torch.cuda.is_available():
180+
if torch.accelerator.is_available():
181181
# Note: stream needs to be initialized on the main thread after default cuda
182182
# stream is setup/used to avoid the risk of accidentally reusing the main
183183
# compute stream or in other cases kernels actually launching from the
184184
# main thread.
185-
self._staging_stream = torch.cuda.Stream()
185+
self._staging_stream = torch.Stream()
186186

187-
if self._config.use_cuda_non_blocking_copy:
188-
assert torch.cuda.is_available(), "Non-blocking copy requires CUDA"
187+
if self._config.use_non_blocking_copy:
188+
assert torch.accelerator.is_available(), (
189+
"Non-blocking copy requires CUDA/XPU"
190+
)
189191

190192
self._staging_future: Optional[Future[STATE_DICT_TYPE]] = None
191193

@@ -216,20 +218,20 @@ def stage(
216218
return self._stage(state_dict, **kwargs)
217219

218220
def _stage(self, state_dict: STATE_DICT_TYPE, **kwargs: Any) -> STATE_DICT_TYPE:
219-
if self._config.use_cuda_non_blocking_copy:
221+
if self._config.use_non_blocking_copy:
220222
assert self._staging_stream or not self._config.use_async_staging, (
221-
"Non-blocking cuda copy in a background thread for async staging needs staging_stream to be initialized."
223+
"Non-blocking copy in a background thread for async staging needs staging_stream to be initialized."
222224
)
223225
with (
224226
self._staging_stream
225227
if self._staging_stream is not None
226228
else nullcontext()
227229
):
228230
state_dict = self._state_dict_stager.stage(
229-
state_dict, non_blocking=self._config.use_cuda_non_blocking_copy
231+
state_dict, non_blocking=self._config.use_non_blocking_copy
230232
)
231233
# waits for the enqued copy operations to finish.
232-
self._staging_stream.synchronize() if self._staging_stream else torch.cuda.synchronize()
234+
self._staging_stream.synchronize() if self._staging_stream else torch.accelerator.synchronize()
233235
else:
234236
state_dict = self._state_dict_stager.stage(state_dict, non_blocking=False)
235237
return state_dict

0 commit comments

Comments
 (0)