-
Notifications
You must be signed in to change notification settings - Fork 557
/
office31.py
67 lines (59 loc) · 2.85 KB
/
office31.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional
import os
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class Office31(ImageList):
"""Office31 Dataset.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \
``'D'``: dslr and ``'W'``: webcam.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
amazon/
images/
backpack/
*.jpg
...
dslr/
webcam/
image_list/
amazon.txt
dslr.txt
webcam.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/2c1dd9fbcaa9455aa4ad/?dl=1"),
("amazon", "amazon.tgz", "https://cloud.tsinghua.edu.cn/f/ec12dfcddade43ab8101/?dl=1"),
("dslr", "dslr.tgz", "https://cloud.tsinghua.edu.cn/f/a41d818ae2f34da7bb32/?dl=1"),
("webcam", "webcam.tgz", "https://cloud.tsinghua.edu.cn/f/8a41009a166e4131adcd/?dl=1"),
]
image_list = {
"A": "image_list/amazon.txt",
"D": "image_list/dslr.txt",
"W": "image_list/webcam.txt"
}
CLASSES = ['back_pack', 'bike', 'bike_helmet', 'bookcase', 'bottle', 'calculator', 'desk_chair', 'desk_lamp',
'desktop_computer', 'file_cabinet', 'headphones', 'keyboard', 'laptop_computer', 'letter_tray',
'mobile_phone', 'monitor', 'mouse', 'mug', 'paper_notebook', 'pen', 'phone', 'printer', 'projector',
'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can']
def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs):
assert task in self.image_list
data_list_file = os.path.join(root, self.image_list[task])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(Office31, self).__init__(root, Office31.CLASSES, data_list_file=data_list_file, **kwargs)
@classmethod
def domains(cls):
return list(cls.image_list.keys())