forked from mpc001/auto_avsr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathavg_ckpts.py
40 lines (35 loc) · 1.01 KB
/
avg_ckpts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import torch
def average_checkpoints(last):
avg = None
for path in last:
states = torch.load(path, map_location=lambda storage, loc: storage)[
"state_dict"
]
states = {k[6:]: v for k, v in states.items() if k.startswith("model.")}
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
if avg[k].is_floating_point():
avg[k] /= len(last)
else:
avg[k] //= len(last)
return avg
def ensemble(args):
last = [
os.path.join(args.exp_dir, args.exp_name, f"epoch={n}.ckpt")
for n in range(
args.trainer.max_epochs - 10,
args.trainer.max_epochs,
)
]
model_path = os.path.join(
args.exp_dir, args.exp_name, f"model_avg_10.pth"
)
torch.save(average_checkpoints(last), model_path)
return model_path