Skip to content

Commit

Permalink
Skip strategy=ddp_spawn, accelerator=cpu, python>=3.9 tests (#10550)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Nov 16, 2021
1 parent 60850ef commit 6dfcb6a
Show file tree
Hide file tree
Showing 18 changed files with 39 additions and 27 deletions.
10 changes: 5 additions & 5 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,16 +381,16 @@ def on_train_end(self) -> None:

_ES_CHECK = dict(check_on_train_epoch_end=True)
_ES_CHECK_P3 = dict(patience=3, check_on_train_epoch_end=True)
_NO_WIN = dict(marks=RunIf(skip_windows=True))
_SPAWN_MARK = dict(marks=RunIf(skip_windows=True, skip_49370=True))


@pytest.mark.parametrize(
"callbacks, expected_stop_epoch, check_on_train_epoch_end, strategy, num_processes",
[
([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, None, 1),
([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, None, 1),
pytest.param([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "ddp_spawn", 2, **_NO_WIN),
pytest.param([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "ddp_spawn", 2, **_NO_WIN),
pytest.param([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "ddp_spawn", 2, **_SPAWN_MARK),
pytest.param([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "ddp_spawn", 2, **_SPAWN_MARK),
([EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], 3, True, None, 1),
([EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], 3, True, None, 1),
pytest.param(
Expand All @@ -399,15 +399,15 @@ def on_train_end(self) -> None:
True,
"ddp_spawn",
2,
**_NO_WIN,
**_SPAWN_MARK,
),
pytest.param(
[EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)],
3,
True,
"ddp_spawn",
2,
**_NO_WIN,
**_SPAWN_MARK,
),
],
)
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_pruning_callback_ddp_spawn(tmpdir):
train_with_pruning_callback(tmpdir, use_global_unstructured=True, strategy="ddp_spawn", gpus=2)


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
def test_pruning_callback_ddp_cpu(tmpdir):
train_with_pruning_callback(tmpdir, parameters_to_prune=True, strategy="ddp_spawn", num_processes=2)

Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_swa_callback_ddp_spawn(tmpdir):
train_with_swa(tmpdir, strategy="ddp_spawn", gpus=2)


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
def test_swa_callback_ddp_cpu(tmpdir):
train_with_swa(tmpdir, strategy="ddp_spawn", num_processes=2)

Expand Down
2 changes: 1 addition & 1 deletion tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def on_train_end(self, trainer, pl_module):
assert torch.save.call_count == 0


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
def test_model_checkpoint_no_extraneous_invocations(tmpdir):
"""Test to ensure that the model callback saves the checkpoints only once in distributed mode."""
model = LogInTwoMethods()
Expand Down
2 changes: 1 addition & 1 deletion tests/checkpointing/test_torch_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_model_torch_save(tmpdir):
trainer = torch.load(temp_path)


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
def test_model_torch_save_ddp_cpu(tmpdir):
"""Test to ensure torch save does not fail for model and trainer using cpu ddp."""
model = BoringModel()
Expand Down
2 changes: 1 addition & 1 deletion tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
return super().get_from_queue(queue)


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
def test_v1_7_0_deprecate_add_get_queue(tmpdir):
model = BoringCallbackDDPSpawnModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, num_processes=2, strategy="ddp_spawn")
Expand Down
11 changes: 11 additions & 0 deletions tests/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __new__(
fairscale_fully_sharded: bool = False,
deepspeed: bool = False,
rich: bool = False,
skip_49370: bool = False,
**kwargs,
):
"""
Expand All @@ -91,6 +92,7 @@ def __new__(
fairscale_fully_sharded: if `fairscale` fully sharded module is required to run the test
deepspeed: if `deepspeed` module is required to run the test
rich: if `rich` module is required to run the test
skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370.
kwargs: native pytest.mark.skipif keyword arguments
"""
conditions = []
Expand Down Expand Up @@ -165,6 +167,15 @@ def __new__(
conditions.append(not _RICH_AVAILABLE)
reasons.append("Rich")

if skip_49370:
# strategy=ddp_spawn, accelerator=cpu, python>=3.9, torch<1.8 does not work
py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
ge_3_9 = Version(py_version) >= Version("3.9")
torch_version = get_distribution("torch").version
old_torch = Version(torch_version) < Version("1.8")
conditions.append(ge_3_9 and old_torch)
reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370")

reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return pytest.mark.skipif(
*args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
assert pl_module.logger.experiment.something(foo="bar") is None


@RunIf(skip_windows=True, skip_49370=True)
@pytest.mark.parametrize("logger_class", [CometLogger, CSVLogger, MLFlowLogger, TensorBoardLogger, TestTubeLogger])
@RunIf(skip_windows=True)
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
"""Test that loggers get replaced by dummy loggers on global rank > 0."""
_patch_comet_atexit(monkeypatch)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def validation_step(self, *args, **kwargs):
model.unfreeze()


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
def test_multi_cpu_model_ddp(tmpdir):
"""Make sure DDP works."""
tutils.set_random_main_port()
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _run_horovod(trainer_options, on_gpu=False):
assert exit_code == 0


@RunIf(skip_windows=True, horovod=True)
@RunIf(skip_windows=True, horovod=True, skip_49370=True)
def test_horovod_cpu(tmpdir):
"""Test Horovod running multi-process on CPU."""
trainer_options = dict(
Expand All @@ -82,7 +82,7 @@ def test_horovod_cpu(tmpdir):
_run_horovod(trainer_options)


@RunIf(skip_windows=True, horovod=True)
@RunIf(skip_windows=True, horovod=True, skip_49370=True)
def test_horovod_cpu_clip_grad_by_value(tmpdir):
"""Test Horovod running multi-process on CPU."""
trainer_options = dict(
Expand All @@ -99,7 +99,7 @@ def test_horovod_cpu_clip_grad_by_value(tmpdir):
_run_horovod(trainer_options)


@RunIf(skip_windows=True, horovod=True)
@RunIf(skip_windows=True, horovod=True, skip_49370=True)
def test_horovod_cpu_implicit(tmpdir):
"""Test Horovod without specifying a backend, inferring from env set by `horovodrun`."""
trainer_options = dict(
Expand Down
6 changes: 3 additions & 3 deletions tests/plugins/test_ddp_spawn_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
return super().get_from_queue(queue)


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
def test_ddp_cpu():
"""Tests if device is set correctly when training for DDPSpawnPlugin."""
trainer = Trainer(num_processes=2, fast_dev_run=True)
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_from_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQu
return super().get_from_queue(trainer, queue)


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
def test_ddp_spawn_add_get_queue(tmpdir):
"""Tests add_to_queue/get_from_queue with DDPSpawnPlugin."""

Expand Down Expand Up @@ -128,7 +128,7 @@ def on_predict_start(self) -> None:
assert isinstance(self.trainer.model, LightningModule)


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
def test_ddp_spawn_configure_ddp(tmpdir):
"""Tests with ddp spawn plugin."""
trainer = Trainer(default_root_dir=tmpdir, num_processes=2, strategy="ddp_spawn", fast_dev_run=True)
Expand Down
3 changes: 2 additions & 1 deletion tests/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_simple_profiler_with_nonexisting_dirpath(tmpdir):
assert nonexisting_tmpdir.join("fit-profiler.txt").exists()


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
def test_simple_profiler_distributed_files(tmpdir):
"""Ensure the proper files are saved in distributed."""
profiler = SimpleProfiler(dirpath=tmpdir, filename="profiler")
Expand Down Expand Up @@ -226,6 +226,7 @@ def test_advanced_profiler_iterable_durations(advanced_profiler, action: str, ex
np.testing.assert_allclose(recored_total_duration, expected_total_duration, rtol=0.2)


@pytest.mark.flaky(reruns=3)
def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
"""ensure that the profiler doesn't introduce too much overhead during training."""
for _ in range(n_iter):
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/logging_/test_distributed_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def on_train_end(self):
assert self.log_name.format(rank=self.local_rank) in self.logger.logs, "Expected rank to be logged"


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
def test_all_rank_logging_ddp_cpu(tmpdir):
"""Check that all ranks can be logged from."""
model = TestModel()
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def validation_step(self, batch, batch_idx):
return super().validation_step(batch, batch_idx)


@pytest.mark.parametrize("devices", [1, pytest.param(2, marks=RunIf(skip_windows=True))])
@pytest.mark.parametrize("devices", [1, pytest.param(2, marks=RunIf(skip_windows=True, skip_49370=True))])
def test_logging_sync_dist_true(tmpdir, devices):
"""Tests to ensure that the sync_dist flag works (should just return the original value)"""
fake_result = 1
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/properties/test_get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_get_model(tmpdir):
trainer.fit(model)


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
def test_get_model_ddp_cpu(tmpdir):
"""Tests that `trainer.lightning_module` extracts the model correctly when using ddp on cpu."""

Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _get_warning_msg():
assert warn_str in msg


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
@pytest.mark.parametrize("num_workers", [0, 1])
def test_dataloader_warnings(tmpdir, num_workers):
trainer = Trainer(default_root_dir=tmpdir, strategy="ddp_spawn", num_processes=2, fast_dev_run=4)
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,7 +1809,7 @@ def on_predict_start(self) -> None:


@pytest.mark.parametrize(
"strategy,num_processes", [(None, 1), pytest.param("ddp_spawn", 2, marks=RunIf(skip_windows=True))]
"strategy,num_processes", [(None, 1), pytest.param("ddp_spawn", 2, marks=RunIf(skip_windows=True, skip_49370=True))]
)
def test_model_in_correct_mode_during_stages(tmpdir, strategy, num_processes):
model = TrainerStagesModel()
Expand All @@ -1830,7 +1830,7 @@ def validation_epoch_end(self, outputs) -> None:
pass


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_49370=True)
def test_fit_test_synchronization(tmpdir):
"""Test that the trainer synchronizes processes before returning control back to the caller."""
tutils.set_random_main_port()
Expand Down
4 changes: 2 additions & 2 deletions tests/utilities/test_all_gather_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def _test_all_gather_ddp(rank, world_size):
assert torch.allclose(grad2, tensor2.grad)


@RunIf(skip_windows=True)
def test_all_gather_ddp():
@RunIf(skip_windows=True, skip_49370=True)
def test_all_gather_ddp_spawn():
world_size = 3
torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)

Expand Down

0 comments on commit 6dfcb6a

Please sign in to comment.