Skip to content

Commit

Permalink
Add Gemini integration
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 5, 2024
1 parent 6a19929 commit 9ed80a2
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 0 deletions.
1 change: 1 addition & 0 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Union

from .exllamav2 import ExLlamaV2Model, exl2
from .gemini import Gemini
from .llamacpp import LlamaCpp, llamacpp
from .mlxlm import MLXLM, mlxlm
from .openai import AzureOpenAI, OpenAI
Expand Down
121 changes: 121 additions & 0 deletions outlines/models/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Integration with Gemini's API."""
from enum import EnumMeta
from functools import singledispatchmethod
from types import NoneType
from typing import Optional, Union

from pydantic import BaseModel
from typing_extensions import _TypedDictMeta # type: ignore

from outlines.prompts import Vision
from outlines.types import Choice, Json

__all__ = ["Gemini"]


class GeminiBase:
"""Base class for the Gemini clients.
`GeminiBase` is responsible for preparing the arguments to Gemini's
`generate_contents` methods: the input (prompt and possibly image), as well
as the output type (only JSON).
"""

@singledispatchmethod
def format_input(self, model_input):
"""Generate the `messages` argument to pass to the client.
Argument
--------
model_input
The input passed by the user.
"""
raise NotImplementedError(
f"The input type {input} is not available with Gemini. The only available types are `str` and `Vision`."
)

@format_input.register(str)
def format_str_input(self, model_input: str):
"""Generate the `messages` argument to pass to the client when the user
only passes a prompt.
"""
return {"contents": [model_input]}

@format_input.register(Vision)
def format_vision_input(self, model_input: Vision):
"""Generate the `messages` argument to pass to the client when the user
passes a prompt and an image.
"""
return {"contents": [model_input.prompt, model_input.image]}

@singledispatchmethod
def format_output_type(self, output_type):
if output_type.__origin__ == list:
if len(output_type.__args__) == 1 and isinstance(
output_type.__args__[0], Json
):
return {
"response_mime_type": "application/json",
"response_schema": list[
output_type.__args__[0].original_definition
],
}
else:
raise TypeError
else:
raise NotImplementedError

@format_output_type.register(NoneType)
def format_none_output_type(self, output_type):
return {}

@format_output_type.register(Json)
def format_json_output_type(self, output_type):
if issubclass(output_type.original_definition, BaseModel):
return {
"response_mime_type": "application/json",
"response_schema": output_type.original_definition,
}
elif isinstance(output_type.original_definition, _TypedDictMeta):
return {
"response_mime_type": "application/json",
"response_schema": output_type.original_definition,
}
else:
raise NotImplementedError

@format_output_type.register(Choice)
def format_enum_output_type(self, output_type):
return {
"response_mime_type": "text/x.enum",
"response_schema": output_type.definition,
}


class Gemini(GeminiBase):
def __init__(self, model_name: str, *args, **kwargs):
import google.generativeai as genai

self.client = genai.GenerativeModel(model_name, *args, **kwargs)

def generate(
self,
model_input: Union[str, Vision],
output_type: Optional[Union[Json, EnumMeta]] = None,
**inference_kwargs,
):
import google.generativeai as genai

contents = self.format_input(model_input)
generation_config = genai.GenerationConfig(
**self.format_output_type(output_type)
)
completion = self.client.generate_content(
generation_config=generation_config, **contents, **inference_kwargs
)

return completion.text
2 changes: 2 additions & 0 deletions outlines/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Json:
"""

def __init__(self, definition: Union[str, dict, BaseModel]):
self.original_definition = definition

if isinstance(definition, type(BaseModel)):
definition = definition.model_json_schema()
if isinstance(definition, str):
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ test = [
"mlx-lm; platform_machine == 'arm64' and sys_platform == 'darwin'",
"huggingface_hub",
"openai>=1.0.0",
"google-generativeai",
"vllm; sys_platform != 'darwin'",
"transformers",
"pillow",
Expand Down Expand Up @@ -112,6 +113,7 @@ module = [
"exllamav2.*",
"jinja2",
"jsonschema.*",
"google.*",
"mamba_ssm.*",
"mlx_lm.*",
"mlx.*",
Expand Down
164 changes: 164 additions & 0 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import io
import json
from enum import Enum

import PIL
import pytest
import requests
from pydantic import BaseModel
from typing_extensions import TypedDict

from outlines.models.gemini import Gemini
from outlines.prompts import Vision
from outlines.types import Choice, Json

MODEL_NAME = "gemini-1.5-flash-latest"


def test_gemini_wrong_init_parameters():
with pytest.raises(TypeError, match="got an unexpected"):
Gemini(MODEL_NAME, foo=10)


def test_gemini_wrong_inference_parameters():
with pytest.raises(TypeError, match="got an unexpected"):
model = Gemini(MODEL_NAME)
model.generate("prompt", foo=10)


@pytest.mark.api_call
def test_gemini_simple_call():
model = Gemini(MODEL_NAME)
result = model.generate("Respond with one word. Not more.")
assert isinstance(result, str)


@pytest.mark.api_call
def test_gemini_simple_vision():
model = Gemini(MODEL_NAME)

url = "https://raw.githubusercontent.com/dottxt-ai/outlines/refs/heads/main/docs/assets/images/logo.png"
r = requests.get(url, stream=True)
if r.status_code == 200:
image = PIL.Image.open(io.BytesIO(r.content))

result = model.generate(Vision("What does this logo represent?", image))
assert isinstance(result, str)


@pytest.mark.api_call
def test_gemini_simple_pydantic():
model = Gemini(MODEL_NAME)

class Foo(BaseModel):
bar: int

result = model.generate("foo?", Json(Foo))
assert isinstance(result, str)
assert "bar" in json.loads(result)


@pytest.mark.xfail(reason="Vision models do not work with structured outputs.")
@pytest.mark.api_call
def test_gemini_simple_vision_pydantic():
model = Gemini(MODEL_NAME)

url = "https://raw.githubusercontent.com/dottxt-ai/outlines/refs/heads/main/docs/assets/images/logo.png"
r = requests.get(url, stream=True)
if r.status_code == 200:
image = PIL.Image.open(io.BytesIO(r.content))

class Logo(BaseModel):
name: int

result = model.generate(Vision("What does this logo represent?", image), Logo)
assert isinstance(result, str)
assert "name" in json.loads(result)


@pytest.mark.xfail(reason="Gemini seems to be unable to follow nested schemas.")
@pytest.mark.api_call
def test_gemini_nested_pydantic():
model = Gemini(MODEL_NAME)

class Bar(BaseModel):
fu: str

class Foo(BaseModel):
sna: int
bar: Bar

result = model.generate("foo?", Json(Foo))
assert isinstance(result, str)
assert "sna" in json.loads(result)
assert "bar" in json.loads(result)
assert "fu" in json.loads(result)["bar"]


@pytest.mark.xfail(
reason="The Gemini SDK's serialization method does not support Json Schema dictionaries."
)
@pytest.mark.api_call
def test_gemini_simple_json_schema_dict():
model = Gemini(MODEL_NAME)

schema = {
"properties": {"bar": {"title": "Bar", "type": "integer"}},
"required": ["bar"],
"title": "Foo",
"type": "object",
}
result = model.generate("foo?", Json(schema))
assert isinstance(result, str)
assert "bar" in json.loads(result)


@pytest.mark.xfail(
reason="The Gemini SDK's serialization method does not support Json Schema strings."
)
@pytest.mark.api_call
def test_gemini_simple_json_schema_string():
model = Gemini(MODEL_NAME)

schema = "{'properties': {'bar': {'title': 'Bar', 'type': 'integer'}}, 'required': ['bar'], 'title': 'Foo', 'type': 'object'}"
result = model.generate("foo?", Json(schema))
assert isinstance(result, str)
assert "bar" in json.loads(result)


@pytest.mark.api_call
def test_gemini_simple_typed_dict():
model = Gemini(MODEL_NAME)

class Foo(TypedDict):
bar: int

result = model.generate("foo?", Json(Foo))
assert isinstance(result, str)
assert "bar" in json.loads(result)


@pytest.mark.api_call
def test_gemini_simple_enum():
model = Gemini(MODEL_NAME)

class Foo(Enum):
bar = "Bar"
foor = "Foo"

result = model.generate("foo?", Choice(Foo))
assert isinstance(result, str)
assert result == "Foo" or result == "Bar"


@pytest.mark.api_call
def test_gemini_simple_list_pydantic():
model = Gemini(MODEL_NAME)

class Foo(BaseModel):
bar: int

result = model.generate("foo?", list[Json(Foo)])
assert isinstance(json.loads(result), list)
assert isinstance(json.loads(result)[0], dict)
assert "bar" in json.loads(result)[0]

0 comments on commit 9ed80a2

Please sign in to comment.