Skip to content

Commit

Permalink
Allow llava_model to be openai model with specific model name for now
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed Oct 30, 2024
1 parent dce9960 commit e77f54a
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 36 deletions.
134 changes: 99 additions & 35 deletions src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2464,12 +2464,15 @@ def get_num_tokens(self, text: str) -> int:

class GenerateStream:
def get_count_output_tokens(self, ret):
if hasattr(ret, 'llm_output') and 'model_name' in ret.llm_output and ret.llm_output['model_name'] in ['o1-mini', 'o1-preview']:
if hasattr(ret, 'llm_output') and 'model_name' in ret.llm_output and ret.llm_output['model_name'] in ['o1-mini',
'o1-preview']:
usage_dict = ret.llm_output['token_usage']
if 'completion_tokens' in usage_dict:
self.count_output_tokens += usage_dict['completion_tokens']
if 'completion_tokens_details' in usage_dict and 'reasoning_tokens' in usage_dict['completion_tokens_details']:
print("reasoning tokens for %s: %s" % (ret.llm_output['model_name'], usage_dict['completion_tokens_details']['reasoning_tokens']))
if 'completion_tokens_details' in usage_dict and 'reasoning_tokens' in usage_dict[
'completion_tokens_details']:
print("reasoning tokens for %s: %s" % (
ret.llm_output['model_name'], usage_dict['completion_tokens_details']['reasoning_tokens']))

def generate_prompt(
self,
Expand Down Expand Up @@ -2608,12 +2611,15 @@ async def _agenerate(

class GenerateNormal:
def get_count_output_tokens(self, ret):
if hasattr(ret, 'llm_output') and 'model_name' in ret.llm_output and ret.llm_output['model_name'] in ['o1-mini', 'o1-preview']:
if hasattr(ret, 'llm_output') and 'model_name' in ret.llm_output and ret.llm_output['model_name'] in ['o1-mini',
'o1-preview']:
usage_dict = ret.llm_output['token_usage']
if 'completion_tokens' in usage_dict:
self.count_output_tokens += usage_dict['completion_tokens']
if 'completion_tokens_details' in usage_dict and 'reasoning_tokens' in usage_dict['completion_tokens_details']:
print("reasoning tokens for %s: %s" % (ret.llm_output['model_name'], usage_dict['completion_tokens_details']['reasoning_tokens']))
if 'completion_tokens_details' in usage_dict and 'reasoning_tokens' in usage_dict[
'completion_tokens_details']:
print("reasoning tokens for %s: %s" % (
ret.llm_output['model_name'], usage_dict['completion_tokens_details']['reasoning_tokens']))

def generate_prompt(
self,
Expand Down Expand Up @@ -3292,7 +3298,7 @@ def get_llm(use_openai_model=False,

if json_vllm:
response_format_real = response_format if not (
guided_json or guided_regex or guided_choice or guided_grammar) else 'text'
guided_json or guided_regex or guided_choice or guided_grammar) else 'text'
vllm_extra_dict = get_vllm_extra_dict(tokenizer,
stop_sequences=prompter.stop_sequences if prompter else [],
# repetition_penalty=repetition_penalty, # could pass
Expand Down Expand Up @@ -3437,7 +3443,8 @@ def get_llm(use_openai_model=False,
if model_name in ['o1-mini', 'o1-preview']:
gen_server_kwargs['max_completion_tokens'] = gen_server_kwargs.pop('max_tokens')
max_reasoning_tokens = int(os.getenv("MAX_REASONING_TOKENS", 25000))
gen_server_kwargs['max_completion_tokens'] = max_reasoning_tokens + max(100, gen_server_kwargs['max_completion_tokens'])
gen_server_kwargs['max_completion_tokens'] = max_reasoning_tokens + max(100, gen_server_kwargs[
'max_completion_tokens'])
gen_server_kwargs['temperature'] = 1.0
model_kwargs.pop('presence_penalty', None)
model_kwargs.pop('n', None)
Expand Down Expand Up @@ -5111,33 +5118,90 @@ def file_to_doc(file,
print("END: Pix2Struct", flush=True)
if llava_model and enable_llava and 'vllm' not in llava_model:
file_llava = fix_image_file(file, do_align=True, do_rotate=True, do_pad=False)
# LLaVa
if verbose:
print("BEGIN: LLaVa", flush=True)
try:
from vision.utils_vision import get_llava_response
res, llava_prompt = get_llava_response(file_llava, llava_model,
prompt=llava_prompt,
allow_prompt_auto=True,
max_time=60, # not too much time for docQA
verbose=verbose,
)
metadata = dict(source=file, date=str(datetime.now()), input_type='LLaVa')
docs1c = [Document(page_content=res, metadata=metadata)]
docs1c = [x for x in docs1c if x.page_content]
add_meta(docs1c, file, parser='LLaVa: %s' % llava_model, file_as_source=True)
# caption didn't set source, so fix-up meta
hash_of_file = hash_file(file)
[doci.metadata.update(source=file, source_true=file_llava, hashid=hash_of_file,
llava_prompt=llava_prompt or '') for doci in
docs1c]
docs1.extend(docs1c)
except BaseException as e0:
print("LLaVa: %s: %s" % (str(e0), traceback.print_exception(e0)), flush=True)
e = e0
handled |= len(docs1) > 0
if verbose:
print("END: LLaVa", flush=True)

if llava_model.startswith('openai:'):
if verbose:
print("BEGIN: OpenAI docAI", flush=True)
try:
from openai import OpenAI
openai_client = OpenAI(base_url=os.getenv('H2OGPT_OPENAI_BASE_URL', 'https://api.openai.com'),
api_key=os.getenv('H2OGPT_OPENAI_API_KEY', 'EMPTY'), timeout=60)
if llava_prompt in ['auto', None]:
llava_prompt = "Describe the image and what does the image say?"
from vision.utils_vision import img_to_base64
file_llava_url = img_to_base64(file_llava)
content = [{
'type': 'text',
'text': llava_prompt,
}, {
'type': 'image_url',
'image_url': {
'url':
file_llava_url,
},
}]
messages = [dict(role='system',
content='You are a keen document vision model that can understand complex images and text and respond to queries or convert text inside images to text.'),
dict(role='user', content=content)]
stream_output = False
gen_server_kwargs = dict()
model_name = llava_model.split('openai:')[1]
responses = openai_client.chat.completions.create(
model=model_name,
messages=messages,
stream=stream_output,
**gen_server_kwargs,
)
if responses.choices is None and responses.model_extra:
raise RuntimeError("OpenAI Chat failed: %s" % responses.model_extra)
res = responses.choices[0].message.content
if not res:
raise RuntimeError("OpenAI Chat had no response")

metadata = dict(source=file, date=str(datetime.now()), input_type='OpenAI DocAI')
docs1c = [Document(page_content=res, metadata=metadata)]
docs1c = [x for x in docs1c if x.page_content]
add_meta(docs1c, file, parser='LLaVa: %s' % llava_model, file_as_source=True)
# caption didn't set source, so fix-up meta
hash_of_file = hash_file(file)
[doci.metadata.update(source=file, source_true=file_llava, hashid=hash_of_file,
llava_prompt=llava_prompt or '') for doci in
docs1c]
docs1.extend(docs1c)
except BaseException as e0:
print("LLaVa: %s: %s" % (str(e0), traceback.print_exception(e0)), flush=True)
e = e0
handled |= len(docs1) > 0
if verbose:
print("END: OpenAI docAI", flush=True)
else:
# LLaVa
if verbose:
print("BEGIN: LLaVa", flush=True)
try:
from vision.utils_vision import get_llava_response
res, llava_prompt = get_llava_response(file_llava, llava_model,
prompt=llava_prompt,
allow_prompt_auto=True,
max_time=60, # not too much time for docQA
verbose=verbose,
)
metadata = dict(source=file, date=str(datetime.now()), input_type='LLaVa')
docs1c = [Document(page_content=res, metadata=metadata)]
docs1c = [x for x in docs1c if x.page_content]
add_meta(docs1c, file, parser='LLaVa: %s' % llava_model, file_as_source=True)
# caption didn't set source, so fix-up meta
hash_of_file = hash_file(file)
[doci.metadata.update(source=file, source_true=file_llava, hashid=hash_of_file,
llava_prompt=llava_prompt or '') for doci in
docs1c]
docs1.extend(docs1c)
except BaseException as e0:
print("LLaVa: %s: %s" % (str(e0), traceback.print_exception(e0)), flush=True)
e = e0
handled |= len(docs1) > 0
if verbose:
print("END: LLaVa", flush=True)

doc1 = chunk_sources(docs1)
if len(doc1) == 0:
Expand Down
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "9e71f30a01ef47e0f9333f5580a55382b4cd15e2"
__version__ = "dce9960977e52cc03ae07115e858bdbe308773ed"

0 comments on commit e77f54a

Please sign in to comment.