Skip to content

Commit 8b9d314

Browse files
committed
stop eval after n batches
+weird import error
1 parent cc5b642 commit 8b9d314

File tree

4 files changed

+8
-4
lines changed

4 files changed

+8
-4
lines changed

dataset/dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import albumentations as alb
2+
from albumentations.pytorch import ToTensorV2
13
import torch
24
import torch.nn as nn
35
import torch.nn.functional as F
@@ -15,8 +17,6 @@
1517
import cv2
1618
from transformers import PreTrainedTokenizerFast
1719
from tqdm.auto import tqdm
18-
import albumentations as alb
19-
from albumentations.pytorch import ToTensorV2
2020

2121

2222
train_transform = alb.Compose(

eval.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,14 @@ def detokenize(tokens, tokenizer):
2727

2828

2929
@torch.no_grad()
30-
def evaluate(model: torch.nn.Module, dataset: Im2LatexDataset, args: Munch, name: str = 'test'):
30+
def evaluate(model: torch.nn.Module, dataset: Im2LatexDataset, args: Munch, num_batches: int = None, name: str = 'test'):
3131
"""evaluates the model. Returns bleu score on the dataset
3232
3333
Args:
3434
model (torch.nn.Module): the model
3535
dataset (Im2LatexDataset): test dataset
3636
args (Munch): arguments
37+
num_batches (int): How many batches to evaluate on. Defaults to None (all batches).
3738
name (str, optional): name of the test e.g. val or test for wandb. Defaults to 'test'.
3839
3940
Returns:
@@ -53,6 +54,8 @@ def evaluate(model: torch.nn.Module, dataset: Im2LatexDataset, args: Munch, name
5354
truth = detokenize(seq['input_ids'], dataset.tokenizer)
5455
bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth]))
5556
pbar.set_description('BLEU: %.2f' % (np.mean(bleus)))
57+
if num_batches is not None and i >= num_batches:
58+
break
5659
bleu_score = np.mean(bleus)
5760
# samples
5861
pred = token2str(dec, dataset.tokenizer)

settings/default.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ name: "pix2tex"
1212
epochs: 10
1313
batchsize: 8
1414
testbatchsize: 20
15+
valbatches: 100
1516

1617
# Optimizer configurations
1718
optimizer: "Adam"

train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def train(args):
5151
if args.wandb:
5252
wandb.log({'train/loss': loss.item()})
5353
if (i+1) % args.sample_freq == 0:
54-
evaluate(model, valdataloader, args, name='val')
54+
evaluate(model, valdataloader, args, num_batches=args.valbatches, name='val')
5555
if (e+1) % args.save_freq == 0:
5656
torch.save(model.state_dict(), os.path.join(args.out_path, '%s_e%02d.pth' % (args.name, e+1)))
5757
yaml.dump(dict(args), open(os.path.join(args.out_path, 'config.yaml'), 'w+'))

0 commit comments

Comments
 (0)