Skip to content

Commit 50662eb

Browse files
awaelchlicarmocca
andauthored
Fixes around Strategy.set_world_ranks (Lightning-AI#16966)
* don't call set_world_ranks in xla strategy * update * fabric and other strategies * CHANGELOG * Typos * Reuse test --------- Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
1 parent 17548d5 commit 50662eb

File tree

13 files changed

+50
-49
lines changed

13 files changed

+50
-49
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424

2525
- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))
2626

27+
28+
- On XLA, avoid setting the global rank before processes have been launched as this will initialize the PJRT computation client in the main process ([#16966](https://github.com/Lightning-AI/lightning/pull/16966))
29+
2730
### Deprecated
2831

2932
-
@@ -39,6 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3942
- Fixed issue where running on TPUs would select the wrong device index ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))
4043

4144

45+
- Fixed issue where Fabric would not initialize the global rank, world size, and rank-zero-only rank after initialization and before launch ([#16966](https://github.com/Lightning-AI/lightning/pull/16966))
46+
47+
4248
## [2.0.1.post0] - 2023-04-11
4349

4450
No changes

src/lightning/fabric/connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,8 @@ def _lazy_init_strategy(self) -> None:
507507
self.strategy.parallel_devices = self._parallel_devices
508508
if hasattr(self.strategy, "num_nodes"):
509509
self.strategy._num_nodes = self._num_nodes_flag
510-
if hasattr(self.strategy, "set_world_ranks"):
511-
self.strategy.set_world_ranks()
510+
if hasattr(self.strategy, "_set_world_ranks"):
511+
self.strategy._set_world_ranks()
512512
self.strategy._configure_launcher()
513513

514514
if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible:

src/lightning/fabric/plugins/environments/lsf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def world_size(self) -> int:
8888
if world_size is None:
8989
raise ValueError(
9090
"Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found."
91-
"Make sure you run your executable with `jsrun`."
91+
" Make sure you run your executable with `jsrun`."
9292
)
9393
return int(world_size)
9494

@@ -101,7 +101,7 @@ def global_rank(self) -> int:
101101
if global_rank is None:
102102
raise ValueError(
103103
"Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found."
104-
"Make sure you run your executable with `jsrun`."
104+
" Make sure you run your executable with `jsrun`."
105105
)
106106
return int(global_rank)
107107

@@ -114,7 +114,7 @@ def local_rank(self) -> int:
114114
if local_rank is None:
115115
raise ValueError(
116116
"Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found."
117-
"Make sure you run your executable with `jsrun`."
117+
" Make sure you run your executable with `jsrun`."
118118
)
119119
return int(local_rank)
120120

src/lightning/fabric/strategies/ddp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
177177

178178
def _setup_distributed(self) -> None:
179179
self._set_world_ranks()
180-
rank_zero_only.rank = self.global_rank
181180
self._process_group_backend = self._get_process_group_backend()
182181
assert self.cluster_environment is not None
183182
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
@@ -186,11 +185,12 @@ def _get_process_group_backend(self) -> str:
186185
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
187186

188187
def _set_world_ranks(self) -> None:
189-
if self.cluster_environment is None:
190-
return
191-
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
192-
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
193-
rank_zero_only.rank = self.cluster_environment.global_rank()
188+
if self.cluster_environment is not None:
189+
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
190+
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
191+
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
192+
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
193+
rank_zero_only.rank = self.global_rank
194194

195195
def _determine_ddp_device_ids(self) -> Optional[List[int]]:
196196
if self.root_device.type == "cpu":

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from lightning.fabric.strategies.strategy import _Sharded
3434
from lightning.fabric.utilities.distributed import log
3535
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
36-
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
36+
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
3737
from lightning.fabric.utilities.seed import reset_seed
3838
from lightning.fabric.utilities.types import _PATH
3939

@@ -580,7 +580,6 @@ def _setup_distributed(self) -> None:
580580
)
581581
reset_seed()
582582
self._set_world_ranks()
583-
rank_zero_only.rank = self.global_rank
584583
self._init_deepspeed_distributed()
585584
if not self._config_initialized:
586585
self._format_config()

src/lightning/fabric/strategies/fsdp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
320320
def _setup_distributed(self) -> None:
321321
reset_seed()
322322
self._set_world_ranks()
323-
rank_zero_only.rank = self.global_rank
324323
self._process_group_backend = self._get_process_group_backend()
325324
assert self.cluster_environment is not None
326325
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
@@ -329,11 +328,12 @@ def _get_process_group_backend(self) -> str:
329328
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
330329

331330
def _set_world_ranks(self) -> None:
332-
if self.cluster_environment is None:
333-
return
334-
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
335-
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
336-
rank_zero_only.rank = self.cluster_environment.global_rank()
331+
if self.cluster_environment is not None:
332+
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
333+
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
334+
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
335+
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
336+
rank_zero_only.rank = self.global_rank
337337

338338

339339
def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None:

src/lightning/fabric/strategies/xla.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def _configure_launcher(self) -> None:
9393

9494
def setup_environment(self) -> None:
9595
self._launched = True
96-
self._set_world_ranks()
9796
rank_zero_only.rank = self.global_rank
9897
super().setup_environment()
9998

@@ -203,8 +202,3 @@ def remove_checkpoint(self, filepath: _PATH) -> None:
203202
@classmethod
204203
def register_strategies(cls, strategy_registry: Dict) -> None:
205204
strategy_registry.register("xla", cls, description=cls.__class__.__name__)
206-
207-
def _set_world_ranks(self) -> None:
208-
if self.cluster_environment is None:
209-
return
210-
rank_zero_only.rank = self.cluster_environment.global_rank()

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333

3434
- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))
3535

36+
37+
- On XLA, avoid setting the global rank before processes have been launched as this will initialize the PJRT computation client in the main process ([#16966](https://github.com/Lightning-AI/lightning/pull/16966))
38+
3639
### Deprecated
3740

3841
-

src/lightning/pytorch/strategies/ddp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ def setup_distributed(self) -> None:
183183
log.debug(f"{self.__class__.__name__}: setting up distributed...")
184184
reset_seed()
185185
self.set_world_ranks()
186-
rank_zero_only.rank = self.global_rank
187186
self._process_group_backend = self._get_process_group_backend()
188187
assert self.cluster_environment is not None
189188
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
@@ -192,11 +191,12 @@ def _get_process_group_backend(self) -> str:
192191
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
193192

194193
def set_world_ranks(self) -> None:
195-
if self.cluster_environment is None:
196-
return
197-
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
198-
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
199-
rank_zero_only.rank = self.cluster_environment.global_rank()
194+
if self.cluster_environment is not None:
195+
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
196+
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
197+
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
198+
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
199+
rank_zero_only.rank = self.global_rank
200200

201201
def _register_ddp_hooks(self) -> None:
202202
log.debug(f"{self.__class__.__name__}: registering ddp hooks")

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from lightning.pytorch.utilities import GradClipAlgorithmType
4444
from lightning.pytorch.utilities.exceptions import MisconfigurationException
4545
from lightning.pytorch.utilities.model_helpers import is_overridden
46-
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn, WarningCache
46+
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn, WarningCache
4747
from lightning.pytorch.utilities.types import LRSchedulerConfig, STEP_OUTPUT
4848

4949
log = logging.getLogger(__name__)
@@ -326,7 +326,6 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option
326326
def setup_distributed(self) -> None:
327327
reset_seed()
328328
self.set_world_ranks()
329-
rank_zero_only.rank = self.global_rank
330329
self._init_deepspeed_distributed()
331330
if not self._config_initialized:
332331
self._format_config()

0 commit comments

Comments
 (0)