-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
136 lines (111 loc) · 4.47 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import numpy as np
import pandas as pd
import os
import cv2
import timm
from albumentations.pytorch.transforms import ToTensorV2
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam
import math
from configuration import *
class ArcMarginProduct(nn.Module):
def __init__(self, opt, in_features, out_features, scale=30.0, margin=0.50, easy_margin=False, ls_eps=0.0):
super(ArcMarginProduct, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.scale = scale
self.margin = margin
self.ls_eps = ls_eps # label smoothing
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m = math.cos(margin)
self.sin_m = math.sin(margin)
self.th = math.cos(math.pi - margin)
self.mm = math.sin(math.pi - margin) * margin
self.opt = opt
def forward(self, input, label):
# --------------------------- cos(theta) & phi(theta) ---------------------------
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
# --------------------------- convert label to one-hot ---------------------------
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
one_hot = torch.zeros(cosine.size(), device=self.opt.DEVICE)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
if self.ls_eps > 0:
one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.scale
return output, nn.CrossEntropyLoss()(output,label)
class KfashionModel(nn.Module):
def __init__(
self,
opt,
margin = Config.MARGIN,
scale = Config.SCALE,
use_fc = True,
pretrained = True):
self.n_classes = opt.CLASSES
self.model_name = opt.MODEL_NAME
self.fc_dim = opt.FC_DIM
super(KfashionModel,self).__init__()
self.backbone = timm.create_model(self.model_name, pretrained=pretrained)
if self.model_name == 'resnext50_32x4d':
final_in_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
self.backbone.global_pool = nn.Identity()
elif 'efficientnet' in self.model_name:
final_in_features = self.backbone.classifier.in_features
self.backbone.classifier = nn.Identity()
self.backbone.global_pool = nn.Identity()
elif 'nfnet' in self.model_name:
final_in_features = self.backbone.head.fc.in_features
self.backbone.head.fc = nn.Identity()
self.backbone.head.global_pool = nn.Identity()
self.pooling = nn.AdaptiveAvgPool2d(1)
self.use_fc = use_fc
self.typ = opt.TYP
if use_fc:
self.dropout = nn.Dropout(p=0.0)
self.fc = nn.Linear(final_in_features, self.fc_dim)
self.bn = nn.BatchNorm1d(self.fc_dim)
self._init_params()
final_in_features = self.fc_dim
self.final = ArcMarginProduct(
opt,
final_in_features,
self.n_classes,
scale = scale,
margin = margin,
easy_margin = False,
ls_eps = 0.0
)
def _init_params(self):
nn.init.xavier_normal_(self.fc.weight)
nn.init.constant_(self.fc.bias, 0)
nn.init.constant_(self.bn.weight, 1)
nn.init.constant_(self.bn.bias, 0)
def forward(self, image, label):
feature = self.extract_feat(image)
if self.typ == 'train':
logits = self.final(feature,label)
return logits
else:
return feature
def extract_feat(self, x):
batch_size = x.shape[0]
x = self.backbone(x)
x = self.pooling(x).view(batch_size, -1)
if self.use_fc:
x = self.dropout(x)
x = self.fc(x)
x = self.bn(x)
return x