@@ -41,7 +41,7 @@ def train(
41
41
epoch_size = train_sampler .itersize (config ['batch_size_s' ] * config ['backprop_every' ], config ['batch_size_h' ])
42
42
collator = Collator (pad_value = - 1 , device = device , allow_self_loops = config ['allow_self_loops' ])
43
43
44
- model = Trainer (config ['model_config' ]).to (device ( device ) ).to (dtype )
44
+ model = Trainer (config ['model_config' ]).to (device ).to (dtype )
45
45
optimizer = AdamW (params = model .parameters (), lr = 1 , weight_decay = 1e-02 )
46
46
schedule = make_schedule (warmup_steps = config ['warmup_epochs' ] * epoch_size ,
47
47
warmdown_steps = config ['warmdown_epochs' ] * epoch_size ,
@@ -86,19 +86,17 @@ def parse_args():
86
86
default = '../data/model.pt' )
87
87
parser .add_argument ('--log_path' , type = str , help = 'Where to log results' ,
88
88
default = '../data/log.txt' )
89
- parser .add_argument ('--use_half' , type = bool , help = 'Whether to use half-precision floats' ,
90
- default = False )
91
89
return parser .parse_args ()
92
90
93
91
94
92
if __name__ == '__main__' :
95
93
args = parse_args ()
96
- train_cfg = json .load (open (args .config_path , 'r' ))
94
+ train_cfg : TrainCfg = json .load (open (args .config_path , 'r' ))
97
95
train (
98
96
config = train_cfg ,
99
97
data_path = args .data_path ,
100
98
store_path = args .store_path ,
101
99
log_path = args .log_path ,
102
100
device = 'cuda' ,
103
- dtype = torch .half if args . use_half else torch .float
101
+ dtype = torch .half if train_cfg [ 'half_precision' ] else torch .float
104
102
)
0 commit comments