Skip to content

Commit 9b781c0

Browse files
committed
wandb support
1 parent baf75bc commit 9b781c0

File tree

5 files changed

+51
-12
lines changed

5 files changed

+51
-12
lines changed

models.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
77
from einops import rearrange, repeat
88

9-
109
class ViTransformerWrapper(nn.Module):
1110
def __init__(
1211
self,
@@ -90,6 +89,9 @@ def get_model(args):
9089
heads=args.heads,
9190
cross_attend=True
9291
)),
93-
pad_value=args.pad_token_id
92+
pad_value=args.pad_token
9493
).to(args.device)
94+
if args.wandb:
95+
import wandb
96+
wandb.watch((encoder, decoder))
9597
return Model(encoder, decoder, args)

settings/default.yaml

+12-5
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,9 @@
22
data: "dataset/data/dataset.pkl"
33
output_path: "outputs"
44
model_path: "checkpoints"
5+
save_freq: 5 # save every nth epoch
56
name: "pix2tex"
67

7-
# Token ids
8-
pad_token: 0
9-
bos_token: 1
10-
eos_token: 2
11-
128
# Training parameters
139
epochs: 10
1410
batchsize: 8
@@ -29,3 +25,14 @@ num_layers: 4
2925
heads: 8
3026
num_tokens: 8000
3127
max_seq_len: 512
28+
29+
# Other
30+
seed: 42
31+
id: null
32+
sample_freq: 50
33+
debug: True
34+
35+
# Token ids
36+
pad_token: 0
37+
bos_token: 1
38+
eos_token: 2

train.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
from dataset.dataset import Im2LatexDataset
1616
from models import get_model
17+
from utils import *
1718

1819

1920
def train(args):
2021
dataloader = Im2LatexDataset().load(args.data)
2122
dataloader.update(args)
2223
device = args.device
23-
args.pad_token_id = dataloader.pad_token_id
24+
os.makedirs(args.model_path, exist_ok=True)
25+
2426
model = get_model(args)
2527
encoder, decoder = model.encoder, model.decoder
2628
opt = optim.Adam(model.parameters(), args.lr)
@@ -36,10 +38,18 @@ def train(args):
3638
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
3739
opt.step()
3840
dset.set_description('Loss: %.4f' % loss.item())
39-
if i % 15 == 0:
40-
print(''.join(dataloader.tokenizer.decode(decoder.generate(torch.LongTensor([dataloader.bos_token_id]).to(
41-
device), args.max_seq_len, eos_token=dataloader.eos_token_id, context=encoded[:1])[:-1]).split(' ')).replace('Ġ', ' ').strip())
42-
print(dataloader.pairs[dataloader.i][0][0])
41+
if args.wandb:
42+
wandb.log({'train/loss': loss.item()})
43+
if i % args.sample_freq == 0:
44+
pred = ''.join(dataloader.tokenizer.decode(decoder.generate(torch.LongTensor([dataloader.bos_token_id]).to(
45+
device), args.max_seq_len, eos_token=dataloader.eos_token_id, context=encoded[:1])[:-1]).split(' ')).replace('Ġ', ' ').strip()
46+
truth = dataloader.pairs[dataloader.i][0][0]
47+
if args.wandb:
48+
table = wandb.Table(columns=["Truth", "Prediction"])
49+
table.add_data(tuth, pred)
50+
wandb.log({"test/examples": table})
51+
if (e+1) % args.save_freq == 0:
52+
torch.save(model.parameters(), os.path.join(args.model_path, '%s_e%02d' % (args.name, e+1)))
4353

4454

4555
if __name__ == '__main__':
@@ -54,6 +64,12 @@ def train(args):
5464
with parsed_args.config as f:
5565
params = yaml.load(f, Loader=yaml.FullLoader)
5666
args = Munch(params)
67+
args.wandb = not parsed_args.debug and not args.debug
5768
logging.getLogger().setLevel(logging.DEBUG if parsed_args.debug else logging.WARNING)
5869
args.device = torch.device('cuda' if torch.cuda.is_available() and not parsed_args.no_cuda else 'cpu')
70+
seed_everything(args.seed)
71+
if args.wandb:
72+
if not parsed_args.resume:
73+
args.id = wandb.util.generate_id()
74+
wandb.init(config=dict(args), resume='allow', name=args.name, id=args.id)
5975
train(args)

utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from utils.utils import *

utils/utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
def seed_everything(seed: int):
2+
import random
3+
import os
4+
import numpy as np
5+
import torch
6+
7+
random.seed(seed)
8+
os.environ['PYTHONHASHSEED'] = str(seed)
9+
np.random.seed(seed)
10+
torch.manual_seed(seed)
11+
torch.cuda.manual_seed(seed)
12+
torch.backends.cudnn.deterministic = True
13+
torch.backends.cudnn.benchmark = True

0 commit comments

Comments
 (0)