Skip to content

Commit

Permalink
Support special test parametrizations (#10569)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored and awaelchli committed Nov 24, 2021
1 parent f96d769 commit 9a17cf9
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 117 deletions.
12 changes: 2 additions & 10 deletions tests/accelerators/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,8 @@ def setup(self, stage: Optional[str] = None) -> None:


@RunIf(min_gpus=2, min_torch="1.8.1", special=True)
def test_ddp_wrapper_16(tmpdir):
_test_ddp_wrapper(tmpdir, precision=16)


@RunIf(min_gpus=2, min_torch="1.8.1", special=True)
def test_ddp_wrapper_32(tmpdir):
_test_ddp_wrapper(tmpdir, precision=32)


def _test_ddp_wrapper(tmpdir, precision):
@pytest.mark.parametrize("precision", (16, 32))
def test_ddp_wrapper(tmpdir, precision):
"""Test parameters to ignore are carried over for DDP."""

class WeirdModule(torch.nn.Module):
Expand Down
25 changes: 8 additions & 17 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,27 +161,18 @@ def test_pruning_callback(


@RunIf(special=True, min_gpus=2)
def test_pruning_callback_ddp_0(tmpdir):
@pytest.mark.parametrize("parameters_to_prune", (False, True))
@pytest.mark.parametrize("use_global_unstructured", (False, True))
def test_pruning_callback_ddp(tmpdir, parameters_to_prune, use_global_unstructured):
train_with_pruning_callback(
tmpdir, parameters_to_prune=False, use_global_unstructured=False, strategy="ddp", gpus=2
tmpdir,
parameters_to_prune=parameters_to_prune,
use_global_unstructured=use_global_unstructured,
strategy="ddp",
gpus=2,
)


@RunIf(special=True, min_gpus=2)
def test_pruning_callback_ddp_1(tmpdir):
train_with_pruning_callback(tmpdir, parameters_to_prune=False, use_global_unstructured=True, strategy="ddp", gpus=2)


@RunIf(special=True, min_gpus=2)
def test_pruning_callback_ddp_2(tmpdir):
train_with_pruning_callback(tmpdir, parameters_to_prune=True, use_global_unstructured=False, strategy="ddp", gpus=2)


@RunIf(special=True, min_gpus=2)
def test_pruning_callback_ddp_3(tmpdir):
train_with_pruning_callback(tmpdir, parameters_to_prune=True, use_global_unstructured=True, strategy="ddp", gpus=2)


@RunIf(min_gpus=2, skip_windows=True)
def test_pruning_callback_ddp_spawn(tmpdir):
train_with_pruning_callback(tmpdir, use_global_unstructured=True, strategy="ddp_spawn", gpus=2)
Expand Down
19 changes: 5 additions & 14 deletions tests/callbacks/test_tqdm_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,20 +522,11 @@ def test_tqdm_progress_bar_can_be_pickled():


@RunIf(min_gpus=2, special=True)
def test_tqdm_progress_bar_max_val_check_interval_0(tmpdir):
_test_progress_bar_max_val_check_interval(
tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.2
)


@RunIf(min_gpus=2, special=True)
def test_tqdm_progress_bar_max_val_check_interval_1(tmpdir):
_test_progress_bar_max_val_check_interval(
tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.5
)


def _test_progress_bar_max_val_check_interval(
@pytest.mark.parametrize(
["total_train_samples", "train_batch_size", "total_val_samples", "val_batch_size", "val_check_interval"],
[(8, 4, 2, 1, 0.2), (8, 4, 2, 1, 0.5)],
)
def test_progress_bar_max_val_check_interval(
tmpdir, total_train_samples, train_batch_size, total_val_samples, val_batch_size, val_check_interval
):
world_size = 2
Expand Down
13 changes: 2 additions & 11 deletions tests/checkpointing/test_checkpoint_callback_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,8 @@ def training_step(self, batch, batch_idx):

@mock.patch("torch.save")
@RunIf(special=True, min_gpus=2)
def test_top_k_ddp_0(save_mock, tmpdir):
_top_k_ddp(save_mock, tmpdir, k=1, epochs=1, val_check_interval=1.0, expected=1)


@mock.patch("torch.save")
@RunIf(special=True, min_gpus=2)
def test_top_k_ddp_1(save_mock, tmpdir):
_top_k_ddp(save_mock, tmpdir, k=2, epochs=2, val_check_interval=0.3, expected=4)


def _top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected):
@pytest.mark.parametrize(["k", "epochs", "val_check_interval", "expected"], [(1, 1, 1.0, 1), (2, 2, 0.3, 4)])
def test_top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
local_rank = int(os.getenv("LOCAL_RANK"))
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,16 @@ def single_process_pg():
torch.distributed.destroy_process_group()
os.environ.clear()
os.environ.update(orig_environ)


def pytest_collection_modifyitems(items):
if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") != "1":
return
# filter out non-special tests
items[:] = [
item
for item in items
for marker in item.own_markers
# has `@RunIf(special=True)`
if marker.name == "skipif" and marker.kwargs.get("special")
]
2 changes: 2 additions & 0 deletions tests/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def __new__(
env_flag = os.getenv("PL_RUNNING_SPECIAL_TESTS", "0")
conditions.append(env_flag != "1")
reasons.append("Special execution")
# used in tests/conftest.py::pytest_collection_modifyitems
kwargs["special"] = True

if fairscale:
conditions.append(not _FAIRSCALE_AVAILABLE)
Expand Down
12 changes: 3 additions & 9 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,16 +423,10 @@ def _predict_batch(trainer, model, batches):


@RunIf(deepspeed=True, min_gpus=1, special=True)
def test_trainer_model_hook_system_fit_deepspeed_automatic_optimization(tmpdir):
_run_trainer_model_hook_system_fit(
dict(gpus=1, precision=16, strategy="deepspeed"), tmpdir, automatic_optimization=True
)


@RunIf(deepspeed=True, min_gpus=1, special=True)
def test_trainer_model_hook_system_fit_deepspeed_manual_optimization(tmpdir):
@pytest.mark.parametrize("automatic_optimization", (True, False))
def test_trainer_model_hook_system_fit_deepspeed(tmpdir, automatic_optimization):
_run_trainer_model_hook_system_fit(
dict(gpus=1, precision=16, strategy="deepspeed"), tmpdir, automatic_optimization=False
dict(gpus=1, precision=16, strategy="deepspeed"), tmpdir, automatic_optimization=automatic_optimization
)


Expand Down
74 changes: 34 additions & 40 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,49 @@ set -e
# this environment variable allows special tests to run
export PL_RUNNING_SPECIAL_TESTS=1
# python arguments
defaults='-m coverage run --source pytorch_lightning --append -m pytest --durations=0 --capture=no --disable-warnings'
defaults='-m coverage run --source pytorch_lightning --append -m pytest --capture=no'

# find tests marked as `@RunIf(special=True)`
grep_output=$(grep --recursive --line-number --word-regexp 'tests' 'benchmarks' --regexp 'special=True')
# file paths
files=$(echo "$grep_output" | cut -f1 -d:)
files_arr=($files)
# line numbers
linenos=$(echo "$grep_output" | cut -f2 -d:)
linenos_arr=($linenos)
# find tests marked as `@RunIf(special=True)`. done manually instead of with pytest because it is faster
grep_output=$(grep --recursive --word-regexp 'tests' 'benchmarks' --regexp 'special=True' --include '*.py' --exclude 'tests/conftest.py')

# file paths, remove duplicates
files=$(echo "$grep_output" | cut -f1 -d: | sort | uniq)

# get the list of parametrizations. we need to call them separately. the last two lines are removed.
# note: if there's a syntax error, this will fail with some garbled output
if [[ "$OSTYPE" == "darwin"* ]]; then
parametrizations=$(pytest $files --collect-only --quiet | tail -r | sed -e '1,3d' | tail -r)
else
parametrizations=$(pytest $files --collect-only --quiet | head -n -2)
fi
parametrizations_arr=($parametrizations)

# tests to skip - space separated
blocklist='test_pytorch_profiler_nested_emit_nvtx'
blocklist='tests/profiler/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx'
report=''

for i in "${!files_arr[@]}"; do
file=${files_arr[$i]}
lineno=${linenos_arr[$i]}

# get code from `@RunIf(special=True)` line to EOF
test_code=$(tail -n +"$lineno" "$file")
for i in "${!parametrizations_arr[@]}"; do
parametrization=${parametrizations_arr[$i]}

# read line by line
while read -r line; do
# if it's a test
if [[ $line == def\ test_* ]]; then
# get the name
test_name=$(echo $line | cut -c 5- | cut -f1 -d\()
# check blocklist
if echo $blocklist | grep -F "${parametrization}"; then
report+="Skipped\t$parametrization\n"
continue
fi

# check blocklist
if echo $blocklist | grep --word-regexp "$test_name" > /dev/null; then
report+="Skipped\t$file:$lineno::$test_name\n"
break
fi
# SPECIAL_PATTERN allows filtering the tests to run when debugging.
# use as `SPECIAL_PATTERN="foo_bar" ./special_tests.sh` to run only those
# test with `foo_bar` in their name
if [[ $parametrization != *$SPECIAL_PATTERN* ]]; then
report+="Skipped\t$parametrization\n"
continue
fi

# SPECIAL_PATTERN allows filtering the tests to run when debugging.
# use as `SPECIAL_PATTERN="foo_bar" ./special_tests.sh` to run only those
# test with `foo_bar` in their name
if [[ $line != *$SPECIAL_PATTERN* ]]; then
report+="Skipped\t$file:$lineno::$test_name\n"
break
fi
# run the test
echo "Running ${parametrization}"
python ${defaults} "${parametrization}"

# run the test
report+="Ran\t$file:$lineno::$test_name\n"
python ${defaults} "${file}::${test_name}"
break
fi
done < <(echo "$test_code")
report+="Ran\t$parametrization\n"
done

if nvcc --version; then
Expand Down
29 changes: 13 additions & 16 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,29 +1453,26 @@ def test_trainer_predict_cpu(tmpdir, datamodule, enable_progress_bar):


@RunIf(min_gpus=2, special=True)
@pytest.mark.parametrize("num_gpus", [1, 2])
def test_trainer_predict_dp(tmpdir, num_gpus):
predict(tmpdir, strategy="dp", accelerator="gpu", devices=num_gpus)


@RunIf(min_gpus=2, special=True, fairscale=True)
def test_trainer_predict_ddp(tmpdir):
predict(tmpdir, strategy="ddp", accelerator="gpu", devices=2)


@RunIf(min_gpus=2, skip_windows=True, special=True)
def test_trainer_predict_ddp_spawn(tmpdir):
predict(tmpdir, strategy="dp", accelerator="gpu", devices=2)
@pytest.mark.parametrize(
"kwargs",
[
{"strategy": "dp", "devices": 1},
{"strategy": "dp", "devices": 2},
{"strategy": "ddp", "devices": 2},
],
)
def test_trainer_predict_special(tmpdir, kwargs):
predict(tmpdir, accelerator="gpu", **kwargs)


@RunIf(min_gpus=1, special=True)
@RunIf(min_gpus=1)
def test_trainer_predict_1_gpu(tmpdir):
predict(tmpdir, accelerator="gpu", devices=1)


@RunIf(skip_windows=True)
def test_trainer_predict_ddp_cpu(tmpdir):
predict(tmpdir, strategy="ddp_spawn", accelerator="cpu", devices=2)
def test_trainer_predict_ddp_spawn(tmpdir):
predict(tmpdir, strategy="ddp_spawn", accelerator="auto", devices=2)


@pytest.mark.parametrize("dataset_cls", [RandomDataset, RandomIterableDatasetWithLen, RandomIterableDataset])
Expand Down

0 comments on commit 9a17cf9

Please sign in to comment.