Skip to content

Commit d6e9c19

Browse files
EtyltomMoral
andauthored
ENH download the BSD500 dataset from deepinv (#9)
Co-authored-by: Thomas Moreau <thomas.moreau.2010@gmail.com>
1 parent 5f9fd8d commit d6e9c19

File tree

7 files changed

+22
-175
lines changed

7 files changed

+22
-175
lines changed

benchmark_utils/image_dataset.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

config.yml

Lines changed: 0 additions & 5 deletions
This file was deleted.

datasets/bsd500_bsd20.py

Lines changed: 0 additions & 118 deletions
This file was deleted.

datasets/bsd500_cbsd68.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from benchopt import BaseDataset, safe_import_context, config
1+
from benchopt import BaseDataset, safe_import_context
2+
from benchopt.config import get_data_path
23

34
with safe_import_context() as import_ctx:
45
import deepinv as dinv
56
import torch
67
from torchvision import transforms
78
from datasets import load_dataset
8-
from benchmark_utils.image_dataset import ImageDataset
99
from benchmark_utils.hugging_face_torch_dataset import (
1010
HuggingFaceTorchDataset
1111
)
@@ -77,8 +77,9 @@ def get_data(self):
7777
transforms.ToTensor()
7878
])
7979

80-
train_dataset = ImageDataset(
81-
config.get_data_path("BSD500") / "train", transform=transform
80+
path = get_data_path("BSD500")
81+
train_dataset = dinv.datasets.BSDS500(
82+
path, download=True, transform=transform
8283
)
8384

8485
dataset_cbsd68 = load_dataset("deepinv/CBSD68")
@@ -90,9 +91,7 @@ def get_data(self):
9091
train_dataset=train_dataset,
9192
test_dataset=test_dataset,
9293
physics=physics,
93-
save_dir=config.get_data_path(
94-
key="generated_datasets"
95-
) / "bsd500_cbsd68",
94+
save_dir=get_data_path("bsd500_cbsd68"),
9695
dataset_filename=self.task,
9796
device=device
9897
)

datasets/bsd500_imnet100.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from benchopt import BaseDataset, safe_import_context, config
1+
from benchopt import BaseDataset, safe_import_context
2+
from benchopt.config import get_data_path
23

34
with safe_import_context() as import_ctx:
45
import deepinv as dinv
56
import torch
67
from torchvision import transforms
7-
from benchmark_utils.image_dataset import ImageDataset
88
from benchmark_utils.hugging_face_torch_dataset import (
99
HuggingFaceTorchDataset
1010
)
@@ -77,9 +77,9 @@ def get_data(self):
7777
transforms.ToTensor()
7878
])
7979

80-
train_dataset = ImageDataset(
81-
config.get_data_path("BSD500") / "train",
82-
transform=transform
80+
path = get_data_path("BSD500")
81+
train_dataset = dinv.datasets.BSDS500(
82+
path, download=True, transform=transform
8383
)
8484

8585
dataset_miniImnet100 = load_dataset("mterris/miniImnet100")
@@ -93,9 +93,7 @@ def get_data(self):
9393
train_dataset=train_dataset,
9494
test_dataset=test_dataset,
9595
physics=physics,
96-
save_dir=config.get_data_path(
97-
key="generated_datasets"
98-
) / "bsd500_imnet100",
96+
save_dir=get_data_path("bsd500_imnet100"),
9997
dataset_filename=self.task,
10098
device=device
10199
)

datasets/cbsd68_set3c.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from benchopt import BaseDataset, safe_import_context, config
1+
from benchopt import BaseDataset, safe_import_context
2+
from benchopt.config import get_data_path
23

34
with safe_import_context() as import_ctx:
45
import deepinv as dinv
@@ -91,9 +92,7 @@ def get_data(self):
9192
train_dataset=train_dataset,
9293
test_dataset=test_dataset,
9394
physics=physics,
94-
save_dir=config.get_data_path(
95-
key="generated_datasets"
96-
) / "sbsd68_set3c",
95+
save_dir=get_data_path("cbsd68_set3c"),
9796
dataset_filename=self.task,
9897
device=device
9998
)

test_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,10 @@ def check_test_solver_install(solver_class):
1111
detecting the situation.
1212
"""
1313
pass
14+
15+
16+
def check_test_dataset_get_data(benchmark, dataset_class):
17+
if sys.platform == "darwin":
18+
pytest.skip(
19+
"Skipping test_dataset_get_data on MacOS."
20+
)

0 commit comments

Comments
 (0)