Skip to content

Commit d9fef09

Browse files
committed
refactor for new dataset design
1 parent a349539 commit d9fef09

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

pathml/datasets/peso.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@
33
import shutil
44
import sys
55
import numpy as np
6-
import torch
7-
import torch.utils.data as data
6+
from torch.utils.data import Dataset
87
from warnings import warn
98
from pathlib import Path
109
import cv2
1110

12-
from pathml.datasets.base import BaseSlideDataset, BaseDataModule
11+
from pathml.datasets.base_data_module import BaseDataModule
1312
from pathml.datasets.utils import download_from_url
1413
from pathml.preprocessing.transforms import TissueDetectionHE
1514
from pathml.preprocessing.pipeline import Pipeline
16-
from pathml.core.slide_classes import HESlide
15+
from pathml.core.slide_data import HESlide
1716
from pathml.core.masks import Masks
1817

1918
import cProfile
@@ -52,8 +51,8 @@ def _download_peso(self, download_dir):
5251
url = f'https://zenodo.org/record/1485967/files/'
5352
for file in files:
5453
print(f"downloading {file}")
55-
download_from_url(f"{url}{file}", f"{download_dir}")
56-
for root, _, files in os.walk(download_dir):
54+
download_from_url(f"{url}{file}", f"{download_dir}")
55+
for root, _, files in os.walk(download_dir):
5756
for file in files:
5857
print(f"unzipping {file}")
5958
if zipfile.is_zipfile(f"{root}/{file}"):
@@ -75,7 +74,7 @@ def _download_peso(self, download_dir):
7574
profile.enable()
7675

7776
name = '_'.join(file.split('_')[:-1])
78-
maskpath = Path(name+'_HE_training_mask.tif')
77+
maskpath = Path(name+'_HE_training_mask.tif')
7978
mask = HESlide(filepath = str(Path(download_dir)/Path('peso_training_masks')/maskpath), name = name)
8079
shape1, shape2 = mask.slide.get_image_shape()
8180
shape = (shape2, shape1)
@@ -93,7 +92,7 @@ def _download_peso(self, download_dir):
9392
])
9493
# TODO: choose tile size
9594
wsi.run(pipeline, tile_size=256)
96-
95+
9796
profile.disable()
9897
ps = pstats.Stats(profile)
9998
ps.print_stats()
@@ -111,7 +110,7 @@ def train_dataloader(self):
111110
"""
112111
Dataloader for training set.
113112
"""
114-
return data.DataLoader(
113+
return DataLoader(
115114
dataset = self._get_dataset(fold_ix = self.split),
116115
batch_size = self.batch_size,
117116
shuffle = self.shuffle
@@ -126,7 +125,7 @@ def valid_dataloader(self):
126125
fold_ix = 2
127126
else:
128127
fold_ix = 1
129-
return data.DataLoader(
128+
return DataLoader(
130129
self._get_dataset(fold_ix = fold_ix),
131130
batch_size = self.batch_size,
132131
shuffle = self.shuffle
@@ -141,20 +140,20 @@ def test_dataloader(self):
141140
fold_ix = 3
142141
else:
143142
fold_ix = 1
144-
return data.DataLoader(
143+
return DataLoader(
145144
self._get_dataset(fold_ix = fold_ix),
146145
batch_size = self.batch_size,
147146
shuffle = self.shuffle
148147
)
149148

150-
class PesoDataset(BaseSlideDataset):
149+
class PesoDataset(Dataset):
151150
"""
152151
Dataset object for Peso dataset.
153152
Raw data downloads:
154153
IHC color deconvolution (n=62) with p63, ck8/18 stainings
155-
training masks (n=62) generated by segmenting IHC with UNet
154+
training masks (n=62) generated by segmenting IHC with UNet
156155
training masks corrected (n=25) with manual annotations
157-
wsis (n=62) wsis at 0.48 \mu/pix
156+
wsis (n=62) at 0.48 \mu/pix
158157
159158
testing data:
160159
testset regions collection of xml files with outlines of test regions
@@ -178,14 +177,14 @@ def __init__(self,
178177
assert data_dir.isdir(), f"Error: data not found at {data_dir}"
179178

180179
if not any(fname.endswith('.h5') for fname in os.listdir(self.data_dir / Path('h5'))):
181-
raise Exception('must download dataset from pathml.datasets')
182-
180+
raise Exception('must download dataset from pathml.datasets')
181+
183182
getitemdict = {}
184183
items = 0
185184
for h5 in os.listdir(self.data_dir / Path('h5')):
186185
wsi = read(self.data_dir / Path('h5') / Path(file))
187186
for i in range(len(wsi.tile)):
188-
getitemdict[items] = (wsi.name, i)
187+
getitemdict[items] = (wsi.name, i)
189188
items = items + 1
190189
self.getitemdict = getitemdict
191190
self.wsi = None
@@ -197,5 +196,5 @@ def __getitem__(self, ix):
197196
wsiname, index = self.getitemdict[ix]
198197
if self.wsi is None or self.wsi.name != wsiname:
199198
self.wsi = read(self.data_dir / Path('h5') / Path(wsiname + '.h5'))
200-
tile = self.wsi.tiles[index]
199+
tile = self.wsi.tiles[index]
201200
return tile

0 commit comments

Comments
 (0)