Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise an exception when calling fit twice with spawn #18776

Merged
merged 6 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `ModelCheckpoint` no longer deletes the file that was passed to `Trainer.fit(ckpt_path=...)` ([#18750](https://github.com/Lightning-AI/lightning/pull/18750))


- Calling `trainer.fit()` twice now raises an error with strategies that spawn subprocesses ([#18776](https://github.com/Lightning-AI/lightning/pull/18776))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


### Deprecated

- Deprecated the `SingleTPUStrategy` (`strategy="single_tpu"`) in favor of `SingleDeviceXLAStrategy` (`strategy="single_xla"`) ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))
Expand Down
11 changes: 11 additions & 0 deletions src/lightning/pytorch/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __init__(
)
self.procs: List[mp.Process] = []

carmocca marked this conversation as resolved.
Show resolved Hide resolved
self._already_fit = False

@property
def is_interactive_compatible(self) -> bool:
# The start method 'spawn' is not supported in interactive environments
Expand All @@ -106,6 +108,13 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
_check_bad_cuda_fork()
if self._start_method == "spawn":
_check_missing_main_guard()
if self._already_fit and trainer.state.fn == TrainerFn.FITTING:
# resolving https://github.com/Lightning-AI/lightning/issues/18775 will lift this restriction
raise NotImplementedError(
"Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not"
" supported. You can work around this limitation by creating a new Trainer instance and passing the"
" `fit(ckpt_path=...)` argument."
)

# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
Expand Down Expand Up @@ -182,6 +191,8 @@ def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", train

trainer.state = worker_output.trainer_state

self._already_fit |= trainer.state.fn == TrainerFn.FITTING
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# get the `callback_metrics` and set it to the trainer
self.update_main_process_results(trainer, worker_output.extra)

Expand Down
8 changes: 8 additions & 0 deletions src/lightning/pytorch/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
**kwargs: Optional keyword arguments to be passed to the given function.

"""
if self._already_fit and trainer.state.fn == TrainerFn.FITTING:
# resolving https://github.com/Lightning-AI/lightning/issues/18775 will lift this restriction
raise NotImplementedError(
"Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not"
" supported. You can work around this by creating a new Trainer instance and passing the"
" `fit(ckpt_path=...)` argument."
)

using_pjrt = _using_pjrt()
# pjrt requires that the queue is serializable
return_queue: Union[queue.Queue, mp.SimpleQueue] = (
Expand Down
28 changes: 12 additions & 16 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,9 @@ def fit(
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(model, self.strategy)
self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True
call._call_and_handle_interrupt(
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)
Expand All @@ -553,10 +556,6 @@ def _fit_impl(
) -> None:
log.debug(f"{self.__class__.__name__}: trainer fit stage")

self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True

# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloaders, LightningDataModule):
datamodule = train_dataloaders
Expand Down Expand Up @@ -640,6 +639,9 @@ def validate(
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(self.lightning_module, self.strategy)
self.state.fn = TrainerFn.VALIDATING
self.state.status = TrainerStatus.RUNNING
self.validating = True
return call._call_and_handle_interrupt(
self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule
)
Expand All @@ -657,10 +659,6 @@ def _validate_impl(
# --------------------
log.debug(f"{self.__class__.__name__}: trainer validate stage")

self.state.fn = TrainerFn.VALIDATING
self.state.status = TrainerStatus.RUNNING
self.validating = True

# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
datamodule = dataloaders
Expand Down Expand Up @@ -749,6 +747,9 @@ def test(
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(self.lightning_module, self.strategy)
self.state.fn = TrainerFn.TESTING
self.state.status = TrainerStatus.RUNNING
self.testing = True
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return call._call_and_handle_interrupt(
self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
)
Expand All @@ -766,10 +767,6 @@ def _test_impl(
# --------------------
log.debug(f"{self.__class__.__name__}: trainer test stage")

self.state.fn = TrainerFn.TESTING
self.state.status = TrainerStatus.RUNNING
self.testing = True

# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
datamodule = dataloaders
Expand Down Expand Up @@ -859,6 +856,9 @@ def predict(
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(self.lightning_module, self.strategy)
self.state.fn = TrainerFn.PREDICTING
self.state.status = TrainerStatus.RUNNING
self.predicting = True
return call._call_and_handle_interrupt(
self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
)
Expand All @@ -876,10 +876,6 @@ def _predict_impl(
# --------------------
log.debug(f"{self.__class__.__name__}: trainer predict stage")

self.state.fn = TrainerFn.PREDICTING
self.state.status = TrainerStatus.RUNNING
self.predicting = True

self.predict_loop.return_predictions = return_predictions # type: ignore[assignment]

# if a datamodule comes in as the second arg, then fix it for the user
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,20 @@ def test_check_for_missing_main_guard():
return_value=Mock(_inheriting=True), # pretend that main is importing itself
), pytest.raises(RuntimeError, match="requires that your script guards the main"):
launcher.launch(function=Mock())


def test_fit_twice_raises():
model = BoringModel()
trainer = Trainer(
limit_train_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
strategy="ddp_spawn",
barebones=True,
)
trainer.fit(model)
trainer.test(model) # make sure testing in between doesnt impact the result
trainer.fit_loop.max_epochs += 1
with pytest.raises(NotImplementedError, match=r"twice.*is not supported"):
trainer.fit(model)
Loading