Closed
Description
This is a very widely used dataset in metric learning research and fine-grained recognition ("Birds dataset"): http://www.vision.caltech.edu/visipedia/CUB-200-2011.html
At few years ago I abused ImageFolder and CIFAR10 multiple inheritance classes to add this dataset:
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
class CUB2011(ImageFolder):
base_folder = 'CUB_200_2011/images'
url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
filename = 'CUB_200_2011.tgz'
tgz_md5 = '97eceeb196236b17998738112f37df78'
train_list = [
['001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg', '4c84da568f89519f84640c54b7fba7c2'],
['002.Laysan_Albatross/Laysan_Albatross_0001_545.jpg', 'e7db63424d0e384dba02aacaf298cdc0'],
]
test_list = [
['198.Rock_Wren/Rock_Wren_0001_189289.jpg', '487d082f1fbd58faa7b08aa5ede3cc00'],
['200.Common_Yellowthroat/Common_Yellowthroat_0003_190521.jpg', '96fd60ce4b4805e64368efc32bf5c6fe']
]
def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5):
return False
return True
def download(self):
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
def __init__(self, root, transform=None, target_transform=None, download=False, **kwargs):
self.root = root
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
super().__init__(self, os.path.join(root, self.base_folder),
transform=transform, target_transform=target_transform, **kwargs)
I extended this class for a metric learning task like this:
class CUB2011MetricLearning(CUB2011):
num_training_classes = 100
def __init__(self, root, train=False, transform=None, target_transform=None, download=False, **kwargs):
CUB2011.__init__(self, root, transform=transform, target_transform=target_transform, download=download, **kwargs)
self.classes = self.classes[:self.num_training_classes] if train else self.classes[self.num_training_classes:]
self.class_to_idx = {class_label : class_label_ind for class_label, class_label_ind in self.class_to_idx.items() if class_label in self.classes}
self.imgs = [(image_file_path, class_label_ind) for image_file_path, class_label_ind in self.imgs if class_label_ind in self.class_to_idx.values()]
@fmassa I would appreciate advice on how to make this proper. Multiple inheritance from CIFAR10 looks fragile to me (I remember mixing in CIFAR10 functionality was important, but don't remember why...). And I wonder if abstractions evolved since then: e.g. if manual download / check integrity calls are needed