14
14
15
15
from dataset .dataset import Im2LatexDataset
16
16
from models import get_model
17
+ from utils import *
17
18
18
19
19
20
def train (args ):
20
21
dataloader = Im2LatexDataset ().load (args .data )
21
22
dataloader .update (args )
22
23
device = args .device
23
- args .pad_token_id = dataloader .pad_token_id
24
+ os .makedirs (args .model_path , exist_ok = True )
25
+
24
26
model = get_model (args )
25
27
encoder , decoder = model .encoder , model .decoder
26
28
opt = optim .Adam (model .parameters (), args .lr )
@@ -36,10 +38,18 @@ def train(args):
36
38
torch .nn .utils .clip_grad_norm_ (model .parameters (), 0.5 )
37
39
opt .step ()
38
40
dset .set_description ('Loss: %.4f' % loss .item ())
39
- if i % 15 == 0 :
40
- print ('' .join (dataloader .tokenizer .decode (decoder .generate (torch .LongTensor ([dataloader .bos_token_id ]).to (
41
- device ), args .max_seq_len , eos_token = dataloader .eos_token_id , context = encoded [:1 ])[:- 1 ]).split (' ' )).replace ('Ġ' , ' ' ).strip ())
42
- print (dataloader .pairs [dataloader .i ][0 ][0 ])
41
+ if args .wandb :
42
+ wandb .log ({'train/loss' : loss .item ()})
43
+ if i % args .sample_freq == 0 :
44
+ pred = '' .join (dataloader .tokenizer .decode (decoder .generate (torch .LongTensor ([dataloader .bos_token_id ]).to (
45
+ device ), args .max_seq_len , eos_token = dataloader .eos_token_id , context = encoded [:1 ])[:- 1 ]).split (' ' )).replace ('Ġ' , ' ' ).strip ()
46
+ truth = dataloader .pairs [dataloader .i ][0 ][0 ]
47
+ if args .wandb :
48
+ table = wandb .Table (columns = ["Truth" , "Prediction" ])
49
+ table .add_data (tuth , pred )
50
+ wandb .log ({"test/examples" : table })
51
+ if (e + 1 ) % args .save_freq == 0 :
52
+ torch .save (model .parameters (), os .path .join (args .model_path , '%s_e%02d' % (args .name , e + 1 )))
43
53
44
54
45
55
if __name__ == '__main__' :
@@ -54,6 +64,12 @@ def train(args):
54
64
with parsed_args .config as f :
55
65
params = yaml .load (f , Loader = yaml .FullLoader )
56
66
args = Munch (params )
67
+ args .wandb = not parsed_args .debug and not args .debug
57
68
logging .getLogger ().setLevel (logging .DEBUG if parsed_args .debug else logging .WARNING )
58
69
args .device = torch .device ('cuda' if torch .cuda .is_available () and not parsed_args .no_cuda else 'cpu' )
70
+ seed_everything (args .seed )
71
+ if args .wandb :
72
+ if not parsed_args .resume :
73
+ args .id = wandb .util .generate_id ()
74
+ wandb .init (config = dict (args ), resume = 'allow' , name = args .name , id = args .id )
59
75
train (args )
0 commit comments