Skip to content

Commit

Permalink
Update prune_model_ckpt script.
Browse files Browse the repository at this point in the history
  • Loading branch information
kohjingyu committed Jul 7, 2023
1 parent 7653cf5 commit 7e9a9bc
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions scripts/prune_model_ckpt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Prunes model weights to keep just the necessary trained weights.
Example usage:
python gill/prune_model_ckpt.py runs/gill_exp
python scripts/prune_model_ckpt.py runs/gill_exp
"""

import json
Expand All @@ -19,19 +19,18 @@
with open(os.path.join(model_dir, 'model_args.json'), 'rb') as f:
model_args = json.load(f)

assert model_args['share_ret_gen']
assert model_args['num_gen_tokens'] == model_args['num_ret_tokens']

del checkpoint['epoch']
del checkpoint['best_acc1']
del checkpoint['optimizer']
del checkpoint['scheduler']

with open(os.path.join(model_dir, 'pretrained_ckpt.pth.tar'), 'wb') as f:
for k, v in checkpoint['state_dict'].items():
checkpoint['state_dict'][k.replace('module.', '')] = v.detach().clone()
state_dict = {}
for k, v in checkpoint['state_dict'].items():
state_dict[k.replace('module.', '')] = v.detach().clone()

finetuned_tokens = checkpoint['state_dict']['model.input_embeddings.weight'][-model_args['num_gen_tokens']:, :].detach().clone()
checkpoint['state_dict']['model.input_embeddings.weight'] = finetuned_tokens
checkpoint['state_dict'] = state_dict
finetuned_tokens = checkpoint['state_dict']['model.input_embeddings.weight'][-model_args['num_gen_tokens']:, :].detach().clone()
checkpoint['state_dict']['model.input_embeddings.weight'] = finetuned_tokens

with open(os.path.join(model_dir, 'pretrained_ckpt.pth.tar'), 'wb') as f:
torch.save(checkpoint, f)

0 comments on commit 7e9a9bc

Please sign in to comment.