Skip to content

Commit 1a74663

Browse files
committed
basic case
Change-Id: I8fe1310f45c86aa01ea1b589410cbaacad1a2b63
1 parent 19c5560 commit 1a74663

File tree

3 files changed

+46
-8
lines changed

3 files changed

+46
-8
lines changed

google/generativeai/generative_models.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def _prepare_request(
123123
contents: content_types.ContentsType,
124124
generation_config: generation_types.GenerationConfigType | None = None,
125125
safety_settings: safety_types.SafetySettingOptions | None = None,
126-
tools: content_types.FunctionLibraryType | None,
127-
tool_config: content_types.ToolConfigType | None,
126+
tools: content_types.FunctionLibraryType | None = None,
127+
tool_config: content_types.ToolConfigType | None = None,
128128
) -> protos.GenerateContentRequest:
129129
"""Creates a `protos.GenerateContentRequest` from raw inputs."""
130130
tools_lib = self._get_tools_lib(tools)
@@ -435,9 +435,14 @@ def __init__(
435435
self._last_received: generation_types.BaseGenerateContentResponse | None = None
436436
self.enable_automatic_function_calling = enable_automatic_function_calling
437437

438-
def to_dict(self):
439-
request = self.model._prepare_request(contents = self.history)
440-
return type(request).to_dict(use_integers_for_enums=False, including_default_value_fields=False)
438+
def to_dict(self, tools=True):
439+
if tools == True and self.model._tools is not None:
440+
pass # raise ValueError("")
441+
442+
request = self.model._prepare_request(contents=self.history)
443+
return type(request).to_dict(
444+
request, use_integers_for_enums=False, including_default_value_fields=False
445+
)
441446

442447
@classmethod
443448
def from_dict(cls, obj):
@@ -448,7 +453,8 @@ def from_dict(cls, obj):
448453
tool_config=request.tool_config,
449454
tools=request.tools,
450455
safety_settings=request.safety_settings,
451-
system_instruction=request.system_instruction)
456+
system_instruction=request.system_instruction,
457+
)
452458

453459
return model.start_chat(history=request.contents)
454460

tests/test_chat.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def setUp(self):
9797
self.client = MockGenerativeServiceClient(self)
9898
client_lib._client_manager.clients["generative"] = self.client
9999

100-
101100
def test_chat(self):
102101
# Multi turn chat
103102
model = generative_models.GenerativeModel("gemini-pro")
@@ -706,6 +705,40 @@ def test_chat_with_request_options(self):
706705
request_options["retry"] = None
707706
self.assertEqual(request_options, self.observed_kwargs[0])
708707

708+
def test_serialize(self):
709+
def add(x, y):
710+
return x + y
711+
712+
model = generative_models.GenerativeModel(
713+
model_name=",models/gemini-1.5-flash",
714+
generation_config={"max_output_tokens": 65},
715+
tools=[add],
716+
safety_settings="block_none",
717+
system_instruction="you are a cat",
718+
)
719+
720+
chat = model.start_chat(
721+
history=[
722+
{"role": "user", "parts": "hello"},
723+
{"role": "model", "parts": "meow!"},
724+
{
725+
"role": "user",
726+
"parts": [
727+
"what's this picture?",
728+
{"mime_type": "image/png", "data": b"PNG!"},
729+
],
730+
},
731+
]
732+
)
733+
734+
chat_json = chat.to_dict()
735+
736+
new_chat = generative_models.ChatSession.from_dict(chat_json)
737+
738+
new_chat_json = chat.to_dict()
739+
740+
self.assertEqual(chat_json, new_chat_json)
741+
709742

710743
if __name__ == "__main__":
711744
absltest.main()

tests/test_generative_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,5 @@ def test_count_tokens_called_with_request_options(self):
667667
self.assertEqual(request_options, self.observed_kwargs[0])
668668

669669

670-
671670
if __name__ == "__main__":
672671
absltest.main()

0 commit comments

Comments
 (0)