1
1
from filelock import FileLock
2
2
import os
3
- import tempfile
4
3
5
4
import pytest
6
5
11
10
from torchvision import transforms , datasets
12
11
13
12
from ray .train import ScalingConfig
14
- from ray .air .constants import TRAINING_ITERATION
15
13
import ray .train as train
16
14
from ray .train .trainer import TrainingFailedError
17
15
@@ -37,19 +35,19 @@ def trainer_init_per_worker(config):
37
35
[transforms .ToTensor (), transforms .Normalize (mean , std )]
38
36
)
39
37
40
- data_directory = tempfile . mkdtemp ( prefix = "cifar_data " )
41
- with FileLock (os .path .join ( data_directory , " data.lock" )):
38
+ data_directory = os . path . expanduser ( "~/data " )
39
+ with FileLock (os .path .expanduser ( "~/ data.lock" )):
42
40
train_dataset = torch .utils .data .Subset (
43
41
datasets .CIFAR10 (
44
42
data_directory , train = True , download = True , transform = cifar10_transforms
45
43
),
46
- list (range (BATCH_SIZE * 10 )),
44
+ list (range (BATCH_SIZE )),
47
45
)
48
46
test_dataset = torch .utils .data .Subset (
49
47
datasets .CIFAR10 (
50
48
data_directory , train = False , download = True , transform = cifar10_transforms
51
49
),
52
- list (range (BATCH_SIZE * 10 )),
50
+ list (range (BATCH_SIZE )),
53
51
)
54
52
55
53
batch_size_per_worker = BATCH_SIZE // train .get_context ().get_world_size ()
@@ -88,22 +86,6 @@ def trainer_init_per_worker(config):
88
86
trainer_init_per_worker .__test__ = False
89
87
90
88
91
- def test_mosaic_cifar10 (ray_start_4_cpus ):
92
- from ray .train .examples .mosaic_cifar10_example import train_mosaic_cifar10
93
-
94
- result = train_mosaic_cifar10 (max_duration = "5ep" ).metrics_dataframe
95
-
96
- # check the max epoch value
97
- assert result ["epoch" ][result .index [- 1 ]] == 4
98
-
99
- # check train_iterations
100
- assert result [TRAINING_ITERATION ][result .index [- 1 ]] == 5
101
-
102
- # check metrics/train/Accuracy has increased
103
- acc = list (result ["metrics/train/Accuracy" ])
104
- assert acc [- 1 ] > acc [0 ]
105
-
106
-
107
89
def test_init_errors (ray_start_4_cpus ):
108
90
from ray .train .mosaic import MosaicTrainer
109
91
0 commit comments