Skip to content

[proposal] CUB 200-2011 dataset #1654

Closed
@vadimkantorov

Description

@vadimkantorov

This is a very widely used dataset in metric learning research and fine-grained recognition ("Birds dataset"): http://www.vision.caltech.edu/visipedia/CUB-200-2011.html

At few years ago I abused ImageFolder and CIFAR10 multiple inheritance classes to add this dataset:

from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import check_integrity, download_and_extract_archive

class CUB2011(ImageFolder):
	base_folder = 'CUB_200_2011/images'
	url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
	filename = 'CUB_200_2011.tgz'
	tgz_md5 = '97eceeb196236b17998738112f37df78'

	train_list = [
		['001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg', '4c84da568f89519f84640c54b7fba7c2'],
		['002.Laysan_Albatross/Laysan_Albatross_0001_545.jpg', 'e7db63424d0e384dba02aacaf298cdc0'],
	]
	test_list = [
		['198.Rock_Wren/Rock_Wren_0001_189289.jpg', '487d082f1fbd58faa7b08aa5ede3cc00'],
		['200.Common_Yellowthroat/Common_Yellowthroat_0003_190521.jpg', '96fd60ce4b4805e64368efc32bf5c6fe']
	]

	def _check_integrity(self):
		root = self.root
		for fentry in (self.train_list + self.test_list):
			filename, md5 = fentry[0], fentry[1]
			fpath = os.path.join(root, self.base_folder, filename)
			if not check_integrity(fpath, md5):
				return False
		return True
	
	def download(self):
		if self._check_integrity():
			print('Files already downloaded and verified')
			return
		download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)

	def __init__(self, root, transform=None, target_transform=None, download=False, **kwargs):
		self.root = root
		if download:
			self.download()

		if not self._check_integrity():
			raise RuntimeError('Dataset not found or corrupted.' +
							   ' You can use download=True to download it')
		super().__init__(self, os.path.join(root, self.base_folder),
			transform=transform, target_transform=target_transform, **kwargs)

I extended this class for a metric learning task like this:

class CUB2011MetricLearning(CUB2011):
	num_training_classes = 100

	def __init__(self, root, train=False, transform=None, target_transform=None, download=False, **kwargs):
		CUB2011.__init__(self, root, transform=transform, target_transform=target_transform, download=download, **kwargs)
		self.classes = self.classes[:self.num_training_classes] if train else self.classes[self.num_training_classes:]
		self.class_to_idx = {class_label : class_label_ind for class_label, class_label_ind in self.class_to_idx.items() if class_label in self.classes}
		self.imgs = [(image_file_path, class_label_ind) for image_file_path, class_label_ind in self.imgs if class_label_ind in self.class_to_idx.values()]

@fmassa I would appreciate advice on how to make this proper. Multiple inheritance from CIFAR10 looks fragile to me (I remember mixing in CIFAR10 functionality was important, but don't remember why...). And I wonder if abstractions evolved since then: e.g. if manual download / check integrity calls are needed

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions