Skip to content

Commit

Permalink
Merge pull request #1534 from h2oai/fix_llava_tokens
Browse files Browse the repository at this point in the history
Fix llava token counting
  • Loading branch information
pseudotensor authored Apr 8, 2024
2 parents dd8b90c + 3041e29 commit 9b5e26e
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 67 deletions.
40 changes: 5 additions & 35 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, \
have_langchain, set_openai, cuda_vis_check, H2O_Fire, lg_to_gr, str_to_list, str_to_dict, get_token_count, \
url_alive, have_wavio, have_soundfile, have_deepspeed, have_doctr, have_librosa, have_TTS, have_flash_attention_2, \
have_diffusers, sanitize_filename, get_gradio_tmp, get_is_gradio_h2oai, is_gradio_version4, get_json, is_json_vllm
have_diffusers, sanitize_filename, get_gradio_tmp, get_is_gradio_h2oai, is_gradio_version4, get_json, is_json_vllm, \
get_docs_tokens

start_faulthandler()
import_matplotlib()
Expand Down Expand Up @@ -4665,7 +4666,10 @@ def evaluate(
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
min_max_new_tokens=min_max_new_tokens,
tokenizer=tokenizer,
client=gr_client if not regenerate_gradio_clients else None,
verbose=verbose,
)
if not stream_output and img_file == 1:
from src.vision.utils_vision import get_llava_response
Expand Down Expand Up @@ -6295,40 +6299,6 @@ def count_overhead_tokens(tokenizer, doing_grounding=False):
return 0


def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None):
"""
max_input_tokens: Over all LLM calls, upper limit of total token count,
or single LLM call if want to know what docs fit into single call
"""
if text_context_list is None or len(text_context_list) == 0:
return 0, None, 0
assert max_input_tokens is not None, "Must set max_input_tokens"
tokens = [get_token_count(x + docs_joiner_default, tokenizer) for x in text_context_list]
tokens_cumsum = np.cumsum(tokens)
where_res = np.where(tokens_cumsum < max_input_tokens)[0]
# if below condition fails, then keep top_k_docs=-1 and trigger special handling next
if where_res.shape[0] > 0:
top_k_docs = 1 + where_res[-1]
one_doc_size = None
num_doc_tokens = tokens_cumsum[top_k_docs - 1] # by index
else:
# if here, means 0 and just do best with 1 doc
top_k_docs = 1
text_context_list = text_context_list[:top_k_docs]
# critical protection
from src.h2oai_pipeline import H2OTextGenerationPipeline
doc_content = text_context_list[0]
doc_content, new_tokens0 = H2OTextGenerationPipeline.limit_prompt(doc_content,
tokenizer,
max_prompt_length=max_input_tokens)
text_context_list[0] = doc_content
one_doc_size = len(doc_content)
num_doc_tokens = get_token_count(doc_content + docs_joiner_default, tokenizer)
print("Unexpected large chunks and can't add to context, will add 1 anyways. Tokens %s -> %s" % (
tokens[0], new_tokens0), flush=True)
return top_k_docs, one_doc_size, num_doc_tokens


def get_on_disk_models(llamacpp_path, use_auth_token, trust_remote_code):
print("Begin auto-detect HF cache text generation models", flush=True)
from huggingface_hub import scan_cache_dir
Expand Down
8 changes: 6 additions & 2 deletions src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
get_accordion, have_jq, get_doc, get_source, have_chromamigdb, get_token_count, reverse_ucurve_list, get_size, \
get_test_name_core, download_simple, have_fiftyone, have_librosa, return_good_url, n_gpus_global, \
get_accordion_named, hyde_titles, have_cv2, FullSet, create_relative_symlink, split_list, get_gradio_tmp, \
merge_dict
merge_dict, get_docs_tokens
from enums import DocumentSubset, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
LangChainAction, LangChainMode, DocumentChoice, LangChainTypes, font_size, head_acc, super_source_prefix, \
super_source_postfix, langchain_modes_intrinsic, get_langchain_prompts, LangChainAgent, docs_joiner_default, \
Expand All @@ -74,7 +74,7 @@
geminiimage_num_max, claude3image_num_max, gpt4image_num_max, llava_num_max, summary_prefix, extract_prefix, \
noop_prompt_type, unknown_prompt_type, template_prompt_type
from evaluate_params import gen_hyper, gen_hyper0
from gen import SEED, get_limited_prompt, get_docs_tokens, get_relaxed_max_new_tokens, get_model_retry, gradio_to_llm, \
from gen import SEED, get_limited_prompt, get_relaxed_max_new_tokens, get_model_retry, gradio_to_llm, \
get_client_from_inference_server
from prompter import non_hf_types, PromptType, Prompter, get_vllm_extra_dict, system_docqa, system_summary, \
is_vision_model, is_gradio_vision_model, is_json_model
Expand Down Expand Up @@ -1040,7 +1040,9 @@ def setup_call(self, prompt):
top_k=self.top_k,
penalty_alpha=self.penalty_alpha,
max_new_tokens=self.max_new_tokens,
min_max_new_tokens=self.min_max_new_tokens,
min_new_tokens=self.min_new_tokens,
verbose=self.verbose,
)
# NOTE: Don't handle self.context
if not self.add_chat_history_to_context:
Expand All @@ -1064,6 +1066,7 @@ def setup_call(self, prompt):
max_new_tokens=client_kwargs['max_new_tokens'],
client=self.client,
max_time=self.max_time,
tokenizer=self.tokenizer,
)
max_new_tokens = get_relaxed_max_new_tokens(prompt, tokenizer=self.tokenizer,
max_new_tokens=self.max_new_tokens,
Expand Down Expand Up @@ -3800,6 +3803,7 @@ def file_to_doc(file,
prompt=llava_prompt,
allow_prompt_auto=True,
max_time=60, # not too much time for docQA
verbose=verbose,
)
metadata = dict(source=file, date=str(datetime.now()), input_type='LLaVa')
docs1c = [Document(page_content=res, metadata=metadata)]
Expand Down
81 changes: 78 additions & 3 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from joblib import Parallel
from tqdm.auto import tqdm

from src.enums import split_google, invalid_json_str
from src.enums import split_google, invalid_json_str, docs_joiner_default
from src.utils_procs import reulimit

reulimit()
Expand Down Expand Up @@ -1211,7 +1211,7 @@ def __init__(self, model_max_length=2048,
is_llama_cpp=False):
if model_max_length is None:
assert not (
is_openai or is_anthropic or is_google), "Should have set model_max_length for OpenAI or Anthropic or Google"
is_openai or is_anthropic or is_google), "Should have set model_max_length for OpenAI or Anthropic or Google"
model_max_length = 2048
self.is_openai = is_openai
self.is_anthropic = is_anthropic
Expand Down Expand Up @@ -2162,7 +2162,9 @@ def has_starting_code_block(text):


pattern_extract_codeblock = re.compile(r"```[a-zA-Z]*\s*(.*?)(```|$)", re.DOTALL)
#pattern_extract_codeblock = re.compile(r"```(?:[a-zA-Z]*\s*)(.*?)(?=```|$)", re.DOTALL)


# pattern_extract_codeblock = re.compile(r"```(?:[a-zA-Z]*\s*)(.*?)(?=```|$)", re.DOTALL)


def extract_code_block_content(stream_content):
Expand Down Expand Up @@ -2234,3 +2236,76 @@ def get_vllm_version(openai_client, inference_server, verbose=False):
if verbose:
print(f"Failed to retrieve version, status code: {response.status_code}")
return vllm_version


def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None, docs_joiner=docs_joiner_default):
"""
max_input_tokens: Over all LLM calls, upper limit of total token count,
or single LLM call if want to know what docs fit into single call
"""
if text_context_list is None or len(text_context_list) == 0:
return 0, None, 0
assert max_input_tokens is not None, "Must set max_input_tokens"
tokens = [get_token_count(x + docs_joiner, tokenizer) for x in text_context_list]
tokens_cumsum = np.cumsum(tokens)
where_res = np.where(tokens_cumsum < max_input_tokens)[0]
# if below condition fails, then keep top_k_docs=-1 and trigger special handling next
if where_res.shape[0] > 0:
top_k_docs = 1 + where_res[-1]
one_doc_size = None
num_doc_tokens = tokens_cumsum[top_k_docs - 1] # by index
else:
# if here, means 0 and just do best with 1 doc
top_k_docs = 1
text_context_list = text_context_list[:top_k_docs]
# critical protection
from src.h2oai_pipeline import H2OTextGenerationPipeline
doc_content = text_context_list[0]
doc_content, new_tokens0 = H2OTextGenerationPipeline.limit_prompt(doc_content,
tokenizer,
max_prompt_length=max_input_tokens)
text_context_list[0] = doc_content
one_doc_size = len(doc_content)
num_doc_tokens = get_token_count(doc_content + docs_joiner, tokenizer)
print("Unexpected large chunks and can't add to context, will add 1 anyways. Tokens %s -> %s" % (
tokens[0], new_tokens0), flush=True)
return top_k_docs, one_doc_size, num_doc_tokens


def get_limited_text(hard_limit_tokens, text, tokenizer, verbose=False):
if tokenizer is None:
return text[:4 * hard_limit_tokens]

low = 0
high = len(text)
best_guess = text # Initialize best_guess to ensure it's defined
ntokens0 = len(tokenizer.tokenize(best_guess))
ntokens = None

max_steps = 5
steps = 0
while low <= high:
mid = low + (high - low) // 2 # Calculate midpoint for current search interval
# Estimate a trial cut of the text based on mid
trial_text_length = max(int(mid * 4), 1) # Using mid * 4 as an estimation, ensuring at least 1 character
trial_text = text[-trial_text_length:] # Take text from the end, based on trial_text_length

# Tokenize the trial text and count tokens
ntokens = len(tokenizer.tokenize(trial_text))

if ntokens > hard_limit_tokens:
# If the trial exceeds the token limit, reduce 'high' to exclude the current trial length
high = mid - 1
else:
# If the trial does not exceed the token limit, update 'best_guess' and increase 'low'
best_guess = trial_text # Update best_guess with the current trial_text
low = mid + 1 # Attempt to include more text in the next trial
if steps >= max_steps:
break
steps += 1

# 'best_guess' now contains the text that best fits the criteria
if verbose:
print("steps: %s ntokens0: %s/%s text0: %s ntokens: %s/%s text: %s" % (
steps, ntokens0, hard_limit_tokens, len(text), ntokens, hard_limit_tokens, len(best_guess)))
return best_guess
Loading

0 comments on commit 9b5e26e

Please sign in to comment.