-
Notifications
You must be signed in to change notification settings - Fork 2
/
save_weights.py
36 lines (29 loc) · 1019 Bytes
/
save_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import click
import torch
import pytorch_lightning as pl
from scipy.fft import dst
from src.lit_module import TextClassificationModule, TextClassificationStudentModule
@click.command()
@click.argument("ckpt_name")
@click.argument("dst_name")
@click.argument("type")
def save_weights_model(ckpt_name: str, dst_name: str, type: str):
print(f"Save {ckpt_name} to {dst_name} (type={type})")
if type == "student":
ckpt = TextClassificationStudentModule.load_from_checkpoint(ckpt_name)
else:
ckpt = TextClassificationModule.load_from_checkpoint(ckpt_name)
torch.save(
{
"state_dict": ckpt.state_dict(),
pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY: dict(ckpt.hparams),
},
dst_name,
)
# test
if type == "student":
ckpt = TextClassificationStudentModule.load_from_checkpoint(dst_name)
else:
ckpt = TextClassificationModule.load_from_checkpoint(dst_name)
if __name__ == "__main__":
save_weights_model()