Skip to content

Commit

Permalink
Fix bugs in inference
Browse files Browse the repository at this point in the history
  • Loading branch information
yanwei-li committed Mar 31, 2024
1 parent 6319148 commit c62b905
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 25 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ We provide some selected examples in this section. More examples can be found in
## Install
Please follow the instructions below to install the required packages.

NOTE: If you want to use Mini-Gemini-2B, please ensure to install the latest version Transformers (>=4.28.0).
NOTE: If you want to use Mini-Gemini-2B, please ensure to install the latest version Transformers (>=4.38.0).

1. Clone this repository
```bash
Expand Down
13 changes: 1 addition & 12 deletions minigemini/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,17 +397,6 @@ def dict(self):
sep2="<|endoftext|>",
)

# conv_mistral_instruct = Conversation(
# system="",
# roles=("USER", "ASSISTANT"),
# version="llama_v2",
# messages=(),
# offset=0,
# sep_style=SeparatorStyle.LLAMA_2,
# sep="",
# sep2="</s>",
# )

conv_mistral_instruct = Conversation(
system="",
roles=("USER", "ASSISTANT"),
Expand All @@ -426,7 +415,7 @@ def dict(self):
messages=(),
offset=0,
sep_style=SeparatorStyle.GEMMA,
sep="<bos>",
sep="",
sep2="<eos>",
)

Expand Down
2 changes: 1 addition & 1 deletion minigemini/eval/model_vqa_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __len__(self):


# DataLoader
def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=0):
assert batch_size == 1, "batch_size must be 1"
dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
Expand Down
14 changes: 5 additions & 9 deletions minigemini/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,6 @@ def main(args):
image_tensor = None
image_tensor_aux = []

# debug use
# import ipdb; ipdb.set_trace()


while True:
try:
Expand Down Expand Up @@ -191,9 +188,6 @@ def main(args):
prompt = final_str

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

with torch.inference_mode():
Expand All @@ -204,11 +198,13 @@ def main(args):
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
bos_token_id=tokenizer.bos_token_id, # Begin of sequence token
eos_token_id=tokenizer.eos_token_id, # End of sequence token
pad_token_id=tokenizer.pad_token_id, # Pad token
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria])
use_cache=True)

outputs = tokenizer.decode(output_ids[0]).strip()
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
conv.messages[-1][-1] = outputs

if args.gen and '<h>' in outputs and '</h>' in outputs:
Expand Down
1 change: 0 additions & 1 deletion minigemini/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def generate_stream(self, params):
do_sample = True if temperature > 0.001 else False

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
keywords = [stop_str]
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=30)

max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
Expand Down
3 changes: 2 additions & 1 deletion minigemini/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,8 @@ def preprocess_gemma(
assert conv.sep_style == conversation_lib.SeparatorStyle.GEMMA

# Mask targets
sep = "<start_of_turn>" + conv.sep + conv.roles[1] + "\n"
# sep = "<start_of_turn>" + conv.sep + conv.roles[1] + "\n"
sep = "<start_of_turn>" + conv.roles[1] + "\n"
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())

Expand Down

0 comments on commit c62b905

Please sign in to comment.