Skip to content

Commit 0804c81

Browse files
committed
update model saving logic
1 parent d5ab1d2 commit 0804c81

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
spliceai_train_code/
22

33
data/
4+
ckpts/
45
note/.ipynb_checkpoints/
56

67
*.egg-info/

tranception_pytorch/train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ def main():
136136
running_loss = []
137137

138138
if (cnt // args.gradient_accumulation_steps) % 25000 == 0:
139+
if not os.path.exists(os.path.dirname(args.output)):
140+
os.makedirs(os.path.dirname(args.output))
141+
139142
idx = cnt // args.gradient_accumulation_steps
140143
torch.save(model.state_dict(), f'{args.output}_{idx}.pt')
141144

0 commit comments

Comments
 (0)