-
Notifications
You must be signed in to change notification settings - Fork 12
/
open_world_cifar.py
123 lines (106 loc) · 4.6 KB
/
open_world_cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
import torchvision
import torch.utils.data as data
from torchvision import transforms
import itertools
from torch.utils.data.sampler import Sampler
class OPENWORLDCIFAR100(torchvision.datasets.CIFAR100):
def __init__(self, root, labeled=True, labeled_num=50, labeled_ratio=0.5, rand_number=0, transform=None, target_transform=None,
download=False, unlabeled_idxs=None):
super(OPENWORLDCIFAR100, self).__init__(root, True, transform, target_transform, download)
downloaded_list = self.train_list
self.data = []
self.targets = []
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
labeled_classes = range(labeled_num)
np.random.seed(rand_number)
if labeled:
self.labeled_idxs, self.unlabeled_idxs = self.get_labeled_index(labeled_classes, labeled_ratio)
self.shrink_data(self.labeled_idxs)
else:
self.shrink_data(unlabeled_idxs)
def get_labeled_index(self, labeled_classes, labeled_ratio):
labeled_idxs = []
unlabeled_idxs = []
for idx, label in enumerate(self.targets):
if label in labeled_classes and np.random.rand() < labeled_ratio:
labeled_idxs.append(idx)
else:
unlabeled_idxs.append(idx)
return labeled_idxs, unlabeled_idxs
def shrink_data(self, idxs):
targets = np.array(self.targets)
self.targets = targets[idxs].tolist()
self.data = self.data[idxs, ...]
class OPENWORLDCIFAR10(torchvision.datasets.CIFAR10):
def __init__(self, root, labeled=True, labeled_num=5, labeled_ratio=0.5, rand_number=0, transform=None, target_transform=None,
download=False, unlabeled_idxs=None):
super(OPENWORLDCIFAR10, self).__init__(root, True, transform, target_transform, download)
downloaded_list = self.train_list
self.data = []
self.targets = []
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
labeled_classes = range(labeled_num)
np.random.seed(rand_number)
if labeled:
self.labeled_idxs, self.unlabeled_idxs = self.get_labeled_index(labeled_classes, labeled_ratio)
self.shrink_data(self.labeled_idxs)
else:
self.shrink_data(unlabeled_idxs)
def get_labeled_index(self, labeled_classes, labeled_ratio):
labeled_idxs = []
unlabeled_idxs = []
for idx, label in enumerate(self.targets):
if label in labeled_classes and np.random.rand() < labeled_ratio:
labeled_idxs.append(idx)
else:
unlabeled_idxs.append(idx)
return labeled_idxs, unlabeled_idxs
def shrink_data(self, idxs):
targets = np.array(self.targets)
self.targets = targets[idxs].tolist()
self.data = self.data[idxs, ...]
# Dictionary of transforms
dict_transform = {
'cifar_train': transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
]),
'cifar_test': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
}