diff --git a/models/__init__.py b/models/__init__.py index 6b765dde..289de91b 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,4 +1,4 @@ -from .models import ArcMarginModel -from .models import ResNet -from .models import IRBlock -from .models import SEBlock \ No newline at end of file +from .arcface_models import ArcMarginModel +from .arcface_models import ResNet +from .arcface_models import IRBlock +from .arcface_models import SEBlock \ No newline at end of file diff --git a/models/arcface_models.py b/models/arcface_models.py new file mode 100644 index 00000000..39a6ac54 --- /dev/null +++ b/models/arcface_models.py @@ -0,0 +1,163 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Parameter +from .config import device, num_classes + + + +class SEBlock(nn.Module): + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.PReLU(), + nn.Linear(channel // reduction, channel), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class IRBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): + super(IRBlock, self).__init__() + self.bn0 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2d(inplanes) + self.prelu = nn.PReLU() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.use_se: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.prelu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, use_se=True): + self.inplanes = 64 + self.use_se = use_se + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.prelu = nn.PReLU() + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.bn2 = nn.BatchNorm2d(512) + self.dropout = nn.Dropout() + self.fc = nn.Linear(512 * 7 * 7, 512) + self.bn3 = nn.BatchNorm1d(512) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) + self.inplanes = planes + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, use_se=self.use_se)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.bn2(x) + x = self.dropout(x) + # feature = x + x = x.view(x.size(0), -1) + x = self.fc(x) + x = self.bn3(x) + + return x + + +class ArcMarginModel(nn.Module): + def __init__(self, args): + super(ArcMarginModel, self).__init__() + + self.weight = Parameter(torch.FloatTensor(num_classes, args.emb_size)) + nn.init.xavier_uniform_(self.weight) + + self.easy_margin = args.easy_margin + self.m = args.margin_m + self.s = args.margin_s + + self.cos_m = math.cos(self.m) + self.sin_m = math.sin(self.m) + self.th = math.cos(math.pi - self.m) + self.mm = math.sin(math.pi - self.m) * self.m + + def forward(self, input, label): + x = F.normalize(input) + W = F.normalize(self.weight) + cosine = F.linear(x, W) + sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) + phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m) + if self.easy_margin: + phi = torch.where(cosine > 0, phi, cosine) + else: + phi = torch.where(cosine > self.th, phi, cosine - self.mm) + one_hot = torch.zeros(cosine.size(), device=device) + one_hot.scatter_(1, label.view(-1, 1).long(), 1) + output = (one_hot * phi) + ((1.0 - one_hot) * cosine) + output *= self.s + return output \ No newline at end of file diff --git a/models/fs_model.py b/models/fs_model.py index 548dd922..2ff18d1e 100644 --- a/models/fs_model.py +++ b/models/fs_model.py @@ -64,8 +64,8 @@ def initialize(self, opt): # Id network netArc_checkpoint = opt.Arc_path - netArc_checkpoint = torch.load(netArc_checkpoint) - self.netArc = netArc_checkpoint['model'].module + netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu")) + self.netArc = netArc_checkpoint self.netArc = self.netArc.to(device) self.netArc.eval() diff --git a/options/test_options.py b/options/test_options.py index 5b43c9a7..c36849bb 100644 --- a/options/test_options.py +++ b/options/test_options.py @@ -22,12 +22,12 @@ def initialize(self): self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") - self.parser.add_argument("--Arc_path", type=str, default='models/BEST_checkpoint.tar', help="run ONNX model via TRT") - self.parser.add_argument("--pic_a_path", type=str, default='./crop_224/gdg.jpg', help="Person who provides identity information") - self.parser.add_argument("--pic_b_path", type=str, default='./crop_224/zrf.jpg', help="Person who provides information other than their identity") + self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT") + self.parser.add_argument("--pic_a_path", type=str, default='G:/swap_data/ID/elon-musk-hero-image.jpeg', help="Person who provides identity information") + self.parser.add_argument("--pic_b_path", type=str, default='G:/swap_data/ID/bengio.jpg', help="Person who provides information other than their identity") self.parser.add_argument("--pic_specific_path", type=str, default='./crop_224/zrf.jpg', help="The specific person to be swapped") self.parser.add_argument("--multisepcific_dir", type=str, default='./demo_file/multispecific', help="Dir for multi specific") - self.parser.add_argument("--video_path", type=str, default='./demo_file/multi_people_1080p.mp4', help="path for the video to swap") + self.parser.add_argument("--video_path", type=str, default='G:/swap_data/video/HSB_Demo_Trim.mp4', help="path for the video to swap") self.parser.add_argument("--temp_path", type=str, default='./temp_results', help="path to save temporarily images") self.parser.add_argument("--output_path", type=str, default='./output/', help="results path") self.parser.add_argument('--id_thres', type=float, default=0.03, help='how many test images to run') diff --git a/test_one_image.py b/test_one_image.py index 4eabd8ee..88e2f7fc 100644 --- a/test_one_image.py +++ b/test_one_image.py @@ -83,4 +83,4 @@ def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0 output = output*255 - cv2.imwrite(opt.output_path + 'result.jpg',output) \ No newline at end of file + cv2.imwrite(opt.output_path + 'result.jpg', output) \ No newline at end of file