Skip to content

Commit

Permalink
fix bug and add pretrained model
Browse files Browse the repository at this point in the history
  • Loading branch information
Dechao Meng committed Jan 20, 2021
1 parent bea989d commit d83d58f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 25 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ python generate_pkl.py veriwild --input-path <VeRi_PATH> --output-path ../output
python generate_pkl.py vehicleid --input-path <VeRi_PATH> --output-path ../outputs/vehicleid.pkl
```


## Training the parsing model
<!-- We provide the pre-trained segmentation model on `examples/parsing/best_model_trainval.pth` which you can use to generate parsing masks for different datasets.
If you want to use the model directly, just skip this section.
Expand Down Expand Up @@ -101,6 +100,10 @@ python main.py train -c configs/vehicleid_b256_pven.yml
python main.py train -c configs/veriwild_b256_224_pven.yml
```

## Pretrained Models
We provide the pretrained parsing model, VeRi776 ReID model and VERIWild ReID model ( the classification layer has been removed ) for your convinient.

链接: https://pan.baidu.com/s/1Q2NMVfGZPCskh-E6vmy9Cw 密码: iiw1
## Evaluate PVEN
```shell
cd examples/parsing_reid
Expand Down
10 changes: 10 additions & 0 deletions examples/parsing_reid/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,5 +513,15 @@ def eval_vehicle_id_(model, valid_loader, query_length, cfg):
logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.2%}")
return cmc, mAP

@clk.command()
@click.option('-i', '--model-path')
@click.option('-o', '--output-path')
def drop_linear(model_path, output_path):
model = torch.load(model_path)
for key in model.keys():
if 'classifier' in key:
model[key] = None
torch.save(model, output_path)

if __name__ == '__main__':
clk()
50 changes: 26 additions & 24 deletions examples/parsing_reid/math_tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import ipdb

import torch
from torch.nn import functional as F
Expand All @@ -12,31 +13,32 @@
from vehicle_reid_pytorch.metrics.rerank import re_ranking


def calc_dist_split(qf, gf, split=0):
qf = qf
m = qf.shape[0]
n = gf.shape[0]
distmat = gf.new(m, n)
# def calc_dist_split(qf, gf, split=0):
# qf = qf
# m = qf.shape[0]
# n = gf.shape[0]
# distmat = gf.new(m, n)

if split == 0:
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
# if split == 0:
# distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
# torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
# distmat.addmm_(x, y.t(), beta=1, alpha=-2)

# 用于测试时控制显存
else:
start = 0
while start < n:
end = start + split if (start + split) < n else n
num = end - start
# # 用于测试时控制显存
# else:
# start = 0
# while start < n:
# end = start + split if (start + split) < n else n
# num = end - start

sub_distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, num) + \
torch.pow(gf[start:end], 2).sum(dim=1, keepdim=True).expand(num, m).t()
# sub_distmat.addmm_(1, -2, qf, gf[start:end].t())
sub_distmat.addmm_(qf, gf[start:end].t(), beta=1, alpha=-2)
distmat[:, start:end] = sub_distmat.cpu()
start += num
# sub_distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, num) + \
# torch.pow(gf[start:end], 2).sum(dim=1, keepdim=True).expand(num, m).t()
# # sub_distmat.addmm_(1, -2, qf, gf[start:end].t())
# sub_distmat.addmm_(qf, gf[start:end].t(), beta=1, alpha=-2)
# distmat[:, start:end] = sub_distmat.cpu()
# start += num

return distmat
# return distmat


def clck_dist(feat1, feat2, vis_score1, vis_score2, split=0):
Expand All @@ -58,7 +60,7 @@ def clck_dist(feat1, feat2, vis_score1, vis_score2, split=0):
parse_feat2 = feat2[:, :, i]
ckcl_ = torch.mm(vis_score1[:, i].view(-1, 1), vis_score2[:, i].view(1, -1)) # [N, N]
ckcl += ckcl_
dist_mat += calc_dist_split(parse_feat1, parse_feat2, split=split).sqrt() * ckcl_
dist_mat += euclidean_dist(parse_feat1, parse_feat2, split=split).sqrt() * ckcl_

return dist_mat / ckcl

Expand Down Expand Up @@ -165,7 +167,7 @@ def compute(self, split=0):
if split == 0:
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat.addmm_(1, -2, qf, gf.t())
distmat.addmm_(qf, gf.t(), beta=1, alpha=-2)
else:
distmat = gf.new(m, n)
start = 0
Expand Down Expand Up @@ -207,7 +209,7 @@ def compute(self, split=0):
torch.save(outputs, os.path.join(self.output_path, 'test_output.pkl'), pickle_protocol=4)

print('Eval...')
cmc, mAP, all_AP = eval_func_mp(distmat + self.lambda_ * local_distmat, q_pids, g_pids, q_camids, g_camids,
cmc, mAP, all_AP = eval_func_mp(distmat + self.lambda_ * (local_distmat ** 2), q_pids, g_pids, q_camids, g_camids,
remove_junk=self.remove_junk)

return {
Expand Down

0 comments on commit d83d58f

Please sign in to comment.