1
1
import io
2
+ import pathlib
2
3
from collections import namedtuple
3
- from typing import Any , Dict , List , Optional , Tuple , Iterator
4
+ from typing import Any , Dict , List , Optional , Tuple , Iterator , Union
5
+ from unicodedata import category
4
6
5
7
from torchdata .datapipes .iter import IterDataPipe , Mapper , Zipper
6
8
from torchvision .prototype import features
7
9
from torchvision .prototype .datasets .utils import (
8
- Dataset ,
9
- DatasetConfig ,
10
- DatasetInfo ,
10
+ Dataset2 ,
11
11
OnlineResource ,
12
12
GDriveResource ,
13
13
)
17
17
)
18
18
from torchvision .prototype .features import Label
19
19
20
+ from .._api import register_dataset , register_info
21
+
22
+
23
+ NAME = "pcam"
24
+
20
25
21
26
class PCAMH5Reader (IterDataPipe [Tuple [str , io .IOBase ]]):
22
27
def __init__ (
@@ -40,15 +45,25 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
40
45
_Resource = namedtuple ("_Resource" , ("file_name" , "gdrive_id" , "sha256" ))
41
46
42
47
43
- class PCAM (Dataset ):
44
- def _make_info (self ) -> DatasetInfo :
45
- return DatasetInfo (
46
- "pcam" ,
47
- homepage = "https://github.com/basveeling/pcam" ,
48
- categories = 2 ,
49
- valid_options = dict (split = ("train" , "test" , "val" )),
50
- dependencies = ["h5py" ],
51
- )
48
+ @register_info (NAME )
49
+ def _info () -> Dict [str , Any ]:
50
+ return dict (categories = ["0" , "1" ])
51
+
52
+
53
+ @register_dataset (NAME )
54
+ class PCAM (Dataset2 ):
55
+ # TODO write proper docstring
56
+ """PCAM Dataset
57
+
58
+ homepage="https://github.com/basveeling/pcam"
59
+ """
60
+
61
+ def __init__ (
62
+ self , root : Union [str , pathlib .Path ], split : str = "train" , * , skip_integrity_check : bool = False
63
+ ) -> None :
64
+ self ._split = self ._verify_str_arg (split , "split" , {"train" , "val" , "test" })
65
+ self ._categories = _info ()["categories" ]
66
+ super ().__init__ (root , skip_integrity_check = skip_integrity_check , dependencies = ("h5py" ,))
52
67
53
68
_RESOURCES = {
54
69
"train" : (
@@ -89,23 +104,21 @@ def _make_info(self) -> DatasetInfo:
89
104
),
90
105
}
91
106
92
- def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
107
+ def _resources (self ) -> List [OnlineResource ]:
93
108
return [ # = [images resource, targets resource]
94
109
GDriveResource (file_name = file_name , id = gdrive_id , sha256 = sha256 , preprocess = "decompress" )
95
- for file_name , gdrive_id , sha256 in self ._RESOURCES [config . split ]
110
+ for file_name , gdrive_id , sha256 in self ._RESOURCES [self . _split ]
96
111
]
97
112
98
113
def _prepare_sample (self , data : Tuple [Any , Any ]) -> Dict [str , Any ]:
99
114
image , target = data # They're both numpy arrays at this point
100
115
101
116
return {
102
117
"image" : features .Image (image .transpose (2 , 0 , 1 )),
103
- "label" : Label (target .item ()),
118
+ "label" : Label (target .item (), categories = self . _categories ),
104
119
}
105
120
106
- def _make_datapipe (
107
- self , resource_dps : List [IterDataPipe ], * , config : DatasetConfig
108
- ) -> IterDataPipe [Dict [str , Any ]]:
121
+ def _datapipe (self , resource_dps : List [IterDataPipe ]) -> IterDataPipe [Dict [str , Any ]]:
109
122
110
123
images_dp , targets_dp = resource_dps
111
124
@@ -116,3 +129,6 @@ def _make_datapipe(
116
129
dp = hint_shuffling (dp )
117
130
dp = hint_sharding (dp )
118
131
return Mapper (dp , self ._prepare_sample )
132
+
133
+ def __len__ (self ):
134
+ return 262_144 if self ._split == "train" else 32_768
0 commit comments