Skip to content

Commit f75f3bc

Browse files
authored
Simplify _get_rank() utility function (#19220)
1 parent 564be3b commit f75f3bc

File tree

3 files changed

+5
-20
lines changed

3 files changed

+5
-20
lines changed

src/lightning/fabric/utilities/rank_zero.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,12 @@
3030
)
3131
from typing_extensions import ParamSpec
3232

33-
import lightning.fabric
3433
from lightning.fabric.utilities.imports import _UTILITIES_GREATER_EQUAL_0_10
3534

3635
rank_zero_module.log = logging.getLogger(__name__)
3736

3837

39-
def _get_rank(
40-
strategy: Optional["lightning.fabric.strategies.Strategy"] = None,
41-
) -> Optional[int]:
42-
if strategy is not None:
43-
return strategy.global_rank
38+
def _get_rank() -> Optional[int]:
4439
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
4540
# therefore LOCAL_RANK needs to be checked first
4641
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")

src/lightning/pytorch/callbacks/early_stopping.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from typing_extensions import override
2727

2828
import lightning.pytorch as pl
29-
from lightning.fabric.utilities.rank_zero import _get_rank
3029
from lightning.pytorch.callbacks.callback import Callback
3130
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3231
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn
@@ -265,12 +264,8 @@ def _improvement_message(self, current: Tensor) -> str:
265264
return msg
266265

267266
@staticmethod
268-
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None:
269-
rank = _get_rank(
270-
strategy=(trainer.strategy if trainer is not None else None), # type: ignore[arg-type]
271-
)
272-
if trainer is not None and trainer.world_size <= 1:
273-
rank = None
267+
def _log_info(trainer: "pl.Trainer", message: str, log_rank_zero_only: bool) -> None:
268+
rank = trainer.global_rank if trainer.world_size > 1 else None
274269
message = rank_prefixed_message(message, rank)
275270
if rank is None or not log_rank_zero_only or rank == 0:
276271
log.info(message)

tests/tests_pytorch/callbacks/test_early_stopping.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,6 @@ def test_early_stopping_squeezes():
480480
es_mock.assert_called_once_with(torch.tensor(0))
481481

482482

483-
@pytest.mark.parametrize("trainer", [Trainer(), None])
484483
@pytest.mark.parametrize(
485484
("log_rank_zero_only", "world_size", "global_rank", "expected_log"),
486485
[
@@ -492,15 +491,11 @@ def test_early_stopping_squeezes():
492491
(True, 2, 1, None),
493492
],
494493
)
495-
def test_early_stopping_log_info(trainer, log_rank_zero_only, world_size, global_rank, expected_log):
494+
def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, expected_log):
496495
"""Checks if log.info() gets called with expected message when used within EarlyStopping."""
497496
# set the global_rank and world_size if trainer is not None
498497
# or else always expect the simple logging message
499-
if trainer:
500-
trainer.strategy.global_rank = global_rank
501-
trainer.strategy.world_size = world_size
502-
else:
503-
expected_log = "bar"
498+
trainer = Mock(global_rank=global_rank, world_size=world_size)
504499

505500
with mock.patch("lightning.pytorch.callbacks.early_stopping.log.info") as log_mock:
506501
EarlyStopping._log_info(trainer, "bar", log_rank_zero_only)

0 commit comments

Comments
 (0)