-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
46 lines (38 loc) · 1.53 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
CAPTION_LENGTH = 25
SIMPLE_PREFIX = "This image shows "
def prep_strings(text, tokenizer, template=None, retrieved_caps=None,
k=None, is_test=False, max_length=None):
if is_test:
padding = False
truncation = False
else:
padding = True
truncation = True
if retrieved_caps is not None:
infix = '\n\n'.join(retrieved_caps[:k]) + '.'
prefix = template.replace('||', infix)
else:
prefix = SIMPLE_PREFIX
prefix_ids = tokenizer.encode(prefix)
len_prefix = len(prefix_ids)
text_ids = tokenizer.encode(text, add_special_tokens=False)
if truncation:
text_ids = text_ids[:CAPTION_LENGTH]
input_ids = prefix_ids + text_ids if not is_test else prefix_ids
# we ignore the prefix (minus one as the first subtoken in the prefix is not predicted)
label_ids = [-100] * (len_prefix - 1) + text_ids + [tokenizer.eos_token_id]
if padding:
input_ids += [tokenizer.pad_token_id] * (max_length - len(input_ids))
label_ids += [-100] * (max_length - len(label_ids))
if is_test:
return input_ids
else:
return input_ids, label_ids
def postprocess_preds(pred, tokenizer):
pred = pred.split(SIMPLE_PREFIX)[-1]
pred = pred.replace(tokenizer.pad_token, '')
if pred.startswith(tokenizer.bos_token):
pred = pred[len(tokenizer.bos_token):]
if pred.endswith(tokenizer.eos_token):
pred = pred[:-len(tokenizer.eos_token)]
return pred