|
6 | 6 |
|
7 | 7 | from beartype import beartype
|
8 | 8 | from fastapi.background import BackgroundTasks
|
9 |
| -from litellm.utils import CustomStreamWrapper, ModelResponse, ModelResponseStream |
| 9 | +from litellm.utils import ( |
| 10 | + ChatCompletionMessageToolCall, |
| 11 | + CustomStreamWrapper, |
| 12 | + ModelResponse, |
| 13 | + ModelResponseStream, |
| 14 | +) |
10 | 15 |
|
11 | 16 | from ...app import app
|
12 | 17 | from ...autogen.openapi_model import (
|
@@ -295,7 +300,7 @@ async def completion(
|
295 | 300 | response: ModelResponse | CustomStreamWrapper | None = None
|
296 | 301 | stream: bool = kwargs.get("stream", False)
|
297 | 302 | while True:
|
298 |
| - tool_calls = [] |
| 303 | + tool_calls: list[ChatCompletionMessageToolCall | dict] = [] |
299 | 304 |
|
300 | 305 | if not stream:
|
301 | 306 | response: ModelResponse = await self._completion_func(**kwargs)
|
@@ -339,13 +344,23 @@ async def completion(
|
339 | 344 |
|
340 | 345 | for tool in tool_calls:
|
341 | 346 | # call a tool
|
342 |
| - tool_name = tool.function.name |
343 |
| - tool_args = json.loads(tool.function.arguments) |
| 347 | + tool_name = ( |
| 348 | + tool.function.name |
| 349 | + if isinstance(tool, ChatCompletionMessageToolCall) |
| 350 | + else tool["function"]["name"] |
| 351 | + ) |
| 352 | + tool_args = json.loads( |
| 353 | + tool.function.arguments |
| 354 | + if isinstance(tool, ChatCompletionMessageToolCall) |
| 355 | + else tool["function"]["arguments"] |
| 356 | + ) |
344 | 357 | tool_response = await self._call_tool(developer_id, tool_name, tool_args)
|
345 | 358 |
|
346 | 359 | # append result to messages from previous step
|
347 | 360 | kwargs["messages"].append({
|
348 |
| - "tool_call_id": tool.id, |
| 361 | + "tool_call_id": tool.id |
| 362 | + if isinstance(tool, ChatCompletionMessageToolCall) |
| 363 | + else tool["id"], |
349 | 364 | "role": "tool",
|
350 | 365 | "name": tool_name,
|
351 | 366 | "content": tool_response,
|
|
0 commit comments