-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathseq_cifar100.py
140 lines (110 loc) · 5.58 KB
/
seq_cifar100.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
from backbone.ResNet import ResNet18
import torch.nn.functional as F
from utils.conf import base_path
from PIL import Image
from datasets.utils.validation import get_train_val
from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders
from datasets.utils.continual_dataset import get_previous_train_loader
from datasets.transforms.denormalization import DeNormalize
import torch
from typing import Tuple
import numpy as np
class MyCIFAR100(CIFAR100):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
super(MyCIFAR100, self).__init__(root, train, transform, target_transform, download)
def __getitem__(self, index):
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
original_img = img.copy()
not_aug_img = self.not_aug_transform(original_img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if hasattr(self, 'logits'):
return img, target, not_aug_img, self.logits[index]
return img, target, not_aug_img
class SequentialCIFAR100(ContinualDataset):
NAME = 'seq-cifar100'
SETTING = 'class-il'
N_CLASSES_PER_TASK = 10
N_TASKS = 10
TRANSFORM = transforms.Compose(
[transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4865, 0.4409),
(0.2673, 0.2564, 0.2761))])
def get_data_loaders(self, nomask=False):
transform = self.TRANSFORM
test_transform = transforms.Compose(
[transforms.ToTensor(), self.get_normalization_transform()])
train_dataset = MyCIFAR100(base_path() + 'CIFAR100', train=True,
download=True, transform=transform)
if self.args.validation:
train_dataset, test_dataset = get_train_val(train_dataset,
test_transform, self.NAME)
else:
test_dataset = CIFAR100(base_path() + 'CIFAR100', train=False,
download=True, transform=test_transform)
if not nomask:
if isinstance(train_dataset.targets, list):
train_dataset.targets = torch.tensor(train_dataset.targets, dtype=torch.long)
if isinstance(test_dataset.targets, list):
test_dataset.targets = torch.tensor(test_dataset.targets, dtype=torch.long)
train, test = store_masked_loaders(train_dataset, test_dataset, self)
return train, test
else:
train_loader = DataLoader(train_dataset,
batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset,
batch_size=32, shuffle=False, num_workers=2)
return train_loader, test_loader
def get_joint_loaders(self, nomask=False):
return self.get_data_loaders(nomask=True)
def not_aug_dataloader(self, args, batch_size):
if hasattr(args, 'iba') and args.iba:
transform = transforms.Compose([transforms.ToTensor()])
else:
transform = transforms.Compose([transforms.ToTensor(),
self.get_normalization_transform()])
train_dataset = MyCIFAR100(base_path() + 'CIFAR100', train=True,
download=True, transform=transform)
if isinstance(train_dataset.targets, list):
train_dataset.targets = torch.tensor(train_dataset.targets, dtype=torch.long)
train_mask = np.logical_and(np.array(train_dataset.targets) >= (self.i - 1) * self.N_CLASSES_PER_TASK
, np.array(train_dataset.targets) < self.i * self.N_CLASSES_PER_TASK)
train_dataset.data = train_dataset.data[train_mask]
train_dataset.targets = np.array(train_dataset.targets)[train_mask]
train_loader = get_previous_train_loader(train_dataset, batch_size, self)
return train_loader
@staticmethod
def get_transform():
transform = transforms.Compose(
[transforms.ToPILImage(), SequentialCIFAR100.TRANSFORM])
return transform
@staticmethod
def get_backbone():
return ResNet18(SequentialCIFAR100.N_CLASSES_PER_TASK
* SequentialCIFAR100.N_TASKS)
@staticmethod
def get_loss():
return F.cross_entropy
@staticmethod
def get_normalization_transform():
transform = transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761))
return transform
@staticmethod
def get_denormalization_transform():
transform = DeNormalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761))
return transform