Skip to content

Commit

Permalink
Handle gemini pro schema and use json mode for flash
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed Aug 19, 2024
1 parent a957ffe commit 8aca99f
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2995,7 +2995,9 @@ def evaluate(
not is_empty(
base_model) and base_model in openai_supports_functiontools + openai_supports_parallel_functiontools or \
not is_empty(inference_server) and \
inference_server.startswith('anthropic')
inference_server.startswith('anthropic') or \
not is_empty(inference_server) and \
inference_server.startswith('google') and base_model == 'gemini-1.5-pro-latest'

if supports_schema:
# for vLLM or claude-3, support schema if given
Expand Down
14 changes: 13 additions & 1 deletion src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
from h2o_serpapi import H2OSerpAPIWrapper
from utils_langchain import StreamingGradioCallbackHandler, _chunk_sources, _add_meta, add_parser, fix_json_meta, \
load_general_summarization_chain, H2OHuggingFaceHubEmbeddings, make_sources_file, select_docs_with_score, \
split_merge_docs
split_merge_docs, convert_to_genai_schema

# to check imports
# find ./src -name '*.py' | xargs awk '{ if (sub(/\\$/, "")) printf "%s ", $0; else print; }' | grep 'from langchain\.' | sed 's/^[ \t]*//' > go.py
Expand Down Expand Up @@ -2484,6 +2484,14 @@ def _generate(
have_tool = True
kwargs.pop('stream', None)
kwargs.pop('streaming', None)
if hasattr(self, 'safety_settings'):
# google
kwargs['safety_settings'] = self.safety_settings
if hasattr(self, 'response_format') and self.response_format == 'json_object':
kwargs['generation_config'] = dict(response_mime_type='application/json')
if self.guided_json and isinstance(self.guided_json, dict) and self.model == 'models/gemini-1.5-pro-latest':
# flash doesn't support, has to be part of prompt
kwargs['generation_config'].update(dict(response_schema=convert_to_genai_schema(self.guided_json)))
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
Expand Down Expand Up @@ -2754,6 +2762,8 @@ class H2OChatGoogle(ChatAGenerateStreamFirst, GenerateStream, ExtraChat, ChatGoo
count_input_tokens: Any = 0
count_output_tokens: Any = 0
prompter: Any = None
response_format: str = 'text'
guided_json: dict | None = {}


class H2OChatMistralAI(ChatAGenerateStreamFirst, GenerateStream2, ExtraChat, ChatMistralAI):
Expand Down Expand Up @@ -3365,6 +3375,8 @@ def get_llm(use_openai_model=False,
verbose=verbose,
tokenizer=tokenizer,
safety_settings=safety_settings,
response_format=response_format if response_format == 'json_object' else 'text',
guided_json=guided_json if response_format == 'json_object' else None,
prompter=prompter,
**kwargs_extra
)
Expand Down
41 changes: 41 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3014,3 +3014,44 @@ def is_empty(obj):
if hasattr(obj, '__len__'):
return len(obj) == 0
return False


from typing import Any, Dict, List, Union
from typing_extensions import TypedDict

def create_typed_dict(schema: Dict[str, Any], name: str = "Schema") -> type:
properties = schema.get("properties", {})
required = set(schema.get("required", []))

fields: Dict[str, Union[type, Any]] = {}
total = len(required) == len(properties)

for prop, details in properties.items():
prop_type = details.get("type")
if prop_type == "string":
field_type = str
elif prop_type == "integer":
field_type = int
elif prop_type == "number":
field_type = float
elif prop_type == "boolean":
field_type = bool
elif prop_type == "array":
items = details.get("items", {})
if items.get("type") == "string":
field_type = List[str]
elif items.get("type") == "object":
field_type = List[create_typed_dict(items, f"{name}Item")]
else:
field_type = List[Any]
elif prop_type == "object":
field_type = create_typed_dict(details, f"{name}{prop.capitalize()}")
else:
field_type = Any

if prop in required:
fields[prop] = field_type
else:
fields[prop] = Union[field_type, None]

return TypedDict(name, fields, total=total)
66 changes: 66 additions & 0 deletions src/utils_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,69 @@ def make_sources_file(langchain_mode, source_files_added):
with open(sources_file, "wt", encoding="utf-8") as f:
f.write(source_files_added)
return sources_file


from typing import Dict, Any, List
from google.ai.generativelanguage_v1beta.types import Schema, Type


def convert_to_genai_schema(json_schema: Dict[str, Any], name: str = "Root") -> Schema:
properties = json_schema.get("properties", {})
required = json_schema.get("required", [])

schema_properties = {}

for prop, details in properties.items():
prop_type = details.get("type")

if prop_type == "string":
if "enum" in details:
schema_properties[prop] = Schema(
type_=Type.STRING,
enum=details["enum"],
description=details.get("description", "")
)
else:
schema_properties[prop] = Schema(
type_=Type.STRING,
description=details.get("description", "")
)
elif prop_type == "integer":
format_type = details.get("format", "int32")
schema_properties[prop] = Schema(
type_=Type.INTEGER,
format_=format_type,
description=details.get("description", "")
)
elif prop_type == "number":
schema_properties[prop] = Schema(
type_=Type.NUMBER,
format_=details.get("format", "float"),
description=details.get("description", "")
)
elif prop_type == "boolean":
schema_properties[prop] = Schema(
type_=Type.BOOLEAN,
description=details.get("description", "")
)
elif prop_type == "array":
items = details.get("items", {})
schema_properties[prop] = Schema(
type_=Type.ARRAY,
items=convert_to_genai_schema({"type": "object", "properties": {"item": items}}),
description=details.get("description", "")
)
elif prop_type == "object":
schema_properties[prop] = convert_to_genai_schema(details, f"{name}{prop.capitalize()}")
else:
schema_properties[prop] = Schema(type_=Type.UNSPECIFIED)

if "nullable" in details:
schema_properties[prop].nullable = details["nullable"]

return Schema(
type_=Type.OBJECT,
properties=schema_properties,
required=required,
description=json_schema.get("description", "")
)
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "4761d64812136f2fc5f370b4740c9e33197c40f7"
__version__ = "a957ffe1f7a3370c6bd9043864934f2d60ea49b9"
2 changes: 1 addition & 1 deletion tests/test_client_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6323,7 +6323,7 @@ def test_get_image_file():

def get_test_server_client(base_model):
inference_server = os.getenv('TEST_SERVER', 'https://gpt.h2o.ai')
# inference_server = 'http://localhost:7860'
inference_server = 'http://localhost:7860'
# inference_server = 'http://localhost:7863'

if inference_server == 'https://gpt.h2o.ai':
Expand Down
90 changes: 89 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import tempfile
import time
import typing
import uuid

import pytest
Expand All @@ -14,7 +15,7 @@
from src.vision.utils_vision import process_file_list
from src.utils import get_list_or_str, read_popen_pipes, get_token_count, reverse_ucurve_list, undo_reverse_ucurve_list, \
is_uuid4, has_starting_code_block, extract_code_block_content, looks_like_json, get_json, is_full_git_hash, \
deduplicate_names, handle_json, check_input_type, start_faulthandler, remove, get_gradio_depth
deduplicate_names, handle_json, check_input_type, start_faulthandler, remove, get_gradio_depth, create_typed_dict
from src.enums import invalid_json_str, user_prompt_for_fake_system_prompt0
from src.prompter import apply_chat_template
import subprocess as sp
Expand Down Expand Up @@ -1219,3 +1220,90 @@ def test_depth():

example_list = [[[[[1]]]], [[[[2]]]], [[[3]]], [[4]], [5], []]
assert get_gradio_depth(example_list) == 4


def test_schema_to_typed():
TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"skills": {
"type": "array",
"items": {"type": "string", "maxLength": 10},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {"type": "string"},
"duration": {"type": "string"},
"position": {"type": "string"}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}

Schema = create_typed_dict(TEST_SCHEMA)

# Example usage of the generated TypedDict
person: Schema = {
"name": "John Doe",
"age": 30,
"skills": ["Python", "TypeScript", "Docker"],
"work history": [
{"company": "TechCorp", "position": "Developer", "duration": "2 years"},
{"company": "DataInc", "position": "Data Scientist"}
]
}

print(person)


def test_genai_schema():

# Usage example
TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"skills": {
"type": "array",
"items": {"type": "string", "maxLength": 10},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {"type": "string"},
"duration": {"type": "string"},
"position": {"type": "string"}
},
"required": ["company", "position"]
}
},
"status": {
"type": "string",
"enum": ["active", "inactive", "on leave"]
}
},
"required": ["name", "age", "skills", "work history", "status"]
}

from src.utils_langchain import convert_to_genai_schema
genai_schema = convert_to_genai_schema(TEST_SCHEMA)

# Print the schema (this will show the structure, but not all details)
print(genai_schema)

# You can now use this schema with the Gemini API
# For example:
# response = model.generate_content(prompt, response_schema=genai_schema)

0 comments on commit 8aca99f

Please sign in to comment.