Skip to content

Commit fa39ad1

Browse files
committed
split torch to submodule
1 parent 044c834 commit fa39ad1

File tree

1 file changed

+30
-21
lines changed

1 file changed

+30
-21
lines changed

histoprep/utils/_torch.py histoprep/utils/torch.py

+30-21
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,23 @@
1212
import numpy as np
1313
from PIL import Image
1414

15-
from histoprep._reader import SlideReader
15+
from histoprep.reader import SlideReader
1616

1717
try:
18-
from torch.utils.data import Dataset, IterableDataset
18+
from torch.utils.data import Dataset
1919

2020
HAS_PYTORCH = True
2121
except ImportError:
2222
HAS_PYTORCH = False
2323
Dataset = object
24-
IterableDataset = object
2524

2625
ERROR_PYTORCH = "Could not import torch, make sure it has been installed!"
2726
ERROR_LENGTH_MISMATCH = "Path length ({}) does not match label length ({})."
2827
ERROR_TILE_SHAPE = "Tile shape must be defined to create a cache array."
2928

3029

3130
class SlideReaderDataset(Dataset):
32-
"""Torch dataset yielding tile images from reader (requires `PyTorch`).
33-
34-
Args:
35-
reader: `SlideReader` instance.
36-
coordinates: Iterator of xywh-coordinates.
37-
level: Slide level for reading tile image. Defaults to 0.
38-
transform: Transform function for tile images. Defaults to None.
39-
"""
31+
"""Torch dataset yielding tile images from reader (requires `PyTorch`)."""
4032

4133
def __init__(
4234
self,
@@ -45,6 +37,17 @@ def __init__(
4537
level: int = 0,
4638
transform: Callable[[np.ndarray], Any] | None = None,
4739
) -> None:
40+
"""Initialize SlideReaderDataset.
41+
42+
Args:
43+
reader: `SlideReader` instance.
44+
coordinates: Iterator of xywh-coordinates.
45+
level: Slide level for reading tile image. Defaults to 0.
46+
transform: Transform function for tile images. Defaults to None.
47+
48+
Raises:
49+
ImportError: Could not import `PyTorch`.
50+
"""
4851
if not HAS_PYTORCH:
4952
raise ImportError(ERROR_PYTORCH)
5053
super().__init__()
@@ -65,16 +68,7 @@ def __getitem__(self, index: int) -> tuple[np.ndarray | Any, np.ndarray]:
6568

6669

6770
class TileImageDataset(Dataset):
68-
"""Torch dataset yielding tile images from paths (requires `PyTorch`).
69-
70-
Args:
71-
paths: Paths to tile images.
72-
labels: Indexable list of labels for each path. Defaults to None.
73-
transform: Transform function for tile images.. Defaults to None.
74-
use_cache: Cache each image to shared array, requires that each tile has the
75-
same shape. Defaults to False.
76-
tile_shape: Tile shape for creating a shared cache array. Defaults to None.
77-
"""
71+
"""Torch dataset yielding tile images from paths (requires `PyTorch`)."""
7872

7973
def __init__(
8074
self,
@@ -85,6 +79,21 @@ def __init__(
8579
use_cache: bool = False,
8680
tile_shape: tuple[int, ...] | None = None,
8781
) -> None:
82+
"""Torch dataset yielding tile images from paths (requires `PyTorch`).
83+
84+
Args:
85+
paths: Paths to tile images.
86+
labels: Indexable list of labels for each path. Defaults to None.
87+
transform: Transform function for tile images. Defaults to None.
88+
use_cache: Cache each image to shared array, requires that each tile has the
89+
same shape. Defaults to False.
90+
tile_shape: Tile shape for creating a shared cache array. Defaults to None.
91+
92+
Raises:
93+
ImportError: Could not import `PyTorch`.
94+
ValueError: Label and path lengths differ.
95+
ValueError: Tile shape is undefined but `use_cache=True`.
96+
"""
8897
super().__init__()
8998
if not HAS_PYTORCH:
9099
raise ImportError(ERROR_PYTORCH)

0 commit comments

Comments
 (0)