Skip to content

shyhyawJou/ArcFace-Pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

35 Commits
 
 
 
 
 
 

Repository files navigation

Overview

ArcFace

Usage

You need to alert your model.

  • add label argument to model's forward function

  • add label as classifier's input

  • I recommend you to add if else into foward function, so that when you want to train without Arcface, you don't need to change code

    class Your_Model(nn.Module):
         def __init__(self):
             ..............
             self.classifier = ....
    
         def forward(self, x, label=None):
             ...................
             output = self.classifier(x) if label is None else self.classifier(x, label)
             return output
    
  • Training phase

  1. Replace last fully-connected layer of your model with ArcFace
    Take ResNet18 for example:
from arcface import ArcFace
from torchvision import models

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = ArcFace(num_ftrs, num_classes, m=0.5)
  1. In the training loop, add label as model's input argument.
    Generally, loss_func is nn.CrossEntropyLoss()
for img, label in torch_dataloader:
    if use_arcface:
        output = model(img, label)
    else:
        output = model(img)
    loss = loss_func(output, label)
    ......
  • Evaluation/Inference phase
    Don't need to alert anything

Note

  • If your training is hard to converge, you can set m to smaller(close to zero).
  • If your dataset is difficult, i recommend you to set m to smaller.
  • Acface's argument s is 64 in the original paper, but when s become bigger the gradient is enlarged. So I set s to 8 by default.

About

Pytorch inplementation of ArcFace

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages