12
12
from tqdm .auto import tqdm
13
13
import wandb
14
14
15
- from models import get_model
15
+ from models import get_model , Model
16
16
from utils import *
17
17
18
18
@@ -29,7 +29,7 @@ def detokenize(tokens, tokenizer):
29
29
30
30
31
31
@torch .no_grad ()
32
- def evaluate (model : torch . nn . Module , dataset : Im2LatexDataset , args : Munch , num_batches : int = None , name : str = 'test' ):
32
+ def evaluate (model : Model , dataset : Im2LatexDataset , args : Munch , num_batches : int = None , name : str = 'test' ):
33
33
"""evaluates the model. Returns bleu score on the dataset
34
34
35
35
Args:
@@ -53,7 +53,7 @@ def evaluate(model: torch.nn.Module, dataset: Im2LatexDataset, args: Munch, num_
53
53
encoded = model .encoder (im .to (device ))
54
54
#loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)
55
55
dec = model .decoder .generate (torch .LongTensor ([args .bos_token ]* len (encoded ))[:, None ].to (device ), args .max_seq_len ,
56
- eos_token = args .pad_token , context = encoded )
56
+ eos_token = args .pad_token , context = encoded , temperature = ( args . temperature if 'temperature' in args else 1 ) )
57
57
pred = detokenize (dec , dataset .tokenizer )
58
58
truth = detokenize (seq ['input_ids' ], dataset .tokenizer )
59
59
bleus .append (metrics .bleu_score (pred , [alternatives (x ) for x in truth ]))
@@ -84,13 +84,15 @@ def evaluate(model: torch.nn.Module, dataset: Im2LatexDataset, args: Munch, num_
84
84
parser .add_argument ('--no-cuda' , action = 'store_true' , help = 'Use CPU' )
85
85
parser .add_argument ('-b' , '--batchsize' , type = int , default = 10 , help = 'Batch size' )
86
86
parser .add_argument ('--debug' , action = 'store_true' , help = 'DEBUG' )
87
+ parser .add_argument ('-t' , '--temperature' , type = float , default = .333 , help = 'sampling emperature' )
87
88
88
89
parsed_args = parser .parse_args ()
89
90
with parsed_args .config as f :
90
91
params = yaml .load (f , Loader = yaml .FullLoader )
91
92
args = parse_args (Munch (params ))
92
93
args .testbatchsize = parsed_args .batchsize
93
94
args .wandb = False
95
+ args .temperature = parsed_args .temperature
94
96
logging .getLogger ().setLevel (logging .DEBUG if parsed_args .debug else logging .WARNING )
95
97
seed_everything (args .seed if 'seed' in args else 42 )
96
98
model = get_model (args )
0 commit comments