Skip to content

Commit

Permalink
Update vqa_gen_dataset.py, set max_tgt_length
Browse files Browse the repository at this point in the history
  • Loading branch information
logicwong authored Aug 31, 2023
1 parent b358cd6 commit a36b91c
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion data/mm_data/vqa_gen_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def __getitem__(self, index):
ref_dict = {item.split('|!+')[1]: float(item.split('|!+')[0]) for item in ref.split('&&')}
answer = max(ref_dict, key=ref_dict.get)
conf = torch.tensor([ref_dict[answer]])
tgt_item = self.encode_text(" {}".format(answer))
tgt_item = self.encode_text(" {}".format(answer), length=self.max_tgt_length)

if self.add_object and predict_objects is not None:
predict_object_seq = ' '.join(predict_objects.strip().split('&&')[:self.max_object_length])
Expand Down

0 comments on commit a36b91c

Please sign in to comment.