-
Notifications
You must be signed in to change notification settings - Fork 52
/
Copy pathtest.py
46 lines (34 loc) · 1.04 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from nb.torch.backbones.mobileone import MobileOneNet, make_mobileone_s0
from nb.torch.utils.checkpoint import load_ckp_unwrap_module
import sys
import torch
import os
a = sys.argv[1]
sd = load_ckp_unwrap_module(a)
x = torch.randn(1, 3, 224, 224)
model = make_mobileone_s0(deploy=False)
model.load_state_dict(sd)
print("original model loaded.")
n_f0 = os.path.join(
os.path.dirname(a), os.path.basename(a).split(".")[0] + "_noreparam.pt"
)
mod = torch.jit.trace(model, x)
mod.save(n_f0)
o1 = model(x)
for module in model.modules():
if hasattr(module, "switch_to_deploy"):
module.switch_to_deploy()
deploy_model = make_mobileone_s0(deploy=True)
deploy_model.eval()
deploy_model.load_state_dict(model.state_dict())
o = deploy_model(x)
print((o1 - o).sum())
n_f = os.path.join(
os.path.dirname(a), os.path.basename(a).split(".")[0] + "_reparam.pth"
)
torch.save(model.state_dict(), n_f)
mod = torch.jit.trace(deploy_model, x)
n_f2 = os.path.join(
os.path.dirname(a), os.path.basename(a).split(".")[0] + "_reparam.pt"
)
mod.save(n_f2)