Skip to content

Commit

Permalink
Update model_vqa_loader.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
haotian-liu committed Nov 5, 2023
1 parent b7a4865 commit 5900c2a
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions llava/eval/model_vqa_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def eval_model(args):
idx = line["question_id"]
cur_prompt = line["text"]

stop_str = conv_templates[args.conv_mode].sep if conv_templates[args.conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[args.conv_mode].sep2
input_ids = input_ids.to(device='cuda', non_blocking=True)

with torch.inference_mode():
Expand All @@ -103,7 +102,7 @@ def eval_model(args):
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
max_new_tokens=128,
max_new_tokens=args.max_new_tokens,
use_cache=True)

input_token_len = input_ids.shape[1]
Expand All @@ -112,9 +111,6 @@ def eval_model(args):
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()

ans_id = shortuuid.uuid()
ans_file.write(json.dumps({"question_id": idx,
Expand All @@ -139,6 +135,7 @@ def eval_model(args):
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--max_new_tokens", type=int, default=128)
args = parser.parse_args()

eval_model(args)

0 comments on commit 5900c2a

Please sign in to comment.