Skip to content

Commit c1c9825

Browse files
DeepC004Deepvfdev-5
authored
Issue#2878: Adds Siamese Network example (#2882)
* Issue#2878: Adds Siamese Network example * Update README.md * Updated code formatting * Updated more code formatting errors * Updated some more code formatting errors * Update dataset, loss function and minor fixes * Code refactoring and bottleneck removal * Added accuracy measures * added ignite.metrics.Accuracy + minor changes * code formatting * minor fixes --------- Co-authored-by: Deep <deepchordia004@gmail.com> Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 771564b commit c1c9825

File tree

3 files changed

+324
-0
lines changed

3 files changed

+324
-0
lines changed

examples/siamese_network/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Siamese Network example on MNIST dataset
2+
3+
This example is ported over from [pytorch/examples/siamese_network](https://github.com/pytorch/examples/tree/main/siamese_network)
4+
5+
Usage:
6+
7+
```
8+
pip install -r requirements.txt
9+
python siamese_network.py
10+
```
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
torch
2+
torchvision
3+
pytorch-ignite
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
import argparse
2+
3+
import numpy as np
4+
import torch
5+
import torch.nn as nn
6+
import torch.optim as optim
7+
import torchvision
8+
from torch.optim.lr_scheduler import StepLR
9+
from torch.utils.data import DataLoader, Dataset
10+
from torchvision import datasets
11+
12+
from ignite.contrib.handlers import ProgressBar
13+
from ignite.engine import Engine, Events
14+
from ignite.handlers.param_scheduler import LRScheduler
15+
from ignite.metrics import Accuracy, RunningAverage
16+
from ignite.utils import manual_seed
17+
18+
19+
class SiameseNetwork(nn.Module):
20+
# update Siamese Network implementation in accordance with the dataset
21+
"""
22+
Siamese network for image similarity estimation.
23+
The network is composed of two identical networks, one for each input.
24+
The output of each network is concatenated and passed to a linear layer.
25+
The output of the linear layer passed through a sigmoid function.
26+
`"FaceNet" <https://arxiv.org/pdf/1503.03832.pdf>`_ is a variant of the Siamese network.
27+
This implementation varies from FaceNet as we use the `ResNet-18` model from
28+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`
29+
as our feature extractor.
30+
In addition we use CIFAR10 dataset along with TripletMarginLoss
31+
"""
32+
33+
def __init__(self):
34+
super(SiameseNetwork, self).__init__()
35+
# get resnet model
36+
self.resnet = torchvision.models.resnet34(weights=None)
37+
fc_in_features = self.resnet.fc.in_features
38+
39+
# changing the FC layer of resnet model to a linear layer
40+
self.resnet.fc = nn.Identity()
41+
42+
# add linear layers to compare between the features of the two images
43+
self.fc = nn.Sequential(
44+
nn.Linear(fc_in_features, 256),
45+
nn.ReLU(inplace=True),
46+
nn.Linear(256, 10),
47+
nn.ReLU(inplace=True),
48+
)
49+
50+
# initialise relu activation
51+
self.relu = nn.ReLU()
52+
53+
# initialize the weights
54+
self.resnet.apply(self.init_weights)
55+
self.fc.apply(self.init_weights)
56+
57+
def init_weights(self, m):
58+
if isinstance(m, nn.Linear):
59+
nn.init.xavier_uniform_(m.weight)
60+
m.bias.data.fill_(0.01)
61+
62+
def forward_once(self, x):
63+
output = self.resnet(x)
64+
output = output.view(output.size()[0], -1)
65+
return output
66+
67+
def forward(self, input1, input2, input3):
68+
69+
# pass the input through resnet
70+
output1 = self.forward_once(input1)
71+
output2 = self.forward_once(input2)
72+
output3 = self.forward_once(input3)
73+
74+
# pass the output of resnet to sigmoid layer
75+
output1 = self.fc(output1)
76+
output2 = self.fc(output2)
77+
output3 = self.fc(output3)
78+
79+
return output1, output2, output3
80+
81+
82+
class MatcherDataset(Dataset):
83+
# following class implements data downloading and handles preprocessing
84+
def __init__(self, root, train, download=False):
85+
super(MatcherDataset, self).__init__()
86+
87+
# get CIFAR10 dataset
88+
self.dataset = datasets.CIFAR10(root, train=train, download=download)
89+
90+
# convert data from numpy array to Tensor
91+
self.data = torch.from_numpy(self.dataset.data)
92+
93+
# shift the dimensions of dataset to match the initial input layer dimensions
94+
self.data = torch.movedim(self.data, (0, 1, 2, 3), (0, 2, 3, 1))
95+
96+
# convert targets list to torch Tensor
97+
self.dataset.targets = torch.tensor(self.dataset.targets)
98+
99+
self.group_examples()
100+
101+
def group_examples(self):
102+
"""
103+
To ease the accessibility of data based on the class, we will use `group_examples` to group
104+
examples based on class. The data classes have already been mapped to numeric values and
105+
so are the target outputs for each training input
106+
107+
Every key in `grouped_examples` corresponds to a class in CIFAR10 dataset. For every key in
108+
`grouped_examples`, every value will conform to all of the indices for the CIFAR10
109+
dataset examples that correspond to that key.
110+
"""
111+
112+
# get the targets from CIFAR10 dataset
113+
np_arr = np.array(self.dataset.targets)
114+
115+
# group examples based on class
116+
self.grouped_examples = {}
117+
for i in range(0, 10):
118+
self.grouped_examples[i] = np.where((np_arr == i))[0]
119+
120+
def __len__(self):
121+
return self.data.shape[0]
122+
123+
def __getitem__(self, index):
124+
"""
125+
For every sample in the batch we select 3 images. First one is the anchor image
126+
which is the image obtained from the current index. We also obtain the label of
127+
anchor image.
128+
129+
Now we select two random images, one belonging to the same class as that of the
130+
anchor image (named as positive_image) and the other belonging to a different class
131+
than that of the anchor image (named as negative_image). We return the anchor image,
132+
positive image, negative image and anchor label.
133+
"""
134+
135+
# obtain the anchor image
136+
anchor_image = self.data[index].float()
137+
138+
# obtain the class label of the anchor image
139+
anchor_label = self.dataset.targets[index]
140+
anchor_label = int(anchor_label.item())
141+
142+
# find a label which is different from anchor_label
143+
labels = list(range(0, 10))
144+
labels.remove(anchor_label)
145+
neg_index = torch.randint(0, 9, (1,)).item()
146+
neg_label = labels[neg_index]
147+
148+
# get a random index from the range range of indices
149+
random_index = torch.randint(0, len(self.grouped_examples[anchor_label]), (1,)).item()
150+
151+
# get the index of image in actual data using the anchor label and random index
152+
positive_index = self.grouped_examples[anchor_label][random_index]
153+
154+
# choosing a random image using positive_index
155+
positive_image = self.data[positive_index].float()
156+
157+
# get a random index from the range range of indices
158+
random_index = torch.randint(0, len(self.grouped_examples[neg_label]), (1,)).item()
159+
160+
# get the index of image in actual data using the negative label and random index
161+
negative_index = self.grouped_examples[neg_label][random_index]
162+
163+
# choosing a random image using negative_index
164+
negative_image = self.data[negative_index].float()
165+
166+
return anchor_image, positive_image, negative_image, anchor_label
167+
168+
169+
def pairwise_distance(input1, input2):
170+
dist = input1 - input2
171+
dist = torch.pow(dist, 2)
172+
return dist
173+
174+
175+
def calculate_loss(input1, input2):
176+
output = pairwise_distance(input1, input2)
177+
loss = torch.sum(output, 1)
178+
loss = torch.sqrt(loss)
179+
return loss
180+
181+
182+
def run(args, model, device, optimizer, train_loader, test_loader, lr_scheduler):
183+
184+
# using Triplet Margin Loss
185+
criterion = nn.TripletMarginLoss(p=2, margin=2.8)
186+
187+
# define model training step
188+
def train_step(engine, batch):
189+
model.train()
190+
anchor_image, positive_image, negative_image, anchor_label = batch
191+
anchor_image = anchor_image.to(device)
192+
positive_image, negative_image = positive_image.to(device), negative_image.to(device)
193+
anchor_label = anchor_label.to(device)
194+
optimizer.zero_grad()
195+
anchor_out, positive_out, negative_out = model(anchor_image, positive_image, negative_image)
196+
loss = criterion(anchor_out, positive_out, negative_out)
197+
loss.backward()
198+
optimizer.step()
199+
return loss
200+
201+
# define model testing step
202+
def test_step(engine, batch):
203+
model.eval()
204+
with torch.no_grad():
205+
anchor_image, _, _, anchor_label = batch
206+
anchor_image = anchor_image.to(device)
207+
anchor_label = anchor_label.to(device)
208+
other_image = []
209+
other_label = []
210+
y_true = []
211+
for i in range(anchor_image.shape[0]):
212+
index = torch.randint(0, anchor_image.shape[0], (1,)).item()
213+
img = anchor_image[index]
214+
label = anchor_label[index]
215+
other_image.append(img)
216+
other_label.append(label)
217+
if anchor_label[i] == other_label[i]:
218+
y_true.append(1)
219+
else:
220+
y_true.append(0)
221+
other = torch.stack(other_image)
222+
other_label = torch.tensor(other_label)
223+
other, other_label = other.to(device), other_label.to(device)
224+
anchor_out, other_out, _ = model(anchor_image, other, other)
225+
test_loss = calculate_loss(anchor_out, other_out)
226+
y_pred = torch.where(test_loss < 3, 1, 0)
227+
y_true = torch.tensor(y_true)
228+
return [y_pred, y_true]
229+
230+
# create engines for trainer and evaluator
231+
trainer = Engine(train_step)
232+
evaluator = Engine(test_step)
233+
234+
# attach Running Average Loss metric to trainer and evaluator engines
235+
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
236+
Accuracy(output_transform=lambda x: x).attach(evaluator, "accuracy")
237+
238+
# attach progress bar to trainer with loss
239+
pbar1 = ProgressBar()
240+
pbar1.attach(trainer, metric_names=["loss"])
241+
242+
# attach progress bar to evaluator
243+
pbar2 = ProgressBar()
244+
pbar2.attach(evaluator)
245+
246+
# attach LR Scheduler to trainer engine
247+
trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)
248+
249+
# event handler triggers evauator at end of every epoch
250+
@trainer.on(Events.EPOCH_COMPLETED(every=args.log_interval))
251+
def test(engine):
252+
state = evaluator.run(test_loader)
253+
print(f'Test Accuracy: {state.metrics["accuracy"]}')
254+
255+
# run the trainer
256+
trainer.run(train_loader, max_epochs=args.epochs)
257+
258+
259+
def main():
260+
# adds training defaults and support for terminal arguments
261+
parser = argparse.ArgumentParser(description="PyTorch Siamese network Example")
262+
parser.add_argument(
263+
"--batch-size", type=int, default=256, metavar="N", help="input batch size for training (default: 64)"
264+
)
265+
parser.add_argument(
266+
"--test-batch-size", type=int, default=256, metavar="N", help="input batch size for testing (default: 1000)"
267+
)
268+
parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 14)")
269+
parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)")
270+
parser.add_argument(
271+
"--gamma", type=float, default=0.95, metavar="M", help="Learning rate step gamma (default: 0.7)"
272+
)
273+
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
274+
parser.add_argument("--no-mps", action="store_true", default=False, help="disables macOS GPU training")
275+
parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass")
276+
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
277+
parser.add_argument(
278+
"--log-interval",
279+
type=int,
280+
default=1,
281+
metavar="N",
282+
help="how many batches to wait before logging training status",
283+
)
284+
parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model")
285+
parser.add_argument("--num-workers", default=4, help="number of processes generating parallel batches")
286+
args = parser.parse_args()
287+
288+
# set manual seed
289+
manual_seed(args.seed)
290+
291+
# set device
292+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
293+
294+
# data loading
295+
train_dataset = MatcherDataset("../data", train=True, download=True)
296+
test_dataset = MatcherDataset("../data", train=False)
297+
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers)
298+
test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, num_workers=args.num_workers)
299+
300+
# set model parameters
301+
model = SiameseNetwork().to(device)
302+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
303+
scheduler = StepLR(optimizer, step_size=15, gamma=args.gamma)
304+
lr_scheduler = LRScheduler(scheduler)
305+
306+
# call run function
307+
run(args, model, device, optimizer, train_loader, test_loader, lr_scheduler)
308+
309+
310+
if __name__ == "__main__":
311+
main()

0 commit comments

Comments
 (0)