Skip to content

Commit 88f8170

Browse files
committed
fix bugs
1 parent c8b0846 commit 88f8170

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

tools/publish_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import argparse
22
import subprocess
3+
from collections import OrderedDict
34

45
import torch
56

6-
77
def parse_args():
88
parser = argparse.ArgumentParser(
99
description='Process a checkpoint to be published')
@@ -18,6 +18,13 @@ def process_checkpoint(in_file, out_file):
1818
# remove optimizer for smaller file size
1919
if 'optimizer' in checkpoint:
2020
del checkpoint['optimizer']
21+
if 'state_dict' in checkpoint:
22+
in_state_dict = checkpoint.pop('state_dict')
23+
out_state_dict = OrderedDict()
24+
for key, val in in_state_dict.items():
25+
key = key.replace('backbone.','')
26+
out_state_dict[key] = val
27+
checkpoint['state_dict'] = out_state_dict
2128
# if it is necessary to remove some sensitive data in checkpoint['meta'],
2229
# add the code here.
2330
torch.save(checkpoint, out_file)

0 commit comments

Comments
 (0)