File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change 12
12
from tranception_pytorch import Tranception
13
13
from tranception_pytorch .data import MaskedProteinDataset
14
14
15
-
16
15
def seed_everything (seed ):
17
16
torch .manual_seed (seed )
18
17
torch .cuda .manual_seed (seed )
@@ -45,7 +44,7 @@ def main():
45
44
46
45
parser = argparse .ArgumentParser ()
47
46
parser .add_argument ('--input' , '-i' , required = True )
48
- parser .add_argument ('--output' , '-o' , required = True )
47
+ parser .add_argument ('--output' , '-o' , help = 'Output prefix.' , required = True )
49
48
parser .add_argument ('--batch-size' , type = int , default = 1024 ) # Taken from Table 8.
50
49
parser .add_argument ('--gradient-accumulation-steps' , type = int , default = 1 )
51
50
parser .add_argument ('--annealing-steps' , type = int , default = 10_000 ) # Taken from Appendix B.3.
@@ -136,6 +135,10 @@ def main():
136
135
})
137
136
running_loss = []
138
137
138
+ if (cnt // args .gradient_accumulation_steps ) % 25000 == 0 :
139
+ idx = cnt // args .gradient_accumulation_steps
140
+ torch .save (model .state_dict (), f'{ args .output } _{ idx } .pt' )
141
+
139
142
cnt += 1
140
143
141
144
if __name__ == '__main__' :
You can’t perform that action at this time.
0 commit comments