Skip to content

Commit 7cf6278

Browse files
committed
Update notebook magics to work with Gemini
Adds some aliases for the commands too: * `%%palm` still works, for backwards compat * `%%llm` as a generic magic * `%%gemini` as an analogue to `%%palm`
1 parent 0988543 commit 7cf6278

File tree

6 files changed

+62
-39
lines changed

6 files changed

+62
-39
lines changed

README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,14 @@ Checkout the full [API docs](https://developers.generativeai.google/api), the [g
4343

4444
## Colab magics
4545

46-
Once installed, use the Python client via the `%%palm` Colab magic. Read the full guide [here](https://developers.generativeai.google/tools/notebook_magic).
46+
```
47+
%pip install -q google-generativeai
48+
%load_ext google.generativeai.notebook
49+
```
50+
51+
Once installed, use the Python client via the `%%llm` Colab magic. Read the full guide [here](https://developers.generativeai.google/tools/notebook_magic).
4752

4853
```python
49-
%%palm
54+
%%llm
5055
The best thing since sliced bread is
5156
```

google/generativeai/notebook/cmd_line_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def _create_parser(
368368
placeholders: AbstractSet[str] | None,
369369
) -> argparse.ArgumentParser:
370370
"""Create the full parser."""
371-
system_name = "palm"
371+
system_name = "llm"
372372
description = "A system for interacting with LLMs."
373373
epilog = ""
374374

google/generativeai/notebook/magics.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
# limitations under the License.
1515
"""Colab Magics class.
1616
17-
Installs %%palm magics.
17+
Installs %%llm magics.
1818
"""
1919
from __future__ import annotations
2020

2121
import abc
2222

2323
from google.auth import credentials
24-
from google.generativeai import client as palm
24+
from google.generativeai import client as genai
2525
from google.generativeai.notebook import gspread_client
2626
from google.generativeai.notebook import ipython_env
2727
from google.generativeai.notebook import ipython_env_impl
@@ -34,9 +34,9 @@
3434

3535

3636
# Set the UA to distinguish the magic from the client. Do this at import-time
37-
# so that a user can still call `palm.configure()`, and both their settings
37+
# so that a user can still call `genai.configure()`, and both their settings
3838
# and this are honored.
39-
palm.USER_AGENT = "genai-py-magic"
39+
genai.USER_AGENT = "genai-py-magic"
4040

4141
SheetsInputs = sheets_utils.SheetsInputs
4242
SheetsOutputs = sheets_utils.SheetsOutputs
@@ -72,7 +72,7 @@ class AbstractMagics(abc.ABC):
7272
"""Defines interface to Magics class."""
7373

7474
@abc.abstractmethod
75-
def palm(self, cell_line: str | None, cell_body: str | None):
75+
def llm(self, cell_line: str | None, cell_body: str | None):
7676
"""Perform various LLM-related operations.
7777
7878
Args:
@@ -92,7 +92,7 @@ class MagicsImpl(AbstractMagics):
9292
def __init__(self):
9393
self._engine = magics_engine.MagicsEngine(env=_get_ipython_env())
9494

95-
def palm(self, cell_line: str | None, cell_body: str | None):
95+
def llm(self, cell_line: str | None, cell_body: str | None):
9696
"""Perform various LLM-related operations.
9797
9898
Args:
@@ -126,7 +126,7 @@ def get_instance(cls) -> AbstractMagics:
126126
return cls._instance
127127

128128
@magic.line_cell_magic
129-
def palm(self, cell_line: str | None, cell_body: str | None):
129+
def llm(self, cell_line: str | None, cell_body: str | None):
130130
"""Perform various LLM-related operations.
131131
132132
Args:
@@ -136,7 +136,15 @@ def palm(self, cell_line: str | None, cell_body: str | None):
136136
Returns:
137137
Results from running MagicsEngine.
138138
"""
139-
return Magics.get_instance().palm(cell_line=cell_line, cell_body=cell_body)
139+
return Magics.get_instance().llm(cell_line=cell_line, cell_body=cell_body)
140+
141+
@magic.line_cell_magic
142+
def palm(self, cell_line: str | None, cell_body: str | None):
143+
return self.llm(cell_line, cell_body)
144+
145+
@magic.line_cell_magic
146+
def gemini(self, cell_line: str | None, cell_body: str | None):
147+
return self.llm(cell_line, cell_body)
140148

141149

142150
IPython.get_ipython().register_magics(Magics)

google/generativeai/notebook/magics_engine_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def test_run_help(self):
248248

249249
# Should not raise an exception.
250250
results = engine.execute_cell(magic_line, "ignored")
251-
self.assertRegex(str(results), "usage: palm run")
251+
self.assertRegex(str(results), "usage: llm run")
252252

253253
def test_error(self):
254254
mock_registry = EchoModelRegistry()

google/generativeai/notebook/text_model.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,32 @@
1616
from __future__ import annotations
1717

1818
from google.api_core import retry
19-
from google.generativeai import text
20-
from google.generativeai.types import text_types
19+
import google.generativeai as genai
20+
from google.generativeai.types import generation_types
2121
from google.generativeai.notebook.lib import model as model_lib
2222

23+
_DEFAULT_MODEL = "models/gemini-pro"
24+
2325

2426
class TextModel(model_lib.AbstractModel):
25-
"""Concrete model that uses the Text service."""
27+
"""Concrete model that uses the generate_content service."""
2628

2729
def _generate_text(
2830
self,
2931
prompt: str,
3032
model: str | None = None,
3133
temperature: float | None = None,
3234
candidate_count: int | None = None,
33-
**kwargs,
34-
) -> text_types.Completion:
35-
if model is not None:
36-
kwargs["model"] = model
35+
) -> generation_types.GenerateContentResponse:
36+
gen_config = {}
3737
if temperature is not None:
38-
kwargs["temperature"] = temperature
38+
gen_config["temperature"] = temperature
3939
if candidate_count is not None:
40-
kwargs["candidate_count"] = candidate_count
41-
return text.generate_text(prompt=prompt, **kwargs)
40+
gen_config["candidate_count"] = candidate_count
41+
42+
model_name = model or _DEFAULT_MODEL
43+
gen_model = genai.GenerativeModel(model_name=model_name)
44+
return gen_model.generate_content(prompt, generation_config=gen_config)
4245

4346
def call_model(
4447
self,
@@ -58,7 +61,11 @@ def call_model(
5861
candidate_count=model_args.candidate_count,
5962
)
6063

64+
text_outputs = []
65+
for c in response.candidates:
66+
text_outputs.append("".join(p.text for p in c.content.parts))
67+
6168
return model_lib.ModelResults(
6269
model_input=model_input,
63-
text_results=[x["output"] for x in response.candidates],
70+
text_results=text_outputs,
6471
)

google/generativeai/notebook/text_model_test.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from absl.testing import absltest
2020

2121
from google.api_core import exceptions
22-
from google.generativeai import text
22+
from google.generativeai.types import generation_types
2323
from google.generativeai.notebook import text_model
2424
from google.generativeai.notebook.lib import model as model_lib
2525

@@ -29,20 +29,23 @@ def _fake_generator(
2929
model: str | None = None,
3030
temperature: float | None = None,
3131
candidate_count: int | None = None,
32-
) -> text.Completion:
33-
return text.Completion(
34-
prompt=prompt,
35-
model=model,
36-
temperature=temperature,
37-
candidate_count=candidate_count,
38-
# Smuggle the parameters as text output, so we can make assertions.
39-
candidates=[
40-
{"output": f"{prompt}_1"},
41-
{"output": model},
42-
{"output": temperature},
43-
{"output": candidate_count},
44-
],
45-
)
32+
):
33+
def make_candidate(txt):
34+
c = mock.Mock()
35+
p = mock.Mock()
36+
p.text = str(txt)
37+
c.content.parts = [p]
38+
return c
39+
40+
response = mock.Mock()
41+
# Smuggle the parameters as text output, so we can make assertions.
42+
response.candidates = [
43+
make_candidate(f"{prompt}_1"),
44+
make_candidate(model),
45+
make_candidate(temperature),
46+
make_candidate(candidate_count),
47+
]
48+
return response
4649

4750

4851
class TestModel(text_model.TextModel):
@@ -55,7 +58,7 @@ def _generate_text(
5558
temperature: float | None = None,
5659
candidate_count: int | None = None,
5760
**kwargs,
58-
) -> text.Completion:
61+
) -> generation_types.GenerateContentResponse:
5962
return _fake_generator(
6063
prompt=prompt,
6164
model=model,

0 commit comments

Comments
 (0)