-
Notifications
You must be signed in to change notification settings - Fork 426
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
F
committed
Jul 30, 2018
0 parents
commit 451f6cb
Showing
51 changed files
with
5,963 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,258 @@ | ||
from data.data_pipe import de_preprocess, get_train_loader, get_val_data | ||
from model import Backbone, Arcface, MobileFaceNet, Am_softmax, l2_norm | ||
from verifacation import evaluate | ||
import torch | ||
from torch import optim | ||
import numpy as np | ||
from tqdm import tqdm | ||
from tensorboardX import SummaryWriter | ||
from matplotlib import pyplot as plt | ||
plt.switch_backend('agg') | ||
from utils import get_time, gen_plot, hflip_batch | ||
from PIL import Image | ||
from torchvision import transforms as trans | ||
import math | ||
import bcolz | ||
|
||
class face_learner(object): | ||
def __init__(self, conf, inference=False): | ||
if conf.use_mobilfacenet: | ||
self.model = MobileFaceNet(conf.embedding_size).to(conf.device) | ||
print('MobileFaceNet model generated') | ||
else: | ||
self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).to(conf.device) | ||
print('{}_{} model generated'.format(conf.net_mode, conf.net_depth)) | ||
|
||
if not inference: | ||
self.milestones = conf.milestones | ||
self.loader, self.class_num, self.test_transform = get_train_loader(conf) | ||
|
||
self.writer = SummaryWriter(conf.log_path) | ||
self.step = 0 | ||
self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device) | ||
|
||
print('two model heads generated') | ||
|
||
paras_only_bn = [] | ||
paras_wo_bn = [] | ||
for para in self.model.named_parameters(): | ||
if 'bn' in para[0]: | ||
paras_only_bn.append(para[1]) | ||
else: | ||
paras_wo_bn.append(para[1]) | ||
if conf.use_mobilfacenet: | ||
self.optimizer = optim.SGD([ | ||
{'params': paras_wo_bn[:-1], 'weight_decay': 4e-5}, | ||
{'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4}, | ||
{'params': paras_only_bn} | ||
], lr = conf.lr, momentum = conf.momentum) | ||
else: | ||
self.optimizer = optim.SGD([ | ||
{'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4}, | ||
{'params': paras_only_bn} | ||
], lr = conf.lr, momentum = conf.momentum) | ||
|
||
print('optimizers generated') | ||
self.board_loss_every = len(self.loader)//100 | ||
self.evaluate_every = len(self.loader)//10 | ||
self.save_every = len(self.loader)//5 | ||
self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(conf.data_path) | ||
else: | ||
self.threshold = conf.threshold | ||
|
||
def save_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): | ||
if to_save_folder: | ||
save_path = conf.save_path | ||
else: | ||
save_path = conf.model_path | ||
torch.save( | ||
self.model.state_dict(), save_path / | ||
('model_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra))) | ||
if not model_only: | ||
torch.save( | ||
self.head.state_dict(), save_path / | ||
('head_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra))) | ||
torch.save( | ||
self.optimizer.state_dict(), save_path / | ||
('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra))) | ||
|
||
def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False): | ||
if from_save_folder: | ||
save_path = conf.save_path | ||
else: | ||
save_path = conf.model_path | ||
self.model.load_state_dict(torch.load(save_path/'model_{}'.format(fixed_str))) | ||
if not model_only: | ||
self.head.load_state_dict(torch.load(save_path/'head_{}'.format(fixed_str))) | ||
self.optimizer.load_state_dict(torch.load(save_path/'optimizer_{}'.format(fixed_str))) | ||
|
||
def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor, val, val_std, far): | ||
self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) | ||
self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) | ||
self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) | ||
self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step) | ||
self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step) | ||
self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step) | ||
|
||
def evaluate(self, conf, carray, issame, nrof_folds = 5, tta = False): | ||
self.model.eval() | ||
idx = 0 | ||
embeddings = np.zeros([len(carray), conf.embedding_size]) | ||
with torch.no_grad(): | ||
while idx + conf.batch_size <= len(carray): | ||
batch = torch.tensor(carray[idx:idx + conf.batch_size]) | ||
if tta: | ||
fliped = hflip_batch(batch) | ||
emb_batch = self.model(batch.to(conf.device)) + self.model(fliped.to(conf.device)) | ||
embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch) | ||
else: | ||
embeddings[idx:idx + conf.batch_size] = self.model(batch.to(conf.device)).cpu() | ||
idx += conf.batch_size | ||
if idx < len(carray): | ||
batch = torch.tensor(carray[idx:]) | ||
if tta: | ||
fliped = hflip_batch(batch) | ||
emb_batch = self.model(batch.to(conf.device)) + self.model(fliped.to(conf.device)) | ||
embeddings[idx:] = l2_norm(emb_batch) | ||
else: | ||
embeddings[idx:] = self.model(batch.to(conf.device)).cpu() | ||
tpr, fpr, accuracy, best_thresholds, val, val_std, far = evaluate(embeddings, issame, nrof_folds) | ||
buf = gen_plot(fpr, tpr) | ||
roc_curve = Image.open(buf) | ||
roc_curve_tensor = trans.ToTensor()(roc_curve) | ||
return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor, val, val_std, far | ||
|
||
def find_lr(self, | ||
conf, | ||
init_value=1e-8, | ||
final_value=10., | ||
beta=0.98, | ||
num=None): | ||
if not num: | ||
num = len(self.loader) | ||
mult = (final_value / init_value)**(1 / num) | ||
lr = init_value | ||
for params in self.optimizer.param_groups: | ||
params['lr'] = lr | ||
self.model.train() | ||
avg_loss = 0. | ||
best_loss = 0. | ||
batch_num = 0 | ||
losses = [] | ||
log_lrs = [] | ||
for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num): | ||
|
||
imgs = imgs.to(conf.device) | ||
labels = labels.to(conf.device) | ||
batch_num += 1 | ||
|
||
self.optimizer.zero_grad() | ||
|
||
embeddings = self.model(imgs) | ||
thetas = self.head(embeddings, labels) | ||
loss = conf.ce_loss(thetas, labels) | ||
|
||
#Compute the smoothed loss | ||
avg_loss = beta * avg_loss + (1 - beta) * loss.item() | ||
self.writer.add_scalar('avg_loss', avg_loss, batch_num) | ||
smoothed_loss = avg_loss / (1 - beta**batch_num) | ||
self.writer.add_scalar('smoothed_loss', smoothed_loss,batch_num) | ||
#Stop if the loss is exploding | ||
if batch_num > 1 and smoothed_loss > 3 * best_loss: | ||
print('exited with best_loss at {}'.format(best_loss)) | ||
plt.plot(log_lrs[10:-5], losses[10:-5]) | ||
return log_lrs, losses | ||
#Record the best loss | ||
if smoothed_loss < best_loss or batch_num == 1: | ||
best_loss = smoothed_loss | ||
#Store the values | ||
losses.append(smoothed_loss) | ||
log_lrs.append(math.log10(lr)) | ||
self.writer.add_scalar('log_lr', math.log10(lr), batch_num) | ||
#Do the SGD step | ||
#Update the lr for the next step | ||
|
||
loss.backward() | ||
self.optimizer.step() | ||
|
||
lr *= mult | ||
for params in self.optimizer.param_groups: | ||
params['lr'] = lr | ||
if batch_num > num: | ||
plt.plot(log_lrs[10:-5], losses[10:-5]) | ||
return log_lrs, losses | ||
|
||
def train(self, conf, epochs): | ||
self.model.train() | ||
running_loss = 0. | ||
for e in range(epochs): | ||
for imgs, labels in tqdm(iter(self.loader)): | ||
imgs = imgs.to(conf.device) | ||
labels = labels.to(conf.device) | ||
self.optimizer.zero_grad() | ||
embeddings = self.model(imgs) | ||
thetas = self.head(embeddings, labels) | ||
loss = conf.ce_loss(thetas, labels) | ||
loss.backward() | ||
running_loss += loss.item()/conf.batch_size | ||
self.optimizer.step() | ||
|
||
if self.step % self.board_loss_every == 0 and self.step != 0: | ||
self.writer.add_scalar('train_loss', running_loss / self.board_loss_every, self.step) | ||
running_loss = 0. | ||
|
||
if self.step % self.evaluate_every == 0 and self.step != 0: | ||
accuracy, best_threshold, roc_curve_tensor, val, val_std, far = self.evaluate(conf, self.agedb_30, self.agedb_30_issame) | ||
self.board_val('agedb_30', accuracy, best_threshold, roc_curve_tensor, val, val_std, far) | ||
accuracy, best_threshold, roc_curve_tensor, val, val_std, far = self.evaluate(conf, self.lfw, self.lfw_issame) | ||
self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor, val, val_std, far) | ||
accuracy, best_threshold, roc_curve_tensor, val, val_std, far = self.evaluate(conf, self.cfp_fp, self.cfp_fp_issame) | ||
self.board_val('cfp_fp', accuracy, best_threshold, roc_curve_tensor, val, val_std, far) | ||
self.model.train() | ||
if self.step % self.save_every == 0 and self.step != 0: | ||
self.save_state(conf, accuracy) | ||
|
||
self.step += 1 | ||
|
||
if e == self.milestones[0]: | ||
self.schedule_lr() | ||
self.save_state(conf, accuracy, to_save_folder=True, extra='{} epochs'.format(e)) | ||
|
||
if e == self.milestones[1]: | ||
self.schedule_lr() | ||
self.save_state(conf, accuracy, to_save_folder=True, extra='{} epochs'.format(e)) | ||
|
||
if e == self.milestones[2]: | ||
self.schedule_lr() | ||
self.save_state(conf, accuracy, to_save_folder=True, extra='{} epochs'.format(e)) | ||
|
||
self.save_state(conf, accuracy, to_save_folder=True, extra='final') | ||
|
||
def schedule_lr(self): | ||
for params in self.optimizer.param_groups: | ||
params['lr'] /= 10 | ||
print(self.optimizer) | ||
|
||
def infer(self, conf, faces, target_embs, tta=False): | ||
''' | ||
faces : list of PIL Image | ||
target_embs : [n, 512] computed embeddings of faces in facebank | ||
names : recorded names of faces in facebank | ||
tta : test time augmentation (hfilp, that's all) | ||
''' | ||
embs = [] | ||
for img in faces: | ||
if tta: | ||
mirror = trans.functional.hflip(img) | ||
emb = self.model(conf.test_transform(img).to(conf.device).unsqueeze(0)) | ||
emb_mirror = self.model(conf.test_transform(mirror).to(conf.device).unsqueeze(0)) | ||
embs.append(l2_norm(emb + emb_mirror)) | ||
else: | ||
embs.append(self.model(conf.test_transform(img).to(conf.device).unsqueeze(0))) | ||
source_embs = torch.cat(embs) | ||
|
||
diff = source_embs.unsqueeze(-1) - target_embs.transpose(1,0).unsqueeze(0) | ||
dist = torch.sum(torch.pow(diff, 2), dim=1) | ||
minimum, min_idx = torch.min(dist, dim=1) | ||
min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 | ||
return min_idx |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from easydict import EasyDict as edict | ||
from pathlib import Path | ||
import torch | ||
from torch.nn import CrossEntropyLoss | ||
from torchvision import transforms as trans | ||
|
||
def get_config(training = True): | ||
conf = edict() | ||
conf.data_path = Path('data') | ||
conf.work_path = Path('work_space/') | ||
conf.model_path = conf.work_path/'models' | ||
conf.log_path = conf.work_path/'log' | ||
conf.save_path = conf.work_path/'save' | ||
conf.input_size = [112,112] | ||
conf.embedding_size = 512 | ||
conf.use_mobilfacenet = False | ||
conf.net_depth = 50 | ||
conf.net_mode = 'ir_se' # or 'ir' | ||
conf.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
#--------------------Training Config ------------------------ | ||
if training: | ||
conf.data_mode = 'concat' | ||
conf.vgg_folder = conf.data_path/'vgg_dataset' | ||
conf.ms1m_folder = conf.data_path/'ms1m_dataset' | ||
conf.log_path = conf.work_path/'log' | ||
conf.save_path = conf.work_path/'save' | ||
# conf.weight_decay = 5e-4 | ||
conf.lr = 6e-4 | ||
# conf.milestones = [3,4,5] # mobildefacenet | ||
conf.milestones = [4,6,7] # arcface | ||
conf.momentum = 0.9 | ||
conf.drop_ratio = 0.6 | ||
conf.batch_size = 84 # irse net depth 50 | ||
# conf.batch_size = 200 # mobilefacenet | ||
conf.pin_memory = True | ||
# conf.num_workers = 4 # when batchsize is 200 | ||
conf.num_workers = 3 | ||
conf.ce_loss = CrossEntropyLoss() | ||
#--------------------Inference Config ------------------------ | ||
else: | ||
conf.facebank_path = conf.data_path/'facebank' | ||
conf.test_transform = trans.Compose([ | ||
trans.ToTensor(), | ||
trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) | ||
]) | ||
conf.threshold = 1.5 | ||
conf.face_limit = 10 | ||
#when inference, at maximum detect 10 faces, my laptop is slow | ||
conf.min_face_size = 30 | ||
# the larger this value, the faster deduction, comes with tradeoff in small faces | ||
return conf |
Empty file.
Oops, something went wrong.