Skip to content

Commit

Permalink
Account for bug in model definition w.r.t. stopping tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed Aug 18, 2024
1 parent cad6cc3 commit 9375285
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 10 deletions.
22 changes: 22 additions & 0 deletions src/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,28 @@ def is_vision_model(base_model, all_visible_models=[], visible_vision_models=[])
base_model in ['liuhaotian/llava-v1.6-34b', 'liuhaotian/llava-v1.6-vicuna-13b']


# https://github.com/vllm-project/vllm/issues/7628
# https://github.com/vllm-project/vllm/blob/ce143353c622318a9abf113bebee1cfebc274e0f/examples/offline_inference_vision_language.py#L126-L148
def extra_stop_token_ids(base_model, tokenizer=None, as_ids=False):
if base_model is None:
return []
assert tokenizer is not None or not as_ids
if base_model in ["OpenGVLab/InternVL-Chat-V1-5", "OpenGVLab/Mini-InternVL-Chat-2B-V1-5",
"OpenGVLab/Mini-InternVL-Chat-4B-V1-5", "OpenGVLab/InternVL-Chat-V1-5-Int8",
"OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-4B",
"OpenGVLab/InternVL2-8B", "OpenGVLab/InternVL2-26B", "OpenGVLab/InternVL2-40",
"OpenGVLab/InternVL2-Llama3-76B",
"OpenGVLab/InternVL2-40B-AWQ", "OpenGVLab/InternVL2-26B-AWQ", "OpenGVLab/InternVL2-8B-AWQ",
"OpenGVLab/InternVL2-2B-AWQ",
"OpenGVLab/InternVL2-Llama3-76B-AWQ"]:
words = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
if as_ids:
return tokenizer.encode(words, add_special_tokens=False)
else:
return words
return []


def tokens_per_image(base_model):
if not is_vision_model(base_model):
return 0
Expand Down
2 changes: 1 addition & 1 deletion src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def get_response(*args, exi=0):
if score_with_prompt:
data_point = dict(instruction=instruction, input=iinput, context=context)
prompter = Prompter(prompt_type, prompt_dict,
debug=debug, stream_output=stream_output)
debug=debug, stream_output=stream_output, base_model=base_model)
prompt = prompter.generate_prompt(data_point, context_from_history=False, image_file=image_file)
else:
# just raw input and output
Expand Down
4 changes: 2 additions & 2 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3069,7 +3069,7 @@ def evaluate(

# get prompter
prompter = Prompter(prompt_type, prompt_dict, debug=debug, stream_output=stream_output,
system_prompt=system_prompt, tokenizer=tokenizer)
system_prompt=system_prompt, tokenizer=tokenizer, base_model=base_model)

# THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use
assert langchain_mode in langchain_modes, "Invalid langchain_mode %s not in %s" % (langchain_mode, langchain_modes)
Expand Down Expand Up @@ -5339,7 +5339,7 @@ def get_limited_prompt(instruction,
debug = False
stream_output = False # doesn't matter
prompter = Prompter(prompt_type, prompt_dict, debug=debug, stream_output=stream_output,
system_prompt=system_prompt, tokenizer=tokenizer)
system_prompt=system_prompt, tokenizer=tokenizer, base_model=base_model)
if prompt_type != generate_prompt_type:
# override just this attribute, keep system_prompt etc. from original prompt_type
prompter.prompt_type = generate_prompt_type
Expand Down
4 changes: 2 additions & 2 deletions src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3788,7 +3788,7 @@ def get_llm(use_openai_model=False,
prompt_type = prompter.prompt_type
else:
prompter = Prompter(prompt_type, prompt_dict, debug=False, stream_output=stream_output,
tokenizer=tokenizer)
tokenizer=tokenizer, base_model=model_name)
pass # assume inputted prompt_type is correct
from gpt4all_llm import get_llm_gpt4all
llm = get_llm_gpt4all(model_name=model_name,
Expand Down Expand Up @@ -7096,7 +7096,7 @@ def _run_qa_db(query=None,
# get prompter
chat = True # FIXME?
prompter = Prompter(prompt_type, prompt_dict, debug=False, stream_output=stream_output,
system_prompt=system_prompt, tokenizer=tokenizer)
system_prompt=system_prompt, tokenizer=tokenizer, base_model=model_name)

scores = []
chain = None
Expand Down
3 changes: 2 additions & 1 deletion src/h2oai_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(self, *args, debug=False, chat=False, stream_output=False,
assert self.prompter.prompt_type is not None
else:
self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug,
stream_output=stream_output, tokenizer=self.tokenizer)
stream_output=stream_output, tokenizer=self.tokenizer,
base_model=base_model)
self.human = self.prompter.humanstr
self.bot = self.prompter.botstr
self.can_stop = True
Expand Down
10 changes: 8 additions & 2 deletions src/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import traceback

# also supports imports from this file from other files
from enums import PromptType, gpt_token_mapping, anthropic_mapping, google_mapping, mistralai_mapping, groq_mapping, noop_prompt_type, unknown_prompt_type, user_prompt_for_fake_system_prompt0, template_prompt_type, empty_prompt_type # keep single line
from enums import PromptType, gpt_token_mapping, anthropic_mapping, google_mapping, mistralai_mapping, groq_mapping, \
noop_prompt_type, unknown_prompt_type, user_prompt_for_fake_system_prompt0, template_prompt_type, empty_prompt_type, \
extra_stop_token_ids # keep single line
from prompter_utils import get_use_chat_template
from utils import FakeTokenizer
from stopping import update_terminate_responses
Expand Down Expand Up @@ -1647,7 +1649,8 @@ def inject_chatsep(prompt_type, prompt, chat_sep=None):

class Prompter(object):
def __init__(self, prompt_type, prompt_dict, debug=False, stream_output=False, repeat_penalty=False,
allowed_repeat_line_length=10, system_prompt=None, tokenizer=None, verbose=False):
allowed_repeat_line_length=10, system_prompt=None, tokenizer=None,
base_model=None, verbose=False):
self.prompt_type = prompt_type
self.prompt_dict = prompt_dict
self.debug = debug
Expand All @@ -1671,6 +1674,9 @@ def __init__(self, prompt_type, prompt_dict, debug=False, stream_output=False, r
self.use_chat_template = get_use_chat_template(tokenizer, prompt_type=prompt_type)
self.terminate_response = update_terminate_responses(self.terminate_response,
tokenizer=tokenizer)
self.base_model = base_model
self.terminate_response.extend(extra_stop_token_ids(self.base_model, as_ids=False))

self.pre_response = self.PreResponse
self.verbose = verbose

Expand Down
3 changes: 2 additions & 1 deletion src/stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from transformers import StoppingCriteria, StoppingCriteriaList, GenerationConfig

from enums import PromptType, t5_type
from enums import PromptType, t5_type, extra_stop_token_ids


def update_terminate_responses(terminate_response, tokenizer=None, trust_remote_code=True):
Expand Down Expand Up @@ -171,6 +171,7 @@ def get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model,
handle_newlines += [False] * len(stop)

stop_words = update_terminate_responses(stop_words, tokenizer=tokenizer)
stop_words.extend(extra_stop_token_ids(base_model, as_ids=False))

# get stop tokens
stop_words_ids = [
Expand Down
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "f40cd1fc088dae5af824658bdc95f2e48f8d533f"
__version__ = "cad6cc3418d2318717e43c16a918c2c125c31fd0"

0 comments on commit 9375285

Please sign in to comment.