|
13 | 13 | import json |
14 | 14 | import logging |
15 | 15 | import mimetypes |
16 | | -from typing import Any, AsyncGenerator, Dict, Optional, Type, TypedDict, TypeVar, Union, cast |
| 16 | +from typing import ( |
| 17 | + Any, |
| 18 | + AsyncGenerator, |
| 19 | + Dict, |
| 20 | + Optional, |
| 21 | + Type, |
| 22 | + TypedDict, |
| 23 | + TypeVar, |
| 24 | + Union, |
| 25 | + cast, |
| 26 | +) |
17 | 27 |
|
18 | 28 | import httpx |
19 | 29 | from pydantic import BaseModel |
@@ -385,7 +395,10 @@ def _format_messages(self, messages: Messages, system_prompt: Optional[str] = No |
385 | 395 | ] |
386 | 396 | formatted_tool_messages = [ |
387 | 397 | self._format_tool_message( |
388 | | - {"toolUseId": content["toolResult"]["toolUseId"], "content": content["toolResult"]["content"]} |
| 398 | + { |
| 399 | + "toolUseId": content["toolResult"]["toolUseId"], |
| 400 | + "content": content["toolResult"]["content"], |
| 401 | + } |
389 | 402 | ) |
390 | 403 | for content in contents |
391 | 404 | if "toolResult" in content |
@@ -605,6 +618,7 @@ async def stream( |
605 | 618 |
|
606 | 619 | tool_calls: Dict[int, list] = {} |
607 | 620 | usage_data = None |
| 621 | + finish_reason = None |
608 | 622 |
|
609 | 623 | async for line in response.aiter_lines(): |
610 | 624 | if not line.strip() or not line.startswith("data: "): |
@@ -650,6 +664,7 @@ async def stream( |
650 | 664 |
|
651 | 665 | # Check for finish reason |
652 | 666 | if choice.get("finish_reason"): |
| 667 | + finish_reason = choice.get("finish_reason") |
653 | 668 | break |
654 | 669 |
|
655 | 670 | yield self._format_chunk({"chunk_type": "content_stop"}) |
@@ -702,7 +717,12 @@ async def stream( |
702 | 717 | yield self._format_chunk({"chunk_type": "content_stop"}) |
703 | 718 |
|
704 | 719 | # Send stop reason |
705 | | - stop_reason = "tool_use" if tool_calls else getattr(choice, "finish_reason", "end_turn") |
| 720 | + logger.debug("finish_reason=%s, tool_calls=%s", finish_reason, bool(tool_calls)) |
| 721 | + if finish_reason == "tool_calls" or tool_calls: |
| 722 | + stop_reason = "tool_calls" # Changed from "tool_use" to match format_chunk expectations |
| 723 | + else: |
| 724 | + stop_reason = finish_reason or "end_turn" |
| 725 | + logger.debug("stop_reason=%s", stop_reason) |
706 | 726 | yield self._format_chunk({"chunk_type": "message_stop", "data": stop_reason}) |
707 | 727 |
|
708 | 728 | # Send usage metadata if available |
|
0 commit comments