Skip to content

Commit a1da36f

Browse files
committed
openai - catch usage payload
1 parent b4ffd29 commit a1da36f

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

src/strands/models/openai.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
113113

114114
yield {"chunk_type": "message_stop", "data": choice.finish_reason}
115115

116-
event = next(response)
117-
if hasattr(event, "usage"):
118-
yield {"chunk_type": "metadata", "data": event.usage}
116+
# Skip remaining events as we don't have use for anything except the final usage payload
117+
for event in response:
118+
_ = event
119+
120+
yield {"chunk_type": "metadata", "data": event.usage}

tests/strands/models/test_openai.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,14 @@ def test_stream(openai_client, model):
109109

110110
def test_stream_empty(openai_client, model):
111111
mock_delta = unittest.mock.Mock(content=None, tool_calls=None)
112+
mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0)
112113

113114
mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)])
114115
mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)])
115-
mock_event_3 = unittest.mock.Mock(spec=[])
116+
mock_event_3 = unittest.mock.Mock()
117+
mock_event_4 = unittest.mock.Mock(usage=mock_usage)
116118

117-
openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3])
119+
openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4])
118120

119121
request = {"model": "m1", "messages": [{"role": "user", "content": []}]}
120122
response = model.stream(request)
@@ -125,6 +127,7 @@ def test_stream_empty(openai_client, model):
125127
{"chunk_type": "content_start", "data_type": "text"},
126128
{"chunk_type": "content_stop", "data_type": "text"},
127129
{"chunk_type": "message_stop", "data": "stop"},
130+
{"chunk_type": "metadata", "data": mock_usage},
128131
]
129132

130133
assert tru_events == exp_events

0 commit comments

Comments
 (0)