12
12
import numpy as np
13
13
from PIL import Image
14
14
15
- from histoprep ._reader import SlideReader
15
+ from histoprep .reader import SlideReader
16
16
17
17
try :
18
- from torch .utils .data import Dataset , IterableDataset
18
+ from torch .utils .data import Dataset
19
19
20
20
HAS_PYTORCH = True
21
21
except ImportError :
22
22
HAS_PYTORCH = False
23
23
Dataset = object
24
- IterableDataset = object
25
24
26
25
ERROR_PYTORCH = "Could not import torch, make sure it has been installed!"
27
26
ERROR_LENGTH_MISMATCH = "Path length ({}) does not match label length ({})."
28
27
ERROR_TILE_SHAPE = "Tile shape must be defined to create a cache array."
29
28
30
29
31
30
class SlideReaderDataset (Dataset ):
32
- """Torch dataset yielding tile images from reader (requires `PyTorch`).
33
-
34
- Args:
35
- reader: `SlideReader` instance.
36
- coordinates: Iterator of xywh-coordinates.
37
- level: Slide level for reading tile image. Defaults to 0.
38
- transform: Transform function for tile images. Defaults to None.
39
- """
31
+ """Torch dataset yielding tile images from reader (requires `PyTorch`)."""
40
32
41
33
def __init__ (
42
34
self ,
@@ -45,6 +37,17 @@ def __init__(
45
37
level : int = 0 ,
46
38
transform : Callable [[np .ndarray ], Any ] | None = None ,
47
39
) -> None :
40
+ """Initialize SlideReaderDataset.
41
+
42
+ Args:
43
+ reader: `SlideReader` instance.
44
+ coordinates: Iterator of xywh-coordinates.
45
+ level: Slide level for reading tile image. Defaults to 0.
46
+ transform: Transform function for tile images. Defaults to None.
47
+
48
+ Raises:
49
+ ImportError: Could not import `PyTorch`.
50
+ """
48
51
if not HAS_PYTORCH :
49
52
raise ImportError (ERROR_PYTORCH )
50
53
super ().__init__ ()
@@ -65,16 +68,7 @@ def __getitem__(self, index: int) -> tuple[np.ndarray | Any, np.ndarray]:
65
68
66
69
67
70
class TileImageDataset (Dataset ):
68
- """Torch dataset yielding tile images from paths (requires `PyTorch`).
69
-
70
- Args:
71
- paths: Paths to tile images.
72
- labels: Indexable list of labels for each path. Defaults to None.
73
- transform: Transform function for tile images.. Defaults to None.
74
- use_cache: Cache each image to shared array, requires that each tile has the
75
- same shape. Defaults to False.
76
- tile_shape: Tile shape for creating a shared cache array. Defaults to None.
77
- """
71
+ """Torch dataset yielding tile images from paths (requires `PyTorch`)."""
78
72
79
73
def __init__ (
80
74
self ,
@@ -85,6 +79,21 @@ def __init__(
85
79
use_cache : bool = False ,
86
80
tile_shape : tuple [int , ...] | None = None ,
87
81
) -> None :
82
+ """Torch dataset yielding tile images from paths (requires `PyTorch`).
83
+
84
+ Args:
85
+ paths: Paths to tile images.
86
+ labels: Indexable list of labels for each path. Defaults to None.
87
+ transform: Transform function for tile images. Defaults to None.
88
+ use_cache: Cache each image to shared array, requires that each tile has the
89
+ same shape. Defaults to False.
90
+ tile_shape: Tile shape for creating a shared cache array. Defaults to None.
91
+
92
+ Raises:
93
+ ImportError: Could not import `PyTorch`.
94
+ ValueError: Label and path lengths differ.
95
+ ValueError: Tile shape is undefined but `use_cache=True`.
96
+ """
88
97
super ().__init__ ()
89
98
if not HAS_PYTORCH :
90
99
raise ImportError (ERROR_PYTORCH )
0 commit comments