Skip to content

Commit f3d59a1

Browse files
committed
validate with custom temperature
1 parent 0ae4998 commit f3d59a1

File tree

3 files changed

+27
-21
lines changed

3 files changed

+27
-21
lines changed

eval.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tqdm.auto import tqdm
1313
import wandb
1414

15-
from models import get_model
15+
from models import get_model, Model
1616
from utils import *
1717

1818

@@ -29,7 +29,7 @@ def detokenize(tokens, tokenizer):
2929

3030

3131
@torch.no_grad()
32-
def evaluate(model: torch.nn.Module, dataset: Im2LatexDataset, args: Munch, num_batches: int = None, name: str = 'test'):
32+
def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: int = None, name: str = 'test'):
3333
"""evaluates the model. Returns bleu score on the dataset
3434
3535
Args:
@@ -53,7 +53,7 @@ def evaluate(model: torch.nn.Module, dataset: Im2LatexDataset, args: Munch, num_
5353
encoded = model.encoder(im.to(device))
5454
#loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)
5555
dec = model.decoder.generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len,
56-
eos_token=args.pad_token, context=encoded)
56+
eos_token=args.pad_token, context=encoded, temperature=(args.temperature if 'temperature' in args else 1))
5757
pred = detokenize(dec, dataset.tokenizer)
5858
truth = detokenize(seq['input_ids'], dataset.tokenizer)
5959
bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth]))
@@ -84,13 +84,15 @@ def evaluate(model: torch.nn.Module, dataset: Im2LatexDataset, args: Munch, num_
8484
parser.add_argument('--no-cuda', action='store_true', help='Use CPU')
8585
parser.add_argument('-b', '--batchsize', type=int, default=10, help='Batch size')
8686
parser.add_argument('--debug', action='store_true', help='DEBUG')
87+
parser.add_argument('-t', '--temperature', type=float, default=.333, help='sampling emperature')
8788

8889
parsed_args = parser.parse_args()
8990
with parsed_args.config as f:
9091
params = yaml.load(f, Loader=yaml.FullLoader)
9192
args = parse_args(Munch(params))
9293
args.testbatchsize = parsed_args.batchsize
9394
args.wandb = False
95+
args.temperature = parsed_args.temperature
9496
logging.getLogger().setLevel(logging.DEBUG if parsed_args.debug else logging.WARNING)
9597
seed_everything(args.seed if 'seed' in args else 42)
9698
model = get_model(args)

models.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,6 @@
1010
from einops import rearrange, repeat
1111

1212

13-
class Model(nn.Module):
14-
def __init__(self, encoder: Encoder, decoder: AutoregressiveWrapper, args, temp: float = .333):
15-
super().__init__()
16-
self.encoder = encoder
17-
self.decoder = decoder
18-
self.bos_token = args.bos_token
19-
self.eos_token = args.eos_token
20-
self.max_seq_len = args.max_seq_len
21-
self.temperature = temp
22-
23-
@torch.no_grad()
24-
def forward(self, x: torch.Tensor):
25-
device = x.device
26-
encoded = self.encoder(x.to(device))
27-
dec = self.decoder.generate(torch.LongTensor([self.bos_token]*len(x))[:, None].to(device), self.max_seq_len,
28-
eos_token=self.eos_token, context=encoded, temperature=self.temperature)
29-
return dec
30-
3113

3214
class CustomARWrapper(AutoregressiveWrapper):
3315
def __init__(self, *args, **kwargs):
@@ -106,6 +88,25 @@ def forward_features(self, x):
10688
return x
10789

10890

91+
class Model(nn.Module):
92+
def __init__(self, encoder: CustomVisionTransformer, decoder: CustomARWrapper, args, temp: float = .333):
93+
super().__init__()
94+
self.encoder = encoder
95+
self.decoder = decoder
96+
self.bos_token = args.bos_token
97+
self.eos_token = args.eos_token
98+
self.max_seq_len = args.max_seq_len
99+
self.temperature = temp
100+
101+
@torch.no_grad()
102+
def forward(self, x: torch.Tensor):
103+
device = x.device
104+
encoded = self.encoder(x.to(device))
105+
dec = self.decoder.generate(torch.LongTensor([self.bos_token]*len(x))[:, None].to(device), self.max_seq_len,
106+
eos_token=self.eos_token, context=encoded, temperature=self.temperature)
107+
return dec
108+
109+
109110
def get_model(args):
110111
backbone = ResNetV2(
111112
layers=args.backbone_layers, num_classes=0, global_pool='', in_chans=args.channels,

settings/default.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ name: "pix2tex"
1111
# Training parameters
1212
epochs: 10
1313
batchsize: 8
14+
15+
# Testing parameters
1416
testbatchsize: 20
1517
valbatches: 100
18+
temperature: 0.2
1619

1720
# Optimizer configurations
1821
optimizer: "Adam"

0 commit comments

Comments
 (0)