-
-
Notifications
You must be signed in to change notification settings - Fork 756
fix: eliminate double API calls when using tools with streaming #753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -276,6 +276,41 @@ def _needs_system_message_skip(self) -> bool: | |
| ] | ||
|
|
||
| return self.model in legacy_o1_models | ||
|
|
||
| def _supports_streaming_tools(self) -> bool: | ||
| """ | ||
| Check if the current provider supports streaming with tools. | ||
|
|
||
| Most providers that support tool calling also support streaming with tools, | ||
| but some providers (like Ollama and certain local models) require non-streaming | ||
| calls when tools are involved. | ||
|
|
||
| Returns: | ||
| bool: True if provider supports streaming with tools, False otherwise | ||
| """ | ||
| if not self.model: | ||
| return False | ||
|
|
||
| # Ollama doesn't reliably support streaming with tools | ||
| if self._is_ollama_provider(): | ||
| return False | ||
|
|
||
| # OpenAI models support streaming with tools | ||
| if any(self.model.startswith(prefix) for prefix in ["gpt-", "o1-", "o3-"]): | ||
| return True | ||
|
|
||
| # Anthropic Claude models support streaming with tools | ||
| if self.model.startswith("claude-"): | ||
| return True | ||
|
|
||
| # Google Gemini models support streaming with tools | ||
| if any(self.model.startswith(prefix) for prefix in ["gemini-", "gemini/"]): | ||
| return True | ||
|
|
||
| # For other providers, default to False to be safe | ||
| # This ensures we make a single non-streaming call rather than risk | ||
| # missing tool calls or making duplicate calls | ||
| return False | ||
|
|
||
| def get_response( | ||
| self, | ||
|
|
@@ -480,49 +515,110 @@ def get_response( | |
|
|
||
| # Otherwise do the existing streaming approach | ||
| else: | ||
| if verbose: | ||
| with Live(display_generating("", current_time), console=console, refresh_per_second=4) as live: | ||
| response_text = "" | ||
| # Determine if we should use streaming based on tool support | ||
| use_streaming = stream | ||
| if formatted_tools and not self._supports_streaming_tools(): | ||
| # Provider doesn't support streaming with tools, use non-streaming | ||
| use_streaming = False | ||
|
|
||
| if use_streaming: | ||
| # Streaming approach (with or without tools) | ||
| tool_calls = [] | ||
| response_text = "" | ||
|
|
||
| if verbose: | ||
| with Live(display_generating("", current_time), console=console, refresh_per_second=4) as live: | ||
| for chunk in litellm.completion( | ||
| **self._build_completion_params( | ||
| messages=messages, | ||
| tools=formatted_tools, | ||
| temperature=temperature, | ||
| stream=True, | ||
| **kwargs | ||
| ) | ||
| ): | ||
| if chunk and chunk.choices and chunk.choices[0].delta: | ||
| delta = chunk.choices[0].delta | ||
| if delta.content: | ||
| response_text += delta.content | ||
| live.update(display_generating(response_text, current_time)) | ||
|
|
||
| # Capture tool calls from streaming chunks if provider supports it | ||
| if formatted_tools and self._supports_streaming_tools() and hasattr(delta, 'tool_calls') and delta.tool_calls: | ||
| for tc in delta.tool_calls: | ||
| if tc.index >= len(tool_calls): | ||
| tool_calls.append({ | ||
| "id": tc.id, | ||
| "type": "function", | ||
| "function": {"name": "", "arguments": ""} | ||
| }) | ||
| if tc.function.name: | ||
| tool_calls[tc.index]["function"]["name"] = tc.function.name | ||
| if tc.function.arguments: | ||
| tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments | ||
| else: | ||
| # Non-verbose streaming | ||
| for chunk in litellm.completion( | ||
| **self._build_completion_params( | ||
| messages=messages, | ||
| tools=formatted_tools, | ||
| temperature=temperature, | ||
| stream=stream, | ||
| stream=True, | ||
| **kwargs | ||
| ) | ||
| ): | ||
| if chunk and chunk.choices and chunk.choices[0].delta.content: | ||
| content = chunk.choices[0].delta.content | ||
| response_text += content | ||
| live.update(display_generating(response_text, current_time)) | ||
| if chunk and chunk.choices and chunk.choices[0].delta: | ||
| delta = chunk.choices[0].delta | ||
| if delta.content: | ||
| response_text += delta.content | ||
|
|
||
| # Capture tool calls from streaming chunks if provider supports it | ||
| if formatted_tools and self._supports_streaming_tools() and hasattr(delta, 'tool_calls') and delta.tool_calls: | ||
| for tc in delta.tool_calls: | ||
| if tc.index >= len(tool_calls): | ||
| tool_calls.append({ | ||
| "id": tc.id, | ||
| "type": "function", | ||
| "function": {"name": "", "arguments": ""} | ||
| }) | ||
| if tc.function.name: | ||
| tool_calls[tc.index]["function"]["name"] = tc.function.name | ||
| if tc.function.arguments: | ||
| tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments | ||
|
|
||
| response_text = response_text.strip() | ||
|
|
||
| # Create a mock final_response with the captured data | ||
| final_response = { | ||
| "choices": [{ | ||
| "message": { | ||
| "content": response_text, | ||
| "tool_calls": tool_calls if tool_calls else None | ||
| } | ||
| }] | ||
| } | ||
| else: | ||
| # Non-verbose mode, just collect the response | ||
| response_text = "" | ||
| for chunk in litellm.completion( | ||
| # Non-streaming approach (when tools require it or streaming is disabled) | ||
| final_response = litellm.completion( | ||
| **self._build_completion_params( | ||
| messages=messages, | ||
| tools=formatted_tools, | ||
| temperature=temperature, | ||
| stream=stream, | ||
| stream=False, | ||
| **kwargs | ||
| ) | ||
| ): | ||
| if chunk and chunk.choices and chunk.choices[0].delta.content: | ||
| response_text += chunk.choices[0].delta.content | ||
|
|
||
| response_text = response_text.strip() | ||
|
|
||
| # Get final completion to check for tool calls | ||
| final_response = litellm.completion( | ||
| **self._build_completion_params( | ||
| messages=messages, | ||
| tools=formatted_tools, | ||
| temperature=temperature, | ||
| stream=False, # No streaming for tool call check | ||
| **kwargs | ||
| ) | ||
| ) | ||
| response_text = final_response["choices"][0]["message"]["content"] | ||
|
|
||
| if verbose: | ||
| # Display the complete response at once | ||
| display_interaction( | ||
| original_prompt, | ||
| response_text, | ||
| markdown=markdown, | ||
| generation_time=time.time() - current_time, | ||
| console=console | ||
| ) | ||
|
|
||
| tool_calls = final_response["choices"][0]["message"].get("tool_calls") | ||
|
|
||
|
|
@@ -1198,53 +1294,106 @@ async def get_response_async( | |
| console=console | ||
| ) | ||
| else: | ||
| if verbose: | ||
| # ---------------------------------------------------- | ||
| # 1) Make the streaming call WITHOUT tools | ||
| # ---------------------------------------------------- | ||
| async for chunk in await litellm.acompletion( | ||
| **self._build_completion_params( | ||
| messages=messages, | ||
| temperature=temperature, | ||
| stream=stream, | ||
| **kwargs | ||
| ) | ||
| ): | ||
| if chunk and chunk.choices and chunk.choices[0].delta.content: | ||
| response_text += chunk.choices[0].delta.content | ||
| print("\033[K", end="\r") | ||
| print(f"Generating... {time.time() - start_time:.1f}s", end="\r") | ||
| # Determine if we should use streaming based on tool support | ||
| use_streaming = stream | ||
| if formatted_tools and not self._supports_streaming_tools(): | ||
| # Provider doesn't support streaming with tools, use non-streaming | ||
| use_streaming = False | ||
|
|
||
| if use_streaming: | ||
| # Streaming approach (with or without tools) | ||
| tool_calls = [] | ||
|
|
||
| if verbose: | ||
| async for chunk in await litellm.acompletion( | ||
| **self._build_completion_params( | ||
| messages=messages, | ||
| temperature=temperature, | ||
| stream=True, | ||
| tools=formatted_tools, | ||
| **kwargs | ||
| ) | ||
| ): | ||
| if chunk and chunk.choices and chunk.choices[0].delta: | ||
| delta = chunk.choices[0].delta | ||
| if delta.content: | ||
| response_text += delta.content | ||
| print("\033[K", end="\r") | ||
| print(f"Generating... {time.time() - start_time:.1f}s", end="\r") | ||
|
|
||
| # Capture tool calls from streaming chunks if provider supports it | ||
| if formatted_tools and self._supports_streaming_tools() and hasattr(delta, 'tool_calls') and delta.tool_calls: | ||
| for tc in delta.tool_calls: | ||
| if tc.index >= len(tool_calls): | ||
| tool_calls.append({ | ||
| "id": tc.id, | ||
| "type": "function", | ||
| "function": {"name": "", "arguments": ""} | ||
| }) | ||
| if tc.function.name: | ||
| tool_calls[tc.index]["function"]["name"] = tc.function.name | ||
| if tc.function.arguments: | ||
| tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments | ||
| else: | ||
| # Non-verbose streaming | ||
| async for chunk in await litellm.acompletion( | ||
| **self._build_completion_params( | ||
| messages=messages, | ||
| temperature=temperature, | ||
| stream=True, | ||
| tools=formatted_tools, | ||
| **kwargs | ||
| ) | ||
| ): | ||
| if chunk and chunk.choices and chunk.choices[0].delta: | ||
| delta = chunk.choices[0].delta | ||
| if delta.content: | ||
| response_text += delta.content | ||
|
|
||
| # Capture tool calls from streaming chunks if provider supports it | ||
| if formatted_tools and self._supports_streaming_tools() and hasattr(delta, 'tool_calls') and delta.tool_calls: | ||
| for tc in delta.tool_calls: | ||
| if tc.index >= len(tool_calls): | ||
| tool_calls.append({ | ||
| "id": tc.id, | ||
| "type": "function", | ||
| "function": {"name": "", "arguments": ""} | ||
| }) | ||
| if tc.function.name: | ||
| tool_calls[tc.index]["function"]["name"] = tc.function.name | ||
| if tc.function.arguments: | ||
| tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments | ||
|
Comment on lines
+1307
to
+1365
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the synchronous version, there's significant code duplication here for handling streaming in verbose and non-verbose modes. The logic for processing chunks and capturing tool calls is identical. This can be refactored into a single tool_calls = []
async def process_chunk(chunk):
nonlocal response_text
if chunk and chunk.choices and chunk.choices[0].delta:
delta = chunk.choices[0].delta
if delta.content:
response_text += delta.content
if verbose:
print("\033[K", end="\r")
print(f"Generating... {time.time() - start_time:.1f}s", end="\r")
if formatted_tools and self._supports_streaming_tools() and hasattr(delta, 'tool_calls') and delta.tool_calls:
for tc in delta.tool_calls:
if tc.index >= len(tool_calls):
tool_calls.append({
"id": tc.id,
"type": "function",
"function": {"name": "", "arguments": ""}
})
if tc.function.name:
tool_calls[tc.index]["function"]["name"] = tc.function.name
if tc.function.arguments:
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
stream_iterator = await litellm.acompletion(
**self._build_completion_params(
messages=messages,
temperature=temperature,
stream=True,
tools=formatted_tools,
**kwargs
)
)
async for chunk in stream_iterator:
await process_chunk(chunk) |
||
|
|
||
| response_text = response_text.strip() | ||
|
|
||
| # We already have tool_calls from streaming if supported | ||
| # No need for a second API call! | ||
| else: | ||
| # Non-verbose streaming call, still no tools | ||
| async for chunk in await litellm.acompletion( | ||
| # Non-streaming approach (when tools require it or streaming is disabled) | ||
| tool_response = await litellm.acompletion( | ||
| **self._build_completion_params( | ||
| messages=messages, | ||
| temperature=temperature, | ||
| stream=stream, | ||
| **kwargs | ||
| stream=False, | ||
| tools=formatted_tools, | ||
| **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} | ||
| ) | ||
| ): | ||
| if chunk and chunk.choices and chunk.choices[0].delta.content: | ||
| response_text += chunk.choices[0].delta.content | ||
|
|
||
| response_text = response_text.strip() | ||
|
|
||
| # ---------------------------------------------------- | ||
| # 2) If tool calls are needed, do a non-streaming call | ||
| # ---------------------------------------------------- | ||
| if tools and execute_tool_fn: | ||
| # Next call with tools if needed | ||
| tool_response = await litellm.acompletion( | ||
| **self._build_completion_params( | ||
| messages=messages, | ||
| temperature=temperature, | ||
| stream=False, | ||
| tools=formatted_tools, # We safely pass tools here | ||
| **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} | ||
| ) | ||
| ) | ||
| # handle tool_calls from tool_response as usual... | ||
| tool_calls = tool_response.choices[0].message.get("tool_calls") | ||
| response_text = tool_response.choices[0].message.get("content", "") | ||
| tool_calls = tool_response.choices[0].message.get("tool_calls", []) | ||
|
|
||
| if verbose: | ||
| # Display the complete response at once | ||
| display_interaction( | ||
| original_prompt, | ||
| response_text, | ||
| markdown=markdown, | ||
| generation_time=time.time() - start_time, | ||
| console=console | ||
| ) | ||
|
|
||
| # Now handle tools if we have them (either from streaming or non-streaming) | ||
| if tools and execute_tool_fn and tool_calls: | ||
|
|
||
| if tool_calls: | ||
| # Convert tool_calls to a serializable format for all providers | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's significant code duplication between the
verboseand non-verbosestreaming logic. The loop that processes chunks and captures tool calls is nearly identical in both branches. This can be refactored into a single loop to reduce duplication and improve maintainability.