forked from neuralchen/SimSwap
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dd1ecdd
commit d4bf5f9
Showing
5 changed files
with
174 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .models import ArcMarginModel | ||
from .models import ResNet | ||
from .models import IRBlock | ||
from .models import SEBlock | ||
from .arcface_models import ArcMarginModel | ||
from .arcface_models import ResNet | ||
from .arcface_models import IRBlock | ||
from .arcface_models import SEBlock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import math | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
from torch.nn import Parameter | ||
from .config import device, num_classes | ||
|
||
|
||
|
||
class SEBlock(nn.Module): | ||
def __init__(self, channel, reduction=16): | ||
super(SEBlock, self).__init__() | ||
self.avg_pool = nn.AdaptiveAvgPool2d(1) | ||
self.fc = nn.Sequential( | ||
nn.Linear(channel, channel // reduction), | ||
nn.PReLU(), | ||
nn.Linear(channel // reduction, channel), | ||
nn.Sigmoid() | ||
) | ||
|
||
def forward(self, x): | ||
b, c, _, _ = x.size() | ||
y = self.avg_pool(x).view(b, c) | ||
y = self.fc(y).view(b, c, 1, 1) | ||
return x * y | ||
|
||
|
||
class IRBlock(nn.Module): | ||
expansion = 1 | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): | ||
super(IRBlock, self).__init__() | ||
self.bn0 = nn.BatchNorm2d(inplanes) | ||
self.conv1 = conv3x3(inplanes, inplanes) | ||
self.bn1 = nn.BatchNorm2d(inplanes) | ||
self.prelu = nn.PReLU() | ||
self.conv2 = conv3x3(inplanes, planes, stride) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.downsample = downsample | ||
self.stride = stride | ||
self.use_se = use_se | ||
if self.use_se: | ||
self.se = SEBlock(planes) | ||
|
||
def forward(self, x): | ||
residual = x | ||
out = self.bn0(x) | ||
out = self.conv1(out) | ||
out = self.bn1(out) | ||
out = self.prelu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
if self.use_se: | ||
out = self.se(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
out = self.prelu(out) | ||
|
||
return out | ||
|
||
|
||
class ResNet(nn.Module): | ||
|
||
def __init__(self, block, layers, use_se=True): | ||
self.inplanes = 64 | ||
self.use_se = use_se | ||
super(ResNet, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False) | ||
self.bn1 = nn.BatchNorm2d(64) | ||
self.prelu = nn.PReLU() | ||
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) | ||
self.layer1 = self._make_layer(block, 64, layers[0]) | ||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | ||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | ||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | ||
self.bn2 = nn.BatchNorm2d(512) | ||
self.dropout = nn.Dropout() | ||
self.fc = nn.Linear(512 * 7 * 7, 512) | ||
self.bn3 = nn.BatchNorm1d(512) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.xavier_normal_(m.weight) | ||
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.Linear): | ||
nn.init.xavier_normal_(m.weight) | ||
nn.init.constant_(m.bias, 0) | ||
|
||
def _make_layer(self, block, planes, blocks, stride=1): | ||
downsample = None | ||
if stride != 1 or self.inplanes != planes * block.expansion: | ||
downsample = nn.Sequential( | ||
nn.Conv2d(self.inplanes, planes * block.expansion, | ||
kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm2d(planes * block.expansion), | ||
) | ||
|
||
layers = [] | ||
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) | ||
self.inplanes = planes | ||
for i in range(1, blocks): | ||
layers.append(block(self.inplanes, planes, use_se=self.use_se)) | ||
|
||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.prelu(x) | ||
x = self.maxpool(x) | ||
|
||
x = self.layer1(x) | ||
x = self.layer2(x) | ||
x = self.layer3(x) | ||
x = self.layer4(x) | ||
|
||
x = self.bn2(x) | ||
x = self.dropout(x) | ||
# feature = x | ||
x = x.view(x.size(0), -1) | ||
x = self.fc(x) | ||
x = self.bn3(x) | ||
|
||
return x | ||
|
||
|
||
class ArcMarginModel(nn.Module): | ||
def __init__(self, args): | ||
super(ArcMarginModel, self).__init__() | ||
|
||
self.weight = Parameter(torch.FloatTensor(num_classes, args.emb_size)) | ||
nn.init.xavier_uniform_(self.weight) | ||
|
||
self.easy_margin = args.easy_margin | ||
self.m = args.margin_m | ||
self.s = args.margin_s | ||
|
||
self.cos_m = math.cos(self.m) | ||
self.sin_m = math.sin(self.m) | ||
self.th = math.cos(math.pi - self.m) | ||
self.mm = math.sin(math.pi - self.m) * self.m | ||
|
||
def forward(self, input, label): | ||
x = F.normalize(input) | ||
W = F.normalize(self.weight) | ||
cosine = F.linear(x, W) | ||
sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) | ||
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m) | ||
if self.easy_margin: | ||
phi = torch.where(cosine > 0, phi, cosine) | ||
else: | ||
phi = torch.where(cosine > self.th, phi, cosine - self.mm) | ||
one_hot = torch.zeros(cosine.size(), device=device) | ||
one_hot.scatter_(1, label.view(-1, 1).long(), 1) | ||
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) | ||
output *= self.s | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters