@@ -39,30 +39,31 @@ def test_rendering(mock_sintel_data, tmp_path_factory):
39
39
40
40
41
41
@pytest .fixture (scope = "module" )
42
- def dataset (mock_sintel_data ):
43
- return MultiTaskSintel (
44
- tasks = ["flow" ],
45
- boxfilter = dict (extent = 1 , kernel_size = 13 ),
46
- vertical_splits = 3 ,
47
- n_frames = 4 ,
48
- center_crop_fraction = 0.7 ,
49
- dt = 1 / 24 ,
50
- augment = True ,
51
- random_temporal_crop = True ,
52
- all_frames = False ,
53
- resampling = True ,
54
- interpolate = True ,
55
- p_flip = 0.5 ,
56
- p_rot = 5 / 6 ,
57
- contrast_std = 0.2 ,
58
- brightness_std = 0.1 ,
59
- gaussian_white_noise = 0.08 ,
60
- gamma_std = None ,
61
- _init_cache = True ,
62
- unittest = True ,
63
- flip_axes = [0 , 1 , 2 , 3 ],
64
- sintel_path = mock_sintel_data ,
65
- )
42
+ def dataset (mock_sintel_data , tmp_path_factory ):
43
+ with set_root_context (tmp_path_factory .mktemp ("tmp" )):
44
+ return MultiTaskSintel (
45
+ tasks = ["flow" ],
46
+ boxfilter = dict (extent = 1 , kernel_size = 13 ),
47
+ vertical_splits = 3 ,
48
+ n_frames = 4 ,
49
+ center_crop_fraction = 0.7 ,
50
+ dt = 1 / 24 ,
51
+ augment = True ,
52
+ random_temporal_crop = True ,
53
+ all_frames = False ,
54
+ resampling = True ,
55
+ interpolate = True ,
56
+ p_flip = 0.5 ,
57
+ p_rot = 5 / 6 ,
58
+ contrast_std = 0.2 ,
59
+ brightness_std = 0.1 ,
60
+ gaussian_white_noise = 0.08 ,
61
+ gamma_std = None ,
62
+ _init_cache = True ,
63
+ unittest = True ,
64
+ flip_axes = [0 , 1 , 2 , 3 ],
65
+ sintel_path = mock_sintel_data ,
66
+ )
66
67
67
68
68
69
@pytest .fixture (
@@ -81,25 +82,26 @@ def tasks(request):
81
82
return request .param
82
83
83
84
84
- def test_init (tasks , mock_sintel_data ):
85
- dataset = MultiTaskSintel (
86
- tasks = tasks , n_frames = 4 , unittest = True , sintel_path = mock_sintel_data
87
- )
88
- assert hasattr (dataset , "tasks" )
89
- assert "lum" in dataset .data_keys
90
- assert hasattr (dataset , "config" )
91
- assert hasattr (dataset , "meta" )
92
- assert hasattr (dataset , "cached_sequences" )
93
- assert hasattr (dataset , "arg_df" )
94
- assert set (dataset [0 ].keys ()) == set (["lum" , * tasks ])
95
- dataset = MultiTaskSintel (
96
- tasks = tasks ,
97
- n_frames = 4 ,
98
- unittest = True ,
99
- _init_cache = False ,
100
- sintel_path = mock_sintel_data ,
101
- )
102
- assert not hasattr (dataset , "cached_sequences" )
85
+ def test_init (tasks , mock_sintel_data , tmp_path_factory ):
86
+ with set_root_context (tmp_path_factory .mktemp ("tmp" )):
87
+ dataset = MultiTaskSintel (
88
+ tasks = tasks , n_frames = 4 , unittest = True , sintel_path = mock_sintel_data
89
+ )
90
+ assert hasattr (dataset , "tasks" )
91
+ assert "lum" in dataset .data_keys
92
+ assert hasattr (dataset , "config" )
93
+ assert hasattr (dataset , "meta" )
94
+ assert hasattr (dataset , "cached_sequences" )
95
+ assert hasattr (dataset , "arg_df" )
96
+ assert set (dataset [0 ].keys ()) == set (["lum" , * tasks ])
97
+ dataset = MultiTaskSintel (
98
+ tasks = tasks ,
99
+ n_frames = 4 ,
100
+ unittest = True ,
101
+ _init_cache = False ,
102
+ sintel_path = mock_sintel_data ,
103
+ )
104
+ assert not hasattr (dataset , "cached_sequences" )
103
105
104
106
105
107
def test_init_augmentation (dataset ):
0 commit comments