diff --git a/test/renate/data/test_data_module.py b/test/renate/data/test_data_module.py index aa6bddf2..0b76bc1e 100644 --- a/test/renate/data/test_data_module.py +++ b/test/renate/data/test_data_module.py @@ -91,15 +91,15 @@ def test_torchvision_data_module(tmpdir, dataset_name, num_tr, num_te, x_shape): @pytest.mark.parametrize( "dataset_name,chunk_id,num_tr,num_te", [ - ("CLEAR10", 0, 2986, 500), - ("CLEAR100", 0, 9945, 4984), + ("CLEAR10", 0, 3300, 550), + ("CLEAR100", 0, 9964, 5000), ], ) def test_clear_data_module(tmpdir, dataset_name, chunk_id, num_tr, num_te): """Test loading of CLEAR data.""" val_size = 0.2 data_module = CLEARDataModule( - tmpdir, dataset_name=dataset_name, chunk_id=chunk_id, val_size=val_size + tmpdir, dataset_name=dataset_name, time_step=chunk_id, val_size=val_size ) data_module.prepare_data() data_module.setup() @@ -131,7 +131,7 @@ def test_tiny_imagenet_data_module(tmpdir): assert isinstance(val_data, Dataset) assert len(test_data) == num_te assert isinstance(test_data, Dataset) - assert train_data[0][0].size() == test_data[0][0].size() == (3, 64, 64) + assert train_data[0][0].size == test_data[0][0].size == (64, 64) @pytest.mark.slow