Skip to content

Update notebook magics to work with Gemini #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,15 @@ response.reply("Can you tell me a joke?")

### Colab magics

Once installed, use the Python client via the `%%palm` Colab magic. Read the [full guide](https://github.com/google/generative-ai-docs/blob/main/site/en/palm_docs/notebook_magic.ipynb).

```
%%palm
%pip install -q google-generativeai
%load_ext google.generativeai.notebook
```

Once installed, use the Python client via the `%%llm` Colab magic. Read the full guide [here](https://developers.generativeai.google/tools/notebook_magic).

```python
%%llm
The best thing since sliced bread is
```

Expand Down
2 changes: 1 addition & 1 deletion google/generativeai/notebook/cmd_line_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _create_parser(
placeholders: AbstractSet[str] | None,
) -> argparse.ArgumentParser:
"""Create the full parser."""
system_name = "palm"
system_name = "llm"
description = "A system for interacting with LLMs."
epilog = ""

Expand Down
24 changes: 16 additions & 8 deletions google/generativeai/notebook/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# limitations under the License.
"""Colab Magics class.
Installs %%palm magics.
Installs %%llm magics.
"""
from __future__ import annotations

import abc

from google.auth import credentials
from google.generativeai import client as palm
from google.generativeai import client as genai
from google.generativeai.notebook import gspread_client
from google.generativeai.notebook import ipython_env
from google.generativeai.notebook import ipython_env_impl
Expand All @@ -34,9 +34,9 @@


# Set the UA to distinguish the magic from the client. Do this at import-time
# so that a user can still call `palm.configure()`, and both their settings
# so that a user can still call `genai.configure()`, and both their settings
# and this are honored.
palm.USER_AGENT = "genai-py-magic"
genai.USER_AGENT = "genai-py-magic"

SheetsInputs = sheets_utils.SheetsInputs
SheetsOutputs = sheets_utils.SheetsOutputs
Expand Down Expand Up @@ -72,7 +72,7 @@ class AbstractMagics(abc.ABC):
"""Defines interface to Magics class."""

@abc.abstractmethod
def palm(self, cell_line: str | None, cell_body: str | None):
def llm(self, cell_line: str | None, cell_body: str | None):
"""Perform various LLM-related operations.
Args:
Expand All @@ -92,7 +92,7 @@ class MagicsImpl(AbstractMagics):
def __init__(self):
self._engine = magics_engine.MagicsEngine(env=_get_ipython_env())

def palm(self, cell_line: str | None, cell_body: str | None):
def llm(self, cell_line: str | None, cell_body: str | None):
"""Perform various LLM-related operations.
Args:
Expand Down Expand Up @@ -126,7 +126,7 @@ def get_instance(cls) -> AbstractMagics:
return cls._instance

@magic.line_cell_magic
def palm(self, cell_line: str | None, cell_body: str | None):
def llm(self, cell_line: str | None, cell_body: str | None):
"""Perform various LLM-related operations.
Args:
Expand All @@ -136,7 +136,15 @@ def palm(self, cell_line: str | None, cell_body: str | None):
Returns:
Results from running MagicsEngine.
"""
return Magics.get_instance().palm(cell_line=cell_line, cell_body=cell_body)
return Magics.get_instance().llm(cell_line=cell_line, cell_body=cell_body)

@magic.line_cell_magic
def palm(self, cell_line: str | None, cell_body: str | None):
return self.llm(cell_line, cell_body)

@magic.line_cell_magic
def gemini(self, cell_line: str | None, cell_body: str | None):
return self.llm(cell_line, cell_body)


IPython.get_ipython().register_magics(Magics)
29 changes: 18 additions & 11 deletions google/generativeai/notebook/text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,32 @@
from __future__ import annotations

from google.api_core import retry
from google.generativeai import text
from google.generativeai.types import text_types
import google.generativeai as genai
from google.generativeai.types import generation_types
from google.generativeai.notebook.lib import model as model_lib

_DEFAULT_MODEL = "models/gemini-pro"


class TextModel(model_lib.AbstractModel):
"""Concrete model that uses the Text service."""
"""Concrete model that uses the generate_content service."""

def _generate_text(
self,
prompt: str,
model: str | None = None,
temperature: float | None = None,
candidate_count: int | None = None,
**kwargs,
) -> text_types.Completion:
if model is not None:
kwargs["model"] = model
) -> generation_types.GenerateContentResponse:
gen_config = {}
if temperature is not None:
kwargs["temperature"] = temperature
gen_config["temperature"] = temperature
if candidate_count is not None:
kwargs["candidate_count"] = candidate_count
return text.generate_text(prompt=prompt, **kwargs)
gen_config["candidate_count"] = candidate_count

model_name = model or _DEFAULT_MODEL
gen_model = genai.GenerativeModel(model_name=model_name)
return gen_model.generate_content(prompt, generation_config=gen_config)

def call_model(
self,
Expand All @@ -58,7 +61,11 @@ def call_model(
candidate_count=model_args.candidate_count,
)

text_outputs = []
for c in response.candidates:
text_outputs.append("".join(p.text for p in c.content.parts))

return model_lib.ModelResults(
model_input=model_input,
text_results=[x["output"] for x in response.candidates],
text_results=text_outputs,
)
2 changes: 1 addition & 1 deletion tests/notebook/test_magics_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_run_help(self):

# Should not raise an exception.
results = engine.execute_cell(magic_line, "ignored")
self.assertRegex(str(results), "usage: palm run")
self.assertRegex(str(results), "usage: llm run")

def test_error(self):
mock_registry = EchoModelRegistry()
Expand Down
35 changes: 19 additions & 16 deletions tests/notebook/text_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from absl.testing import absltest

from google.api_core import exceptions
from google.generativeai import text
from google.generativeai.types import generation_types
from google.generativeai.notebook import text_model
from google.generativeai.notebook.lib import model as model_lib

Expand All @@ -29,20 +29,23 @@ def _fake_generator(
model: str | None = None,
temperature: float | None = None,
candidate_count: int | None = None,
) -> text.Completion:
return text.Completion(
prompt=prompt,
model=model,
temperature=temperature,
candidate_count=candidate_count,
# Smuggle the parameters as text output, so we can make assertions.
candidates=[
{"output": f"{prompt}_1"},
{"output": model},
{"output": temperature},
{"output": candidate_count},
],
)
):
def make_candidate(txt):
c = mock.Mock()
p = mock.Mock()
p.text = str(txt)
c.content.parts = [p]
return c

response = mock.Mock()
# Smuggle the parameters as text output, so we can make assertions.
response.candidates = [
make_candidate(f"{prompt}_1"),
make_candidate(model),
make_candidate(temperature),
make_candidate(candidate_count),
]
return response


class TestModel(text_model.TextModel):
Expand All @@ -55,7 +58,7 @@ def _generate_text(
temperature: float | None = None,
candidate_count: int | None = None,
**kwargs,
) -> text.Completion:
) -> generation_types.GenerateContentResponse:
return _fake_generator(
prompt=prompt,
model=model,
Expand Down