Skip to content

Commit d5ab1d2

Browse files
committed
save model every 25000 steps
1 parent 38682cd commit d5ab1d2

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tranception_pytorch/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from tranception_pytorch import Tranception
1313
from tranception_pytorch.data import MaskedProteinDataset
1414

15-
1615
def seed_everything(seed):
1716
torch.manual_seed(seed)
1817
torch.cuda.manual_seed(seed)
@@ -45,7 +44,7 @@ def main():
4544

4645
parser = argparse.ArgumentParser()
4746
parser.add_argument('--input', '-i', required=True)
48-
parser.add_argument('--output', '-o', required=True)
47+
parser.add_argument('--output', '-o', help='Output prefix.', required=True)
4948
parser.add_argument('--batch-size', type=int, default=1024) # Taken from Table 8.
5049
parser.add_argument('--gradient-accumulation-steps', type=int, default=1)
5150
parser.add_argument('--annealing-steps', type=int, default=10_000) # Taken from Appendix B.3.
@@ -136,6 +135,10 @@ def main():
136135
})
137136
running_loss = []
138137

138+
if (cnt // args.gradient_accumulation_steps) % 25000 == 0:
139+
idx = cnt // args.gradient_accumulation_steps
140+
torch.save(model.state_dict(), f'{args.output}_{idx}.pt')
141+
139142
cnt += 1
140143

141144
if __name__ == '__main__':

0 commit comments

Comments
 (0)