-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtorch_test.py
121 lines (106 loc) · 4.48 KB
/
torch_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from histoprep import SlideReader
from histoprep.utils import (
SlideReaderDataset,
TileImageDataset,
)
from ._utils import (
SLIDE_PATH_CZI,
SLIDE_PATH_JPEG,
SLIDE_PATH_SVS,
TMP_DIRECTORY,
clean_temporary_directory,
)
def test_posix_paths() -> None:
clean_temporary_directory()
reader = SlideReader(SLIDE_PATH_JPEG)
metadata = reader.save_regions(TMP_DIRECTORY, reader.get_tile_coordinates(None, 96))
dataset = TileImageDataset(
paths=[Path(x) for x in metadata["path"].to_list()],
labels=metadata[list("xywh")].to_numpy(),
transform=lambda x: x[..., 0],
)
next(iter(DataLoader(dataset, batch_size=32)))
def test_reader_dataset_loader_pil() -> None:
reader = SlideReader(SLIDE_PATH_JPEG)
__, tissue_mask = reader.get_tissue_mask()
coords = reader.get_tile_coordinates(tissue_mask, 512, max_background=0.01)
dataset = SlideReaderDataset(reader, coords, level=1, transform=lambda z: z)
assert isinstance(dataset, Dataset)
loader = DataLoader(dataset, batch_size=4, num_workers=10, drop_last=True)
for i, (batch_images, batch_coords) in enumerate(loader):
assert batch_images.shape == (4, 256, 256, 3)
assert isinstance(batch_images, torch.Tensor)
assert batch_coords.shape == (4, 4)
assert isinstance(batch_coords, torch.Tensor)
if i > 20:
break
def test_reader_dataset_loader_openslide() -> None:
reader = SlideReader(SLIDE_PATH_SVS)
__, tissue_mask = reader.get_tissue_mask()
coords = reader.get_tile_coordinates(tissue_mask, 512, max_background=0.01)
dataset = SlideReaderDataset(reader, coords, level=1, transform=lambda z: z)
assert isinstance(dataset, Dataset)
loader = DataLoader(dataset, batch_size=4, num_workers=10, drop_last=True)
for i, (batch_images, batch_coords) in enumerate(loader):
assert batch_images.shape == (4, 256, 256, 3)
assert isinstance(batch_images, torch.Tensor)
assert batch_coords.shape == (4, 4)
assert isinstance(batch_coords, torch.Tensor)
if i > 20:
break
def test_reader_dataset_loader_czi() -> None:
reader = SlideReader(SLIDE_PATH_CZI)
__, tissue_mask = reader.get_tissue_mask()
coords = reader.get_tile_coordinates(tissue_mask, 512, max_background=0.01)
# CZI fails if multiple workers read data from same instance, which should not
# happen here as there is some voodoo shit going on with `Dataset` & `DataLoader`...
dataset = SlideReaderDataset(reader, coords, level=1, transform=lambda z: z)
assert isinstance(dataset, Dataset)
loader = DataLoader(dataset, batch_size=4, num_workers=10, drop_last=True)
for i, (batch_images, batch_coords) in enumerate(loader):
assert batch_images.shape == (4, 256, 256, 3)
assert isinstance(batch_images, torch.Tensor)
assert batch_coords.shape == (4, 4)
assert isinstance(batch_coords, torch.Tensor)
if i > 20:
break
def test_tile_dataset_loader() -> None:
clean_temporary_directory()
reader = SlideReader(SLIDE_PATH_JPEG)
metadata = reader.save_regions(TMP_DIRECTORY, reader.get_tile_coordinates(None, 96))
dataset = TileImageDataset(
metadata["path"].to_numpy(),
labels=metadata[list("xywh")].to_numpy(),
transform=lambda x: x[..., 0],
)
batch_images, batch_paths, batch_coords = next(
iter(DataLoader(dataset, batch_size=32))
)
clean_temporary_directory()
assert batch_images.shape == (32, 96, 96)
assert len(batch_paths) == 32
assert batch_coords.shape == (32, 4)
def test_tile_dataset_cache() -> None:
clean_temporary_directory()
reader = SlideReader(SLIDE_PATH_JPEG)
metadata = reader.save_regions(TMP_DIRECTORY, reader.get_tile_coordinates(None, 96))
dataset = TileImageDataset(
metadata["path"].to_numpy(),
labels=metadata[list("xywh")].to_numpy(),
transform=lambda x: x[..., 0],
use_cache=True,
tile_shape=(96, 96, 3),
)
batch_images, batch_paths, batch_coords = next(
iter(DataLoader(dataset, batch_size=32))
)
clean_temporary_directory()
assert batch_images.shape == (32, 96, 96)
assert len(batch_paths) == 32
assert batch_coords.shape == (32, 4)
assert dataset._cached_indices == set(range(32))
assert np.equal(dataset._cache_array[0][..., 0], batch_images[0].numpy()).all()