Skip to content

Commit

Permalink
Merge pull request #1441 from phidatahq/fix-logging-for-gemini-phi-1989
Browse files Browse the repository at this point in the history
fix-logging-for-gemini-phi-1989
  • Loading branch information
ashpreetbedi authored Nov 14, 2024
2 parents 8bcbf70 + 0c3bd02 commit 342d882
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions phi/model/google/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,11 @@ def response(self, messages: List[Message]) -> ModelResponse:
if assistant_message.content is not None:
model_response.content = assistant_message.get_content_string()

# -*- Remove parts from messages
for m in messages:
if hasattr(m, "parts"):
m.parts = None

logger.debug("---------- Gemini Response End ----------")
return model_response

Expand Down Expand Up @@ -551,13 +556,14 @@ def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]:
for response in self.invoke_stream(messages=messages):
message_data.response_block = response.candidates[0].content
message_data.response_role = message_data.response_block.role
message_data.response_parts = message_data.response_block.parts
if message_data.response_block.parts:
message_data.response_parts = message_data.response_block.parts

if message_data.response_parts is not None:
for part in message_data.response_parts:
part_dict = type(part).to_dict(part)

# Yield text if present
# -*- Yield text if present
if "text" in part_dict:
text = part_dict.get("text")
yield ModelResponse(content=text)
Expand All @@ -567,7 +573,10 @@ def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]:
stream_usage_data.time_to_first_token = response_timer.elapsed
logger.debug(f"Time to first token: {stream_usage_data.time_to_first_token:.4f}s")

# Parse function calls
# -*- Skip function calls if there are no parts
if not message_data.response_block.parts and message_data.response_parts:
continue
# -*- Parse function calls
if "function_call" in part_dict:
message_data.response_tool_calls.append(
{
Expand Down Expand Up @@ -610,4 +619,10 @@ def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]:
if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0 and self.run_tools:
yield from self._handle_stream_tool_calls(assistant_message, messages)
yield from self.response_stream(messages=messages)

# -*- Remove parts from messages
for m in messages:
if hasattr(m, "parts"):
m.parts = None

logger.debug("---------- Gemini Response End ----------")

0 comments on commit 342d882

Please sign in to comment.