Skip to content
This repository has been archived by the owner on May 19, 2022. It is now read-only.

Commit

Permalink
solve #1
Browse files Browse the repository at this point in the history
  • Loading branch information
Separius authored Oct 20, 2020
1 parent c6ede67 commit f4c2f41
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
parser = argparse.ArgumentParser(description='SimCLR converter')
parser.add_argument('tf_path', type=str, help='path of the input tensorflow file (ex: model.ckpt-250228)')
parser.add_argument('--ema', action='store_true')
parser.add_argument('--supervised', action='store_true')
args = parser.parse_args()


Expand Down Expand Up @@ -91,6 +92,10 @@ def main():
assert model.fc.bias.shape == b.shape
model.fc.bias.data = b

if args.supervised:
save_location = f'r{depth}_{width}x_sk{1 if sk_ratio != 0 else 0}{"_ema" if use_ema_model else ""}.pth'
torch.save({'resnet': model.state_dict(), 'head': head.state_dict()}, save_location)
return
sd = {}
for v in contrastive_vars:
sd[v] = ckpt_reader.get_tensor(v)
Expand Down

0 comments on commit f4c2f41

Please sign in to comment.