Skip to content

Commit ab6f765

Browse files
committed
counting
Signed-off-by: yxia216 <yxia216@bloomberg.net>
1 parent a7ded26 commit ab6f765

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

internal/extproc/translator/openai_awsbedrock.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ type openAIToAWSBedrockTranslatorV1ChatCompletion struct {
4242
// Translator is created for each request/response stream inside external processor, accordingly the role is not reused by multiple streams.
4343
role string
4444
requestModel internalapi.RequestModel
45+
toolIndex *int64
4546
}
4647

4748
// RequestBody implements [OpenAIChatCompletionTranslator.RequestBody].
@@ -598,6 +599,8 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(_ map[string
598599
}
599600
o.bufferedBody = append(o.bufferedBody, buf...)
600601
o.extractAmazonEventStreamEvents()
602+
toolIndex := int64(0)
603+
o.toolIndex = &toolIndex
601604

602605
for i := range o.events {
603606
event := &o.events[i]
@@ -771,6 +774,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) convertEvent(event *awsbe
771774
},
772775
})
773776
case event.Delta.ToolUse != nil:
777+
774778
chunk.Choices = append(chunk.Choices, openai.ChatCompletionResponseChunkChoice{
775779
Index: 0,
776780
Delta: &openai.ChatCompletionResponseChunkChoiceDelta{
@@ -781,7 +785,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) convertEvent(event *awsbe
781785
Arguments: event.Delta.ToolUse.Input,
782786
},
783787
Type: openai.ChatCompletionMessageToolCallTypeFunction,
784-
Index: 0,
788+
Index: *o.toolIndex,
785789
},
786790
},
787791
},
@@ -819,11 +823,12 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) convertEvent(event *awsbe
819823
Name: event.Start.ToolUse.Name,
820824
},
821825
Type: openai.ChatCompletionMessageToolCallTypeFunction,
822-
Index: 0,
826+
Index: *o.toolIndex,
823827
},
824828
},
825829
},
826830
})
831+
*o.toolIndex++
827832
}
828833
case event.StopReason != nil:
829834
chunk.Choices = append(chunk.Choices, openai.ChatCompletionResponseChunkChoice{

internal/extproc/translator/openai_gcpanthropic_stream.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ type streamingToolCall struct {
4040
type anthropicStreamParser struct {
4141
buffer bytes.Buffer
4242
activeMessageID string
43-
activeToolCalls map[int]*streamingToolCall
43+
activeToolCalls map[int64]*streamingToolCall
44+
toolIndex *int64
4445
tokenUsage LLMTokenUsage
4546
stopReason anthropic.StopReason
4647
requestModel internalapi.RequestModel
@@ -51,7 +52,7 @@ type anthropicStreamParser struct {
5152
func newAnthropicStreamParser(requestModel string) *anthropicStreamParser {
5253
return &anthropicStreamParser{
5354
requestModel: requestModel,
54-
activeToolCalls: make(map[int]*streamingToolCall),
55+
activeToolCalls: make(map[int64]*streamingToolCall),
5556
}
5657
}
5758

@@ -124,15 +125,15 @@ func (p *anthropicStreamParser) Process(body io.Reader, endOfStream bool, span t
124125

125126
// Add active tool calls to the final chunk.
126127
var toolCalls []openai.ChatCompletionChunkChoiceDeltaToolCall
127-
for _, tool := range p.activeToolCalls {
128+
for toolIndex, tool := range p.activeToolCalls {
128129
toolCalls = append(toolCalls, openai.ChatCompletionChunkChoiceDeltaToolCall{
129130
ID: &tool.id,
130131
Type: openai.ChatCompletionMessageToolCallTypeFunction,
131132
Function: openai.ChatCompletionMessageToolCallFunctionParam{
132133
Name: tool.name,
133134
Arguments: tool.inputJSON,
134135
},
135-
Index: 0,
136+
Index: toolIndex,
136137
})
137138
}
138139

@@ -197,6 +198,7 @@ func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, dat
197198
p.activeMessageID = event.Message.ID
198199
p.tokenUsage.InputTokens = uint32(event.Message.Usage.InputTokens) //nolint:gosec
199200
p.tokenUsage.CachedInputTokens += uint32(event.Message.Usage.CacheReadInputTokens) //nolint:gosec
201+
*p.toolIndex = 0
200202
return nil, nil
201203

202204
case string(constant.ValueOf[constant.ContentBlockStart]()):
@@ -205,7 +207,6 @@ func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, dat
205207
return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err)
206208
}
207209
if event.ContentBlock.Type == string(constant.ValueOf[constant.ToolUse]()) || event.ContentBlock.Type == string(constant.ValueOf[constant.ServerToolUse]()) {
208-
toolIdx := int(event.Index)
209210
var argsJSON string
210211
// Check if the input field is provided directly in the start event.
211212
if event.ContentBlock.Input != nil {
@@ -228,15 +229,15 @@ func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, dat
228229
}
229230

230231
// Store the complete input JSON in our state.
231-
p.activeToolCalls[toolIdx] = &streamingToolCall{
232+
p.activeToolCalls[*p.toolIndex] = &streamingToolCall{
232233
id: event.ContentBlock.ID,
233234
name: event.ContentBlock.Name,
234235
inputJSON: argsJSON,
235236
}
236237
delta := openai.ChatCompletionResponseChunkChoiceDelta{
237238
ToolCalls: []openai.ChatCompletionChunkChoiceDeltaToolCall{
238239
{
239-
Index: 0,
240+
Index: *p.toolIndex,
240241
ID: &event.ContentBlock.ID,
241242
Type: openai.ChatCompletionMessageToolCallTypeFunction,
242243
Function: openai.ChatCompletionMessageToolCallFunctionParam{
@@ -247,6 +248,7 @@ func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, dat
247248
},
248249
},
249250
}
251+
*p.toolIndex++
250252
return p.constructOpenAIChatCompletionChunk(delta, ""), nil
251253
}
252254
if event.ContentBlock.Type == string(constant.ValueOf[constant.Thinking]()) {
@@ -284,14 +286,14 @@ func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, dat
284286
delta := openai.ChatCompletionResponseChunkChoiceDelta{Content: &event.Delta.Text}
285287
return p.constructOpenAIChatCompletionChunk(delta, ""), nil
286288
case string(constant.ValueOf[constant.InputJSONDelta]()):
287-
tool, ok := p.activeToolCalls[int(event.Index)]
289+
tool, ok := p.activeToolCalls[*p.toolIndex]
288290
if !ok {
289291
return nil, fmt.Errorf("received input_json_delta for unknown tool at index %d", event.Index)
290292
}
291293
delta := openai.ChatCompletionResponseChunkChoiceDelta{
292294
ToolCalls: []openai.ChatCompletionChunkChoiceDeltaToolCall{
293295
{
294-
Index: 0,
296+
Index: *p.toolIndex,
295297
Function: openai.ChatCompletionMessageToolCallFunctionParam{
296298
Arguments: event.Delta.PartialJSON,
297299
},
@@ -308,7 +310,7 @@ func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, dat
308310
if err := json.Unmarshal(data, &event); err != nil {
309311
return nil, fmt.Errorf("unmarshal content_block_stop: %w", err)
310312
}
311-
delete(p.activeToolCalls, int(event.Index))
313+
delete(p.activeToolCalls, *p.toolIndex)
312314
return nil, nil
313315

314316
case string(constant.ValueOf[constant.MessageStop]()):

0 commit comments

Comments
 (0)