-
Notifications
You must be signed in to change notification settings - Fork 484
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
290 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
"""Integration with Gemini's API.""" | ||
from enum import EnumMeta | ||
from functools import singledispatchmethod | ||
from types import NoneType | ||
from typing import Optional, Union | ||
|
||
from pydantic import BaseModel | ||
from typing_extensions import _TypedDictMeta # type: ignore | ||
|
||
from outlines.prompts import Vision | ||
from outlines.types import Choice, Json | ||
|
||
__all__ = ["Gemini"] | ||
|
||
|
||
class GeminiBase: | ||
"""Base class for the Gemini clients. | ||
`GeminiBase` is responsible for preparing the arguments to Gemini's | ||
`generate_contents` methods: the input (prompt and possibly image), as well | ||
as the output type (only JSON). | ||
""" | ||
|
||
@singledispatchmethod | ||
def format_input(self, model_input): | ||
"""Generate the `messages` argument to pass to the client. | ||
Argument | ||
-------- | ||
model_input | ||
The input passed by the user. | ||
""" | ||
raise NotImplementedError( | ||
f"The input type {input} is not available with Gemini. The only available types are `str` and `Vision`." | ||
) | ||
|
||
@format_input.register(str) | ||
def format_str_input(self, model_input: str): | ||
"""Generate the `messages` argument to pass to the client when the user | ||
only passes a prompt. | ||
""" | ||
return {"contents": [model_input]} | ||
|
||
@format_input.register(Vision) | ||
def format_vision_input(self, model_input: Vision): | ||
"""Generate the `messages` argument to pass to the client when the user | ||
passes a prompt and an image. | ||
""" | ||
return {"contents": [model_input.prompt, model_input.image]} | ||
|
||
@singledispatchmethod | ||
def format_output_type(self, output_type): | ||
if output_type.__origin__ == list: | ||
if len(output_type.__args__) == 1 and isinstance( | ||
output_type.__args__[0], Json | ||
): | ||
return { | ||
"response_mime_type": "application/json", | ||
"response_schema": list[ | ||
output_type.__args__[0].original_definition | ||
], | ||
} | ||
else: | ||
raise TypeError | ||
else: | ||
raise NotImplementedError | ||
|
||
@format_output_type.register(NoneType) | ||
def format_none_output_type(self, output_type): | ||
return {} | ||
|
||
@format_output_type.register(Json) | ||
def format_json_output_type(self, output_type): | ||
if issubclass(output_type.original_definition, BaseModel): | ||
return { | ||
"response_mime_type": "application/json", | ||
"response_schema": output_type.original_definition, | ||
} | ||
elif isinstance(output_type.original_definition, _TypedDictMeta): | ||
return { | ||
"response_mime_type": "application/json", | ||
"response_schema": output_type.original_definition, | ||
} | ||
else: | ||
raise NotImplementedError | ||
|
||
@format_output_type.register(Choice) | ||
def format_enum_output_type(self, output_type): | ||
return { | ||
"response_mime_type": "text/x.enum", | ||
"response_schema": output_type.definition, | ||
} | ||
|
||
|
||
class Gemini(GeminiBase): | ||
def __init__(self, model_name: str, *args, **kwargs): | ||
import google.generativeai as genai | ||
|
||
self.client = genai.GenerativeModel(model_name, *args, **kwargs) | ||
|
||
def generate( | ||
self, | ||
model_input: Union[str, Vision], | ||
output_type: Optional[Union[Json, EnumMeta]] = None, | ||
**inference_kwargs, | ||
): | ||
import google.generativeai as genai | ||
|
||
contents = self.format_input(model_input) | ||
generation_config = genai.GenerationConfig( | ||
**self.format_output_type(output_type) | ||
) | ||
completion = self.client.generate_content( | ||
generation_config=generation_config, **contents, **inference_kwargs | ||
) | ||
|
||
return completion.text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import io | ||
import json | ||
from enum import Enum | ||
|
||
import PIL | ||
import pytest | ||
import requests | ||
from pydantic import BaseModel | ||
from typing_extensions import TypedDict | ||
|
||
from outlines.models.gemini import Gemini | ||
from outlines.prompts import Vision | ||
from outlines.types import Choice, Json | ||
|
||
MODEL_NAME = "gemini-1.5-flash-latest" | ||
|
||
|
||
def test_gemini_wrong_init_parameters(): | ||
with pytest.raises(TypeError, match="got an unexpected"): | ||
Gemini(MODEL_NAME, foo=10) | ||
|
||
|
||
def test_gemini_wrong_inference_parameters(): | ||
with pytest.raises(TypeError, match="got an unexpected"): | ||
model = Gemini(MODEL_NAME) | ||
model.generate("prompt", foo=10) | ||
|
||
|
||
@pytest.mark.api_call | ||
def test_gemini_simple_call(): | ||
model = Gemini(MODEL_NAME) | ||
result = model.generate("Respond with one word. Not more.") | ||
assert isinstance(result, str) | ||
|
||
|
||
@pytest.mark.api_call | ||
def test_gemini_simple_vision(): | ||
model = Gemini(MODEL_NAME) | ||
|
||
url = "https://raw.githubusercontent.com/dottxt-ai/outlines/refs/heads/main/docs/assets/images/logo.png" | ||
r = requests.get(url, stream=True) | ||
if r.status_code == 200: | ||
image = PIL.Image.open(io.BytesIO(r.content)) | ||
|
||
result = model.generate(Vision("What does this logo represent?", image)) | ||
assert isinstance(result, str) | ||
|
||
|
||
@pytest.mark.api_call | ||
def test_gemini_simple_pydantic(): | ||
model = Gemini(MODEL_NAME) | ||
|
||
class Foo(BaseModel): | ||
bar: int | ||
|
||
result = model.generate("foo?", Json(Foo)) | ||
assert isinstance(result, str) | ||
assert "bar" in json.loads(result) | ||
|
||
|
||
@pytest.mark.xfail(reason="Vision models do not work with structured outputs.") | ||
@pytest.mark.api_call | ||
def test_gemini_simple_vision_pydantic(): | ||
model = Gemini(MODEL_NAME) | ||
|
||
url = "https://raw.githubusercontent.com/dottxt-ai/outlines/refs/heads/main/docs/assets/images/logo.png" | ||
r = requests.get(url, stream=True) | ||
if r.status_code == 200: | ||
image = PIL.Image.open(io.BytesIO(r.content)) | ||
|
||
class Logo(BaseModel): | ||
name: int | ||
|
||
result = model.generate(Vision("What does this logo represent?", image), Logo) | ||
assert isinstance(result, str) | ||
assert "name" in json.loads(result) | ||
|
||
|
||
@pytest.mark.xfail(reason="Gemini seems to be unable to follow nested schemas.") | ||
@pytest.mark.api_call | ||
def test_gemini_nested_pydantic(): | ||
model = Gemini(MODEL_NAME) | ||
|
||
class Bar(BaseModel): | ||
fu: str | ||
|
||
class Foo(BaseModel): | ||
sna: int | ||
bar: Bar | ||
|
||
result = model.generate("foo?", Json(Foo)) | ||
assert isinstance(result, str) | ||
assert "sna" in json.loads(result) | ||
assert "bar" in json.loads(result) | ||
assert "fu" in json.loads(result)["bar"] | ||
|
||
|
||
@pytest.mark.xfail( | ||
reason="The Gemini SDK's serialization method does not support Json Schema dictionaries." | ||
) | ||
@pytest.mark.api_call | ||
def test_gemini_simple_json_schema_dict(): | ||
model = Gemini(MODEL_NAME) | ||
|
||
schema = { | ||
"properties": {"bar": {"title": "Bar", "type": "integer"}}, | ||
"required": ["bar"], | ||
"title": "Foo", | ||
"type": "object", | ||
} | ||
result = model.generate("foo?", Json(schema)) | ||
assert isinstance(result, str) | ||
assert "bar" in json.loads(result) | ||
|
||
|
||
@pytest.mark.xfail( | ||
reason="The Gemini SDK's serialization method does not support Json Schema strings." | ||
) | ||
@pytest.mark.api_call | ||
def test_gemini_simple_json_schema_string(): | ||
model = Gemini(MODEL_NAME) | ||
|
||
schema = "{'properties': {'bar': {'title': 'Bar', 'type': 'integer'}}, 'required': ['bar'], 'title': 'Foo', 'type': 'object'}" | ||
result = model.generate("foo?", Json(schema)) | ||
assert isinstance(result, str) | ||
assert "bar" in json.loads(result) | ||
|
||
|
||
@pytest.mark.api_call | ||
def test_gemini_simple_typed_dict(): | ||
model = Gemini(MODEL_NAME) | ||
|
||
class Foo(TypedDict): | ||
bar: int | ||
|
||
result = model.generate("foo?", Json(Foo)) | ||
assert isinstance(result, str) | ||
assert "bar" in json.loads(result) | ||
|
||
|
||
@pytest.mark.api_call | ||
def test_gemini_simple_enum(): | ||
model = Gemini(MODEL_NAME) | ||
|
||
class Foo(Enum): | ||
bar = "Bar" | ||
foor = "Foo" | ||
|
||
result = model.generate("foo?", Choice(Foo)) | ||
assert isinstance(result, str) | ||
assert result == "Foo" or result == "Bar" | ||
|
||
|
||
@pytest.mark.api_call | ||
def test_gemini_simple_list_pydantic(): | ||
model = Gemini(MODEL_NAME) | ||
|
||
class Foo(BaseModel): | ||
bar: int | ||
|
||
result = model.generate("foo?", list[Json(Foo)]) | ||
assert isinstance(json.loads(result), list) | ||
assert isinstance(json.loads(result)[0], dict) | ||
assert "bar" in json.loads(result)[0] |