You need to alert your model.
-
add
label
argument to model's forward function -
add
label
as classifier's input -
I recommend you to add
if else
into foward function, so that when you want to train without Arcface, you don't need to change codeclass Your_Model(nn.Module): def __init__(self): .............. self.classifier = .... def forward(self, x, label=None): ................... output = self.classifier(x) if label is None else self.classifier(x, label) return output
-
Training phase
- Replace last fully-connected layer of your model with
ArcFace
Take ResNet18 for example:
from arcface import ArcFace
from torchvision import models
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = ArcFace(num_ftrs, num_classes, m=0.5)
- In the training loop, add
label
as model's input argument.
Generally, loss_func isnn.CrossEntropyLoss()
for img, label in torch_dataloader:
if use_arcface:
output = model(img, label)
else:
output = model(img)
loss = loss_func(output, label)
......
- Evaluation/Inference phase
Don't need to alert anything
- If your training is hard to converge, you can set
m
to smaller(close to zero). - If your dataset is difficult, i recommend you to set
m
to smaller. - Acface's argument
s
is 64 in the original paper, but whens
become bigger the gradient is enlarged. So I sets
to 8 by default.