Skip to content

Commit bdd62e2

Browse files
committed
Train and test prototypical networks on MiniImageNet
1 parent 64226e0 commit bdd62e2

21 files changed

+849
-0
lines changed

dataloaders/__init__.py

Whitespace-only changes.
165 Bytes
Binary file not shown.
Binary file not shown.
1.71 KB
Binary file not shown.
Binary file not shown.

dataloaders/few_shot.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torchvision.datasets.folder as datasets
2+
3+
4+
class ImageFolderFewShot(datasets.ImageFolder):
5+
"""A generic data loader where the images are arranged in this way: ::
6+
root/dog/xxx.png
7+
root/dog/xxy.png
8+
root/dog/xxz.png
9+
root/cat/123.png
10+
root/cat/nsdf3.png
11+
root/cat/asd932_.png
12+
Args:
13+
root (string): Root directory path.
14+
transform (callable, optional): A function/transform that takes in an PIL image
15+
and returns a transformed version. E.g, ``transforms.RandomCrop``
16+
target_transform (callable, optional): A function/transform that takes in the
17+
target and transforms it.
18+
loader (callable, optional): A function to load an image given its path.
19+
Attributes:
20+
classes (list): List of the class names.
21+
class_to_idx (dict): Dict with items (class_name, class_index).
22+
imgs (list): List of (image path, class_index) tuples
23+
"""
24+
25+
def __init__(self, root, transform=None, target_transform=None):
26+
super(ImageFolderFewShot, self).__init__(root, transform=transform, target_transform=target_transform)
27+
self.labels = [sample[1] for sample in self.samples]

dataloaders/mini_imagenet_loader.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import os.path as osp
2+
from PIL import Image
3+
4+
from torch.utils.data import Dataset
5+
from torchvision import transforms
6+
7+
8+
ROOT_PATH = './mini_imagenet/'
9+
10+
11+
class MiniImageNet(Dataset):
12+
13+
def __init__(self, setname):
14+
csv_path = osp.join(ROOT_PATH, setname + '.csv')
15+
lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
16+
17+
data = []
18+
label = []
19+
lb = -1
20+
21+
self.wnids = []
22+
self.class_idx_to_sample_idx = {}
23+
24+
for idx, l in enumerate(lines):
25+
name, wnid = l.split(',')
26+
path = osp.join(ROOT_PATH, 'images', name)
27+
if wnid not in self.wnids:
28+
self.wnids.append(wnid)
29+
lb += 1
30+
self.class_idx_to_sample_idx.update({lb: []})
31+
data.append(path)
32+
label.append(lb)
33+
self.class_idx_to_sample_idx[lb].append(idx)
34+
35+
self.data = data
36+
self.labels = label
37+
38+
self.transform = transforms.Compose([
39+
transforms.Resize(84),
40+
transforms.CenterCrop(84),
41+
transforms.ToTensor(),
42+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
43+
std=[0.229, 0.224, 0.225])
44+
])
45+
46+
def __len__(self):
47+
return len(self.data)
48+
49+
def __getitem__(self, i):
50+
path, label = self.data[i], self.labels[i]
51+
image = self.transform(Image.open(path).convert('RGB'))
52+
return image, label

models/__init__.py

Whitespace-only changes.
160 Bytes
Binary file not shown.
1.13 KB
Binary file not shown.

0 commit comments

Comments
 (0)