|
| 1 | +import argparse |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +import torch.optim as optim |
| 7 | +import torchvision |
| 8 | +from torch.optim.lr_scheduler import StepLR |
| 9 | +from torch.utils.data import DataLoader, Dataset |
| 10 | +from torchvision import datasets |
| 11 | + |
| 12 | +from ignite.contrib.handlers import ProgressBar |
| 13 | +from ignite.engine import Engine, Events |
| 14 | +from ignite.handlers.param_scheduler import LRScheduler |
| 15 | +from ignite.metrics import Accuracy, RunningAverage |
| 16 | +from ignite.utils import manual_seed |
| 17 | + |
| 18 | + |
| 19 | +class SiameseNetwork(nn.Module): |
| 20 | + # update Siamese Network implementation in accordance with the dataset |
| 21 | + """ |
| 22 | + Siamese network for image similarity estimation. |
| 23 | + The network is composed of two identical networks, one for each input. |
| 24 | + The output of each network is concatenated and passed to a linear layer. |
| 25 | + The output of the linear layer passed through a sigmoid function. |
| 26 | + `"FaceNet" <https://arxiv.org/pdf/1503.03832.pdf>`_ is a variant of the Siamese network. |
| 27 | + This implementation varies from FaceNet as we use the `ResNet-18` model from |
| 28 | + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>` |
| 29 | + as our feature extractor. |
| 30 | + In addition we use CIFAR10 dataset along with TripletMarginLoss |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__(self): |
| 34 | + super(SiameseNetwork, self).__init__() |
| 35 | + # get resnet model |
| 36 | + self.resnet = torchvision.models.resnet34(weights=None) |
| 37 | + fc_in_features = self.resnet.fc.in_features |
| 38 | + |
| 39 | + # changing the FC layer of resnet model to a linear layer |
| 40 | + self.resnet.fc = nn.Identity() |
| 41 | + |
| 42 | + # add linear layers to compare between the features of the two images |
| 43 | + self.fc = nn.Sequential( |
| 44 | + nn.Linear(fc_in_features, 256), |
| 45 | + nn.ReLU(inplace=True), |
| 46 | + nn.Linear(256, 10), |
| 47 | + nn.ReLU(inplace=True), |
| 48 | + ) |
| 49 | + |
| 50 | + # initialise relu activation |
| 51 | + self.relu = nn.ReLU() |
| 52 | + |
| 53 | + # initialize the weights |
| 54 | + self.resnet.apply(self.init_weights) |
| 55 | + self.fc.apply(self.init_weights) |
| 56 | + |
| 57 | + def init_weights(self, m): |
| 58 | + if isinstance(m, nn.Linear): |
| 59 | + nn.init.xavier_uniform_(m.weight) |
| 60 | + m.bias.data.fill_(0.01) |
| 61 | + |
| 62 | + def forward_once(self, x): |
| 63 | + output = self.resnet(x) |
| 64 | + output = output.view(output.size()[0], -1) |
| 65 | + return output |
| 66 | + |
| 67 | + def forward(self, input1, input2, input3): |
| 68 | + |
| 69 | + # pass the input through resnet |
| 70 | + output1 = self.forward_once(input1) |
| 71 | + output2 = self.forward_once(input2) |
| 72 | + output3 = self.forward_once(input3) |
| 73 | + |
| 74 | + # pass the output of resnet to sigmoid layer |
| 75 | + output1 = self.fc(output1) |
| 76 | + output2 = self.fc(output2) |
| 77 | + output3 = self.fc(output3) |
| 78 | + |
| 79 | + return output1, output2, output3 |
| 80 | + |
| 81 | + |
| 82 | +class MatcherDataset(Dataset): |
| 83 | + # following class implements data downloading and handles preprocessing |
| 84 | + def __init__(self, root, train, download=False): |
| 85 | + super(MatcherDataset, self).__init__() |
| 86 | + |
| 87 | + # get CIFAR10 dataset |
| 88 | + self.dataset = datasets.CIFAR10(root, train=train, download=download) |
| 89 | + |
| 90 | + # convert data from numpy array to Tensor |
| 91 | + self.data = torch.from_numpy(self.dataset.data) |
| 92 | + |
| 93 | + # shift the dimensions of dataset to match the initial input layer dimensions |
| 94 | + self.data = torch.movedim(self.data, (0, 1, 2, 3), (0, 2, 3, 1)) |
| 95 | + |
| 96 | + # convert targets list to torch Tensor |
| 97 | + self.dataset.targets = torch.tensor(self.dataset.targets) |
| 98 | + |
| 99 | + self.group_examples() |
| 100 | + |
| 101 | + def group_examples(self): |
| 102 | + """ |
| 103 | + To ease the accessibility of data based on the class, we will use `group_examples` to group |
| 104 | + examples based on class. The data classes have already been mapped to numeric values and |
| 105 | + so are the target outputs for each training input |
| 106 | +
|
| 107 | + Every key in `grouped_examples` corresponds to a class in CIFAR10 dataset. For every key in |
| 108 | + `grouped_examples`, every value will conform to all of the indices for the CIFAR10 |
| 109 | + dataset examples that correspond to that key. |
| 110 | + """ |
| 111 | + |
| 112 | + # get the targets from CIFAR10 dataset |
| 113 | + np_arr = np.array(self.dataset.targets) |
| 114 | + |
| 115 | + # group examples based on class |
| 116 | + self.grouped_examples = {} |
| 117 | + for i in range(0, 10): |
| 118 | + self.grouped_examples[i] = np.where((np_arr == i))[0] |
| 119 | + |
| 120 | + def __len__(self): |
| 121 | + return self.data.shape[0] |
| 122 | + |
| 123 | + def __getitem__(self, index): |
| 124 | + """ |
| 125 | + For every sample in the batch we select 3 images. First one is the anchor image |
| 126 | + which is the image obtained from the current index. We also obtain the label of |
| 127 | + anchor image. |
| 128 | +
|
| 129 | + Now we select two random images, one belonging to the same class as that of the |
| 130 | + anchor image (named as positive_image) and the other belonging to a different class |
| 131 | + than that of the anchor image (named as negative_image). We return the anchor image, |
| 132 | + positive image, negative image and anchor label. |
| 133 | + """ |
| 134 | + |
| 135 | + # obtain the anchor image |
| 136 | + anchor_image = self.data[index].float() |
| 137 | + |
| 138 | + # obtain the class label of the anchor image |
| 139 | + anchor_label = self.dataset.targets[index] |
| 140 | + anchor_label = int(anchor_label.item()) |
| 141 | + |
| 142 | + # find a label which is different from anchor_label |
| 143 | + labels = list(range(0, 10)) |
| 144 | + labels.remove(anchor_label) |
| 145 | + neg_index = torch.randint(0, 9, (1,)).item() |
| 146 | + neg_label = labels[neg_index] |
| 147 | + |
| 148 | + # get a random index from the range range of indices |
| 149 | + random_index = torch.randint(0, len(self.grouped_examples[anchor_label]), (1,)).item() |
| 150 | + |
| 151 | + # get the index of image in actual data using the anchor label and random index |
| 152 | + positive_index = self.grouped_examples[anchor_label][random_index] |
| 153 | + |
| 154 | + # choosing a random image using positive_index |
| 155 | + positive_image = self.data[positive_index].float() |
| 156 | + |
| 157 | + # get a random index from the range range of indices |
| 158 | + random_index = torch.randint(0, len(self.grouped_examples[neg_label]), (1,)).item() |
| 159 | + |
| 160 | + # get the index of image in actual data using the negative label and random index |
| 161 | + negative_index = self.grouped_examples[neg_label][random_index] |
| 162 | + |
| 163 | + # choosing a random image using negative_index |
| 164 | + negative_image = self.data[negative_index].float() |
| 165 | + |
| 166 | + return anchor_image, positive_image, negative_image, anchor_label |
| 167 | + |
| 168 | + |
| 169 | +def pairwise_distance(input1, input2): |
| 170 | + dist = input1 - input2 |
| 171 | + dist = torch.pow(dist, 2) |
| 172 | + return dist |
| 173 | + |
| 174 | + |
| 175 | +def calculate_loss(input1, input2): |
| 176 | + output = pairwise_distance(input1, input2) |
| 177 | + loss = torch.sum(output, 1) |
| 178 | + loss = torch.sqrt(loss) |
| 179 | + return loss |
| 180 | + |
| 181 | + |
| 182 | +def run(args, model, device, optimizer, train_loader, test_loader, lr_scheduler): |
| 183 | + |
| 184 | + # using Triplet Margin Loss |
| 185 | + criterion = nn.TripletMarginLoss(p=2, margin=2.8) |
| 186 | + |
| 187 | + # define model training step |
| 188 | + def train_step(engine, batch): |
| 189 | + model.train() |
| 190 | + anchor_image, positive_image, negative_image, anchor_label = batch |
| 191 | + anchor_image = anchor_image.to(device) |
| 192 | + positive_image, negative_image = positive_image.to(device), negative_image.to(device) |
| 193 | + anchor_label = anchor_label.to(device) |
| 194 | + optimizer.zero_grad() |
| 195 | + anchor_out, positive_out, negative_out = model(anchor_image, positive_image, negative_image) |
| 196 | + loss = criterion(anchor_out, positive_out, negative_out) |
| 197 | + loss.backward() |
| 198 | + optimizer.step() |
| 199 | + return loss |
| 200 | + |
| 201 | + # define model testing step |
| 202 | + def test_step(engine, batch): |
| 203 | + model.eval() |
| 204 | + with torch.no_grad(): |
| 205 | + anchor_image, _, _, anchor_label = batch |
| 206 | + anchor_image = anchor_image.to(device) |
| 207 | + anchor_label = anchor_label.to(device) |
| 208 | + other_image = [] |
| 209 | + other_label = [] |
| 210 | + y_true = [] |
| 211 | + for i in range(anchor_image.shape[0]): |
| 212 | + index = torch.randint(0, anchor_image.shape[0], (1,)).item() |
| 213 | + img = anchor_image[index] |
| 214 | + label = anchor_label[index] |
| 215 | + other_image.append(img) |
| 216 | + other_label.append(label) |
| 217 | + if anchor_label[i] == other_label[i]: |
| 218 | + y_true.append(1) |
| 219 | + else: |
| 220 | + y_true.append(0) |
| 221 | + other = torch.stack(other_image) |
| 222 | + other_label = torch.tensor(other_label) |
| 223 | + other, other_label = other.to(device), other_label.to(device) |
| 224 | + anchor_out, other_out, _ = model(anchor_image, other, other) |
| 225 | + test_loss = calculate_loss(anchor_out, other_out) |
| 226 | + y_pred = torch.where(test_loss < 3, 1, 0) |
| 227 | + y_true = torch.tensor(y_true) |
| 228 | + return [y_pred, y_true] |
| 229 | + |
| 230 | + # create engines for trainer and evaluator |
| 231 | + trainer = Engine(train_step) |
| 232 | + evaluator = Engine(test_step) |
| 233 | + |
| 234 | + # attach Running Average Loss metric to trainer and evaluator engines |
| 235 | + RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") |
| 236 | + Accuracy(output_transform=lambda x: x).attach(evaluator, "accuracy") |
| 237 | + |
| 238 | + # attach progress bar to trainer with loss |
| 239 | + pbar1 = ProgressBar() |
| 240 | + pbar1.attach(trainer, metric_names=["loss"]) |
| 241 | + |
| 242 | + # attach progress bar to evaluator |
| 243 | + pbar2 = ProgressBar() |
| 244 | + pbar2.attach(evaluator) |
| 245 | + |
| 246 | + # attach LR Scheduler to trainer engine |
| 247 | + trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) |
| 248 | + |
| 249 | + # event handler triggers evauator at end of every epoch |
| 250 | + @trainer.on(Events.EPOCH_COMPLETED(every=args.log_interval)) |
| 251 | + def test(engine): |
| 252 | + state = evaluator.run(test_loader) |
| 253 | + print(f'Test Accuracy: {state.metrics["accuracy"]}') |
| 254 | + |
| 255 | + # run the trainer |
| 256 | + trainer.run(train_loader, max_epochs=args.epochs) |
| 257 | + |
| 258 | + |
| 259 | +def main(): |
| 260 | + # adds training defaults and support for terminal arguments |
| 261 | + parser = argparse.ArgumentParser(description="PyTorch Siamese network Example") |
| 262 | + parser.add_argument( |
| 263 | + "--batch-size", type=int, default=256, metavar="N", help="input batch size for training (default: 64)" |
| 264 | + ) |
| 265 | + parser.add_argument( |
| 266 | + "--test-batch-size", type=int, default=256, metavar="N", help="input batch size for testing (default: 1000)" |
| 267 | + ) |
| 268 | + parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 14)") |
| 269 | + parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)") |
| 270 | + parser.add_argument( |
| 271 | + "--gamma", type=float, default=0.95, metavar="M", help="Learning rate step gamma (default: 0.7)" |
| 272 | + ) |
| 273 | + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") |
| 274 | + parser.add_argument("--no-mps", action="store_true", default=False, help="disables macOS GPU training") |
| 275 | + parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass") |
| 276 | + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") |
| 277 | + parser.add_argument( |
| 278 | + "--log-interval", |
| 279 | + type=int, |
| 280 | + default=1, |
| 281 | + metavar="N", |
| 282 | + help="how many batches to wait before logging training status", |
| 283 | + ) |
| 284 | + parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model") |
| 285 | + parser.add_argument("--num-workers", default=4, help="number of processes generating parallel batches") |
| 286 | + args = parser.parse_args() |
| 287 | + |
| 288 | + # set manual seed |
| 289 | + manual_seed(args.seed) |
| 290 | + |
| 291 | + # set device |
| 292 | + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| 293 | + |
| 294 | + # data loading |
| 295 | + train_dataset = MatcherDataset("../data", train=True, download=True) |
| 296 | + test_dataset = MatcherDataset("../data", train=False) |
| 297 | + train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers) |
| 298 | + test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, num_workers=args.num_workers) |
| 299 | + |
| 300 | + # set model parameters |
| 301 | + model = SiameseNetwork().to(device) |
| 302 | + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) |
| 303 | + scheduler = StepLR(optimizer, step_size=15, gamma=args.gamma) |
| 304 | + lr_scheduler = LRScheduler(scheduler) |
| 305 | + |
| 306 | + # call run function |
| 307 | + run(args, model, device, optimizer, train_loader, test_loader, lr_scheduler) |
| 308 | + |
| 309 | + |
| 310 | +if __name__ == "__main__": |
| 311 | + main() |
0 commit comments