Skip to content

Commit 0562409

Browse files
authored
[train] Fix MosaicTrainer example and unit test (ray-project#38970)
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
1 parent ed7186a commit 0562409

File tree

2 files changed

+6
-25
lines changed

2 files changed

+6
-25
lines changed

python/ray/train/examples/mosaic_cifar10_example.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22
from filelock import FileLock
33
import os
4-
import tempfile
54

65
import torch
76
import torch.utils.data
@@ -31,8 +30,8 @@ def trainer_init_per_worker(config):
3130
[transforms.ToTensor(), transforms.Normalize(mean, std)]
3231
)
3332

34-
data_directory = tempfile.mkdtemp(prefix="cifar_data")
35-
with FileLock(os.path.join(data_directory, "data.lock")):
33+
data_directory = os.path.expanduser("~/data")
34+
with FileLock(os.path.expanduser("~/data.lock")):
3635
train_dataset = torch.utils.data.Subset(
3736
datasets.CIFAR10(
3837
data_directory, train=True, download=True, transform=cifar10_transforms

python/ray/train/tests/test_mosaic_trainer.py

+4-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from filelock import FileLock
22
import os
3-
import tempfile
43

54
import pytest
65

@@ -11,7 +10,6 @@
1110
from torchvision import transforms, datasets
1211

1312
from ray.train import ScalingConfig
14-
from ray.air.constants import TRAINING_ITERATION
1513
import ray.train as train
1614
from ray.train.trainer import TrainingFailedError
1715

@@ -37,19 +35,19 @@ def trainer_init_per_worker(config):
3735
[transforms.ToTensor(), transforms.Normalize(mean, std)]
3836
)
3937

40-
data_directory = tempfile.mkdtemp(prefix="cifar_data")
41-
with FileLock(os.path.join(data_directory, "data.lock")):
38+
data_directory = os.path.expanduser("~/data")
39+
with FileLock(os.path.expanduser("~/data.lock")):
4240
train_dataset = torch.utils.data.Subset(
4341
datasets.CIFAR10(
4442
data_directory, train=True, download=True, transform=cifar10_transforms
4543
),
46-
list(range(BATCH_SIZE * 10)),
44+
list(range(BATCH_SIZE)),
4745
)
4846
test_dataset = torch.utils.data.Subset(
4947
datasets.CIFAR10(
5048
data_directory, train=False, download=True, transform=cifar10_transforms
5149
),
52-
list(range(BATCH_SIZE * 10)),
50+
list(range(BATCH_SIZE)),
5351
)
5452

5553
batch_size_per_worker = BATCH_SIZE // train.get_context().get_world_size()
@@ -88,22 +86,6 @@ def trainer_init_per_worker(config):
8886
trainer_init_per_worker.__test__ = False
8987

9088

91-
def test_mosaic_cifar10(ray_start_4_cpus):
92-
from ray.train.examples.mosaic_cifar10_example import train_mosaic_cifar10
93-
94-
result = train_mosaic_cifar10(max_duration="5ep").metrics_dataframe
95-
96-
# check the max epoch value
97-
assert result["epoch"][result.index[-1]] == 4
98-
99-
# check train_iterations
100-
assert result[TRAINING_ITERATION][result.index[-1]] == 5
101-
102-
# check metrics/train/Accuracy has increased
103-
acc = list(result["metrics/train/Accuracy"])
104-
assert acc[-1] > acc[0]
105-
106-
10789
def test_init_errors(ray_start_4_cpus):
10890
from ray.train.mosaic import MosaicTrainer
10991

0 commit comments

Comments
 (0)