Skip to content

Commit 82bf020

Browse files
authored
Refactor dataset preparation fixture to avoid redundancy and limit test parametrization to reduce time (#632)
* refactor combined_dataset fixture for reusable dataset preparation * remove repeated tests
1 parent 1a3c1c1 commit 82bf020

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

tests/conftest.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,24 @@ def mosaic_mds_index_data():
5252
}
5353

5454

55-
@pytest.fixture
56-
def combined_dataset(tmpdir_factory):
57-
tmpdir = tmpdir_factory.mktemp("data")
55+
@pytest.fixture(scope="session")
56+
def prepare_combined_dataset(tmpdir_factory):
57+
tmpdir = tmpdir_factory.mktemp("combined_dataset")
5858
datasets = [str(tmpdir.join(f"dataset_{i}")) for i in range(2)]
5959
for dataset in datasets:
6060
cache = Cache(input_dir=dataset, chunk_bytes="64MB")
6161
for i in range(50):
6262
cache[i] = i
6363
cache.done()
6464
cache.merge()
65+
return datasets
6566

66-
dataset_1 = StreamingDataset(datasets[0], shuffle=True)
67-
dataset_2 = StreamingDataset(datasets[1], shuffle=True)
67+
68+
@pytest.fixture
69+
def combined_dataset(prepare_combined_dataset):
70+
dataset_1_path, dataset_2_path = prepare_combined_dataset
71+
dataset_1 = StreamingDataset(dataset_1_path)
72+
dataset_2 = StreamingDataset(dataset_2_path)
6873
return CombinedStreamingDataset(datasets=[dataset_1, dataset_2])
6974

7075

tests/streaming/test_combined.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,6 @@ def test_combined_dataset_dataloader_states_without_any_iterations(combined_data
537537
@pytest.mark.timeout(120)
538538
@pytest.mark.parametrize("num_workers", [0, 2, 4])
539539
def test_combined_dataset_dataloader_states_complete_iterations(combined_dataset, num_workers):
540-
print(f"Testing with num_workers={num_workers}")
541540
dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=num_workers)
542541
assert len(dataloader) == 25, "Dataloader length should be 25 (50+50 items / batch size 4)"
543542

@@ -559,16 +558,13 @@ def test_combined_dataset_dataloader_states_complete_iterations(combined_dataset
559558

560559

561560
@pytest.mark.timeout(300)
562-
@pytest.mark.parametrize(("num_workers", "break_at"), [(0, 10), (0, 15), (2, 10), (2, 15), (4, 10), (4, 15)])
561+
@pytest.mark.parametrize(("num_workers", "break_at"), [(0, 10), (0, 15), (2, 15), (4, 15)])
563562
def test_combined_dataset_dataloader_states_partial_iterations(combined_dataset, num_workers, break_at):
564-
print(f"Testing with num_workers={num_workers}, break_at={break_at}")
565-
566563
# Verify dataloader state after partial last iteration
567564
dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=num_workers)
568565

569566
total_batches = len(dataloader)
570567
assert total_batches == 25, "Dataloader length should be 25 (100 items / batch size 4)"
571-
572568
assert not dataloader.restore, "Dataloader should not be in restore state initially."
573569

574570
# Partial iteration up to 'break_at'

0 commit comments

Comments
 (0)