Skip to content

Commit 2ba36e7

Browse files
deependujhaBorda
andauthored
fix: enable parallel test execution with pytest-xdist in CI workflow (#620)
* fix: enable parallel test execution with pytest-xdist in CI workflow * temporary fix to handle parallelly running tests in ci * update * update * update * update * 7 pm * pytest-xdist ==3.4.0 * fix tmp path on windows * add fixture for unique HF URL to support parallel test runs * update * update * increase timeout of 60s to 90s * bump pytest & pytest-xdist * rerun failing tests twice * refactor: update pytest command and adjust fixture scopes for better test isolation * update * update * update * update * update * update * update * update * update * let's just wait * Update tests/streaming/test_dataloader.py * update * update * Apply suggestions from code review * Update src/litdata/streaming/resolver.py * update * update * Update .github/workflows/ci-testing.yml * update * let's try running all tests in parallel * update * update * let's run tests in groups * update * tests pass * update --------- Co-authored-by: Jirka B <j.borovec+github@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
1 parent 5ac0863 commit 2ba36e7

File tree

6 files changed

+26
-18
lines changed

6 files changed

+26
-18
lines changed

.github/workflows/ci-testing.yml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,16 @@ jobs:
4545
uv pip install -e ".[extras]" -r requirements/test.txt -U -q
4646
uv pip list
4747
48-
- name: Tests
49-
working-directory: tests
50-
run: pytest . -v --cov=litdata --durations=100
48+
- name: Run fast tests in parallel
49+
run: |
50+
pytest \
51+
tests/streaming tests/utilities \
52+
tests/test_cli.py tests/test_debugger.py \
53+
-n 2 --cov=litdata --cov-append --cov-report= --durations=120
54+
55+
- name: Run processing tests sequentially
56+
run: |
57+
pytest tests/processing tests/raw --cov=litdata --cov-append --cov-report= --durations=90
5158
5259
- name: Statistics
5360
continue-on-error: true

requirements/test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pytest-cov ==6.2.1
99
pytest-timeout ==2.4.0
1010
pytest-rerunfailures ==15.1
1111
pytest-random-order ==1.1.1
12+
pytest-xdist >=3.8.0
1213
pandas
1314
pyarrow >=20.0.0
1415
polars >1.0.0

tests/conftest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import shutil
3+
import signal
34
import sys
45
import threading
56
from collections import OrderedDict
@@ -16,7 +17,7 @@
1617
from litdata.utilities.dataset_utilities import get_default_cache_dir
1718

1819

19-
@pytest.fixture(autouse=True)
20+
@pytest.fixture(autouse=True, scope="session")
2021
def teardown_process_group():
2122
"""Ensures distributed process group gets closed before the next test runs."""
2223
yield
@@ -25,9 +26,8 @@ def teardown_process_group():
2526

2627

2728
@pytest.fixture(autouse=True)
28-
def set_env():
29-
# Set environment variable before each test to configure BaseWorker's maximum wait time
30-
os.environ["DATA_OPTIMIZER_TIMEOUT"] = "20"
29+
def disable_signals(monkeypatch):
30+
monkeypatch.setattr(signal, "signal", lambda *args, **kwargs: None)
3131

3232

3333
@pytest.fixture
@@ -132,7 +132,7 @@ def lightning_sdk_mock(monkeypatch):
132132
return lightning_sdk
133133

134134

135-
@pytest.fixture(autouse=True)
135+
@pytest.fixture(autouse=True, scope="session")
136136
def _thread_police():
137137
"""Attempts stopping left-over threads to avoid test interactions.
138138

tests/streaming/test_dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def test_dataloader_states_with_persistent_workers(tmpdir):
319319
assert count >= 25, "There should be at least 25 batches in the third epoch"
320320

321321

322-
@pytest.mark.timeout(60)
322+
@pytest.mark.timeout(90)
323323
def test_resume_dataloader_with_new_dataset(tmpdir):
324324
dataset_1_path = tmpdir.join("dataset_1")
325325
dataset_2_path = tmpdir.join("dataset_2")

tests/streaming/test_dataset.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir, compression
311311
pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")),
312312
],
313313
)
314-
@pytest.mark.timeout(60)
314+
@pytest.mark.timeout(90)
315315
def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir, compression):
316316
seed_everything(42)
317317

@@ -364,7 +364,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir, compr
364364
),
365365
],
366366
)
367-
@pytest.mark.timeout(60)
367+
@pytest.mark.timeout(90)
368368
def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir, compression):
369369
seed_everything(42)
370370

@@ -412,7 +412,7 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir, comp
412412
pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")),
413413
],
414414
)
415-
@pytest.mark.timeout(60)
415+
@pytest.mark.timeout(90)
416416
def test_streaming_dataset_distributed_full_shuffle_even_multi_nodes(drop_last, tmpdir, compression):
417417
seed_everything(42)
418418

@@ -685,7 +685,7 @@ def test_dataset_for_text_tokens_multiple_workers(tmpdir):
685685
assert result == expected
686686

687687

688-
@pytest.mark.timeout(60)
688+
@pytest.mark.timeout(90)
689689
def test_dataset_for_text_tokens_with_large_block_size_multiple_workers(tmpdir):
690690
# test to reproduce ERROR: Unexpected segmentation fault encountered in worker
691691
seed_everything(42)
@@ -1077,7 +1077,7 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):
10771077
assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from)
10781078

10791079

1080-
@pytest.mark.timeout(60)
1080+
@pytest.mark.timeout(90)
10811081
@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
10821082
def test_dataset_valid_state(tmpdir, monkeypatch):
10831083
seed_everything(42)
@@ -1213,7 +1213,7 @@ def fn(remote_chunkpath: str, local_chunkpath: str):
12131213
dataset._validate_state_dict()
12141214

12151215

1216-
@pytest.mark.timeout(60)
1216+
@pytest.mark.timeout(90)
12171217
@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
12181218
def test_dataset_valid_state_override(tmpdir, monkeypatch):
12191219
seed_everything(42)

tests/utilities/test_env.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33

44
def test_distributed_env_from_env(monkeypatch):
5-
monkeypatch.setenv("WORLD_SIZE", 2)
6-
monkeypatch.setenv("GLOBAL_RANK", 1)
7-
monkeypatch.setenv("NNODES", 2)
5+
monkeypatch.setenv("WORLD_SIZE", "2")
6+
monkeypatch.setenv("GLOBAL_RANK", "1")
7+
monkeypatch.setenv("NNODES", "2")
88

99
dist_env = _DistributedEnv.detect()
1010
assert dist_env.world_size == 2

0 commit comments

Comments
 (0)