3
3
import shutil
4
4
import sys
5
5
import numpy as np
6
- import torch
7
- import torch .utils .data as data
6
+ from torch .utils .data import Dataset
8
7
from warnings import warn
9
8
from pathlib import Path
10
9
import cv2
11
10
12
- from pathml .datasets .base import BaseSlideDataset , BaseDataModule
11
+ from pathml .datasets .base_data_module import BaseDataModule
13
12
from pathml .datasets .utils import download_from_url
14
13
from pathml .preprocessing .transforms import TissueDetectionHE
15
14
from pathml .preprocessing .pipeline import Pipeline
16
- from pathml .core .slide_classes import HESlide
15
+ from pathml .core .slide_data import HESlide
17
16
from pathml .core .masks import Masks
18
17
19
18
import cProfile
@@ -52,8 +51,8 @@ def _download_peso(self, download_dir):
52
51
url = f'https://zenodo.org/record/1485967/files/'
53
52
for file in files :
54
53
print (f"downloading { file } " )
55
- download_from_url (f"{ url } { file } " , f"{ download_dir } " )
56
- for root , _ , files in os .walk (download_dir ):
54
+ download_from_url (f"{ url } { file } " , f"{ download_dir } " )
55
+ for root , _ , files in os .walk (download_dir ):
57
56
for file in files :
58
57
print (f"unzipping { file } " )
59
58
if zipfile .is_zipfile (f"{ root } /{ file } " ):
@@ -75,7 +74,7 @@ def _download_peso(self, download_dir):
75
74
profile .enable ()
76
75
77
76
name = '_' .join (file .split ('_' )[:- 1 ])
78
- maskpath = Path (name + '_HE_training_mask.tif' )
77
+ maskpath = Path (name + '_HE_training_mask.tif' )
79
78
mask = HESlide (filepath = str (Path (download_dir )/ Path ('peso_training_masks' )/ maskpath ), name = name )
80
79
shape1 , shape2 = mask .slide .get_image_shape ()
81
80
shape = (shape2 , shape1 )
@@ -93,7 +92,7 @@ def _download_peso(self, download_dir):
93
92
])
94
93
# TODO: choose tile size
95
94
wsi .run (pipeline , tile_size = 256 )
96
-
95
+
97
96
profile .disable ()
98
97
ps = pstats .Stats (profile )
99
98
ps .print_stats ()
@@ -111,7 +110,7 @@ def train_dataloader(self):
111
110
"""
112
111
Dataloader for training set.
113
112
"""
114
- return data . DataLoader (
113
+ return DataLoader (
115
114
dataset = self ._get_dataset (fold_ix = self .split ),
116
115
batch_size = self .batch_size ,
117
116
shuffle = self .shuffle
@@ -126,7 +125,7 @@ def valid_dataloader(self):
126
125
fold_ix = 2
127
126
else :
128
127
fold_ix = 1
129
- return data . DataLoader (
128
+ return DataLoader (
130
129
self ._get_dataset (fold_ix = fold_ix ),
131
130
batch_size = self .batch_size ,
132
131
shuffle = self .shuffle
@@ -141,20 +140,20 @@ def test_dataloader(self):
141
140
fold_ix = 3
142
141
else :
143
142
fold_ix = 1
144
- return data . DataLoader (
143
+ return DataLoader (
145
144
self ._get_dataset (fold_ix = fold_ix ),
146
145
batch_size = self .batch_size ,
147
146
shuffle = self .shuffle
148
147
)
149
148
150
- class PesoDataset (BaseSlideDataset ):
149
+ class PesoDataset (Dataset ):
151
150
"""
152
151
Dataset object for Peso dataset.
153
152
Raw data downloads:
154
153
IHC color deconvolution (n=62) with p63, ck8/18 stainings
155
- training masks (n=62) generated by segmenting IHC with UNet
154
+ training masks (n=62) generated by segmenting IHC with UNet
156
155
training masks corrected (n=25) with manual annotations
157
- wsis (n=62) wsis at 0.48 \mu/pix
156
+ wsis (n=62) at 0.48 \mu/pix
158
157
159
158
testing data:
160
159
testset regions collection of xml files with outlines of test regions
@@ -178,14 +177,14 @@ def __init__(self,
178
177
assert data_dir .isdir (), f"Error: data not found at { data_dir } "
179
178
180
179
if not any (fname .endswith ('.h5' ) for fname in os .listdir (self .data_dir / Path ('h5' ))):
181
- raise Exception ('must download dataset from pathml.datasets' )
182
-
180
+ raise Exception ('must download dataset from pathml.datasets' )
181
+
183
182
getitemdict = {}
184
183
items = 0
185
184
for h5 in os .listdir (self .data_dir / Path ('h5' )):
186
185
wsi = read (self .data_dir / Path ('h5' ) / Path (file ))
187
186
for i in range (len (wsi .tile )):
188
- getitemdict [items ] = (wsi .name , i )
187
+ getitemdict [items ] = (wsi .name , i )
189
188
items = items + 1
190
189
self .getitemdict = getitemdict
191
190
self .wsi = None
@@ -197,5 +196,5 @@ def __getitem__(self, ix):
197
196
wsiname , index = self .getitemdict [ix ]
198
197
if self .wsi is None or self .wsi .name != wsiname :
199
198
self .wsi = read (self .data_dir / Path ('h5' ) / Path (wsiname + '.h5' ))
200
- tile = self .wsi .tiles [index ]
199
+ tile = self .wsi .tiles [index ]
201
200
return tile
0 commit comments