Skip to content

✨🐛 #183: Fix Anthropic API key header and start counting its token usage #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,15 @@ const docTemplate = `{
"anthropic.Config": {
"type": "object",
"required": [
"apiVersion",
"baseUrl",
"chatEndpoint",
"model"
],
"properties": {
"apiVersion": {
"type": "string"
},
"baseUrl": {
"type": "string"
},
Expand Down Expand Up @@ -910,13 +914,13 @@ const docTemplate = `{
"type": "object",
"properties": {
"promptTokens": {
"type": "number"
"type": "integer"
},
"responseTokens": {
"type": "number"
"type": "integer"
},
"totalTokens": {
"type": "number"
"type": "integer"
}
}
}
Expand Down
10 changes: 7 additions & 3 deletions docs/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,15 @@
"anthropic.Config": {
"type": "object",
"required": [
"apiVersion",
"baseUrl",
"chatEndpoint",
"model"
],
"properties": {
"apiVersion": {
"type": "string"
},
"baseUrl": {
"type": "string"
},
Expand Down Expand Up @@ -907,13 +911,13 @@
"type": "object",
"properties": {
"promptTokens": {
"type": "number"
"type": "integer"
},
"responseTokens": {
"type": "number"
"type": "integer"
},
"totalTokens": {
"type": "number"
"type": "integer"
}
}
}
Expand Down
9 changes: 6 additions & 3 deletions docs/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ basePath: /
definitions:
anthropic.Config:
properties:
apiVersion:
type: string
baseUrl:
type: string
chatEndpoint:
Expand All @@ -11,6 +13,7 @@ definitions:
model:
type: string
required:
- apiVersion
- baseUrl
- chatEndpoint
- model
Expand Down Expand Up @@ -488,11 +491,11 @@ definitions:
schemas.TokenUsage:
properties:
promptTokens:
type: number
type: integer
responseTokens:
type: number
type: integer
totalTokens:
type: number
type: integer
type: object
externalDocs:
description: Documentation
Expand Down
6 changes: 3 additions & 3 deletions pkg/api/schemas/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ type ModelResponse struct {
}

type TokenUsage struct {
PromptTokens float64 `json:"promptTokens"`
ResponseTokens float64 `json:"responseTokens"`
TotalTokens float64 `json:"totalTokens"`
PromptTokens int `json:"promptTokens"`
ResponseTokens int `json:"responseTokens"`
TotalTokens int `json:"totalTokens"`
}

// ChatMessage is a message in a chat request.
Expand Down
46 changes: 25 additions & 21 deletions pkg/providers/anthropic/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessa
}

// Chat sends a chat request to the specified anthropic model.
//
// Ref: https://docs.anthropic.com/claude/reference/messages_post
func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
// Create a new chat request
chatRequest := c.createChatRequestSchema(request)
Expand All @@ -70,10 +72,6 @@ func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schem
return nil, err
}

if len(chatResponse.ModelResponse.Message.Content) == 0 {
return nil, ErrEmptyResponse
}

return chatResponse, nil
}

Expand All @@ -97,12 +95,13 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
return nil, fmt.Errorf("unable to create anthropic chat request: %w", err)
}

req.Header.Set("Authorization", "Bearer "+string(c.config.APIKey))
req.Header.Set("x-api-key", string(c.config.APIKey)) // must be in lower case
req.Header.Set("anthropic-version", c.apiVersion)
req.Header.Set("Content-Type", "application/json")

// TODO: this could leak information from messages which may not be a desired thing to have
c.telemetry.Logger.Debug(
"anthropic chat request",
c.tel.L().Debug(
"Anthropic chat request",
zap.String("chat_url", c.chatURL),
zap.Any("payload", payload),
)
Expand All @@ -121,38 +120,43 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
// Read the response body into a byte slice
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
c.telemetry.Logger.Error("failed to read anthropic chat response", zap.Error(err))
c.tel.L().Error("Failed to read anthropic chat response", zap.Error(err))
return nil, err
}

// Parse the response JSON
var anthropicCompletion ChatCompletion
var anthropicResponse ChatCompletion

err = json.Unmarshal(bodyBytes, &anthropicCompletion)
err = json.Unmarshal(bodyBytes, &anthropicResponse)
if err != nil {
c.telemetry.Logger.Error("failed to parse anthropic chat response", zap.Error(err))
c.tel.L().Error("Failed to parse anthropic chat response", zap.Error(err))
return nil, err
}

if len(anthropicResponse.Content) == 0 {
return nil, ErrEmptyResponse
}

completion := anthropicResponse.Content[0]
usage := anthropicResponse.Usage

// Map response to ChatResponse schema
response := schemas.ChatResponse{
ID: anthropicCompletion.ID,
ID: anthropicResponse.ID,
Created: int(time.Now().UTC().Unix()), // not provided by anthropic
Provider: providerName,
ModelName: anthropicCompletion.Model,
ModelName: anthropicResponse.Model,
Cached: false,
ModelResponse: schemas.ModelResponse{
SystemID: map[string]string{
"system_fingerprint": anthropicCompletion.ID,
},
SystemID: map[string]string{},
Message: schemas.ChatMessage{
Role: anthropicCompletion.Content[0].Type,
Content: anthropicCompletion.Content[0].Text,
Role: completion.Type,
Content: completion.Text,
},
TokenUsage: schemas.TokenUsage{
PromptTokens: 0, // Anthropic doesn't send prompt tokens
ResponseTokens: 0,
TotalTokens: 0,
PromptTokens: usage.InputTokens,
ResponseTokens: usage.OutputTokens,
TotalTokens: usage.InputTokens + usage.OutputTokens,
},
},
}
Expand Down
6 changes: 4 additions & 2 deletions pkg/providers/anthropic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ var (
type Client struct {
baseURL string
chatURL string
apiVersion string
chatRequestTemplate *ChatRequest
errMapper *ErrorMapper
config *Config
httpClient *http.Client
telemetry *telemetry.Telemetry
tel *telemetry.Telemetry
}

// NewClient creates a new OpenAI client for the OpenAI API.
Expand All @@ -39,6 +40,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
c := &Client{
baseURL: providerConfig.BaseURL,
chatURL: chatURL,
apiVersion: providerConfig.APIVersion,
config: providerConfig,
chatRequestTemplate: NewChatRequestFromConfig(providerConfig),
errMapper: NewErrorMapper(tel),
Expand All @@ -50,7 +52,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
MaxIdleConnsPerHost: 2,
},
},
telemetry: tel,
tel: tel,
}

return c, nil
Expand Down
3 changes: 2 additions & 1 deletion pkg/providers/anthropic/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ type Params struct {
MaxTokens int `yaml:"max_tokens,omitempty" json:"max_tokens"`
StopSequences []string `yaml:"stop,omitempty" json:"stop"`
Metadata *string `yaml:"metadata,omitempty" json:"metadata"`
// Stream bool `json:"stream,omitempty"` // TODO: we are not supporting this at the moment
}

func DefaultParams() Params {
Expand All @@ -38,6 +37,7 @@ func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error {

type Config struct {
BaseURL string `yaml:"baseUrl" json:"baseUrl" validate:"required"`
APIVersion string `yaml:"apiVersion" json:"apiVersion" validate:"required"`
ChatEndpoint string `yaml:"chatEndpoint" json:"chatEndpoint" validate:"required"`
Model string `yaml:"model" json:"model" validate:"required"`
APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"`
Expand All @@ -50,6 +50,7 @@ func DefaultConfig() *Config {

return &Config{
BaseURL: "https://api.anthropic.com/v1",
APIVersion: "2023-06-01",
ChatEndpoint: "/messages",
Model: "claude-instant-1.2",
DefaultParams: &defaultParams,
Expand Down
18 changes: 12 additions & 6 deletions pkg/providers/anthropic/schamas.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
package anthropic

// Anthropic Chat Response
type Content struct {
Type string `json:"type"`
Text string `json:"text"`
}

type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}

// ChatCompletion is an Anthropic Chat Response
type ChatCompletion struct {
ID string `json:"id"`
Type string `json:"type"`
Expand All @@ -9,9 +19,5 @@ type ChatCompletion struct {
Content []Content `json:"content"`
StopReason string `json:"stop_reason"`
StopSequence string `json:"stop_sequence"`
}

type Content struct {
Type string `json:"type"`
Text string `json:"text"`
Usage Usage `json:"usage"`
}
10 changes: 7 additions & 3 deletions pkg/providers/anthropic/testdata/chat.success.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"type": "message",
"model": "claude-2.1",
"model": "claude-instant-1.2",
"role": "assistant",
"content": [
{
Expand All @@ -10,5 +10,9 @@
}
],
"stop_reason": "end_turn",
"stop_sequence": null
}
"stop_sequence": null,
"usage":{
"input_tokens": 24,
"output_tokens": 13
}
}
5 changes: 3 additions & 2 deletions pkg/providers/bedrock/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
err = json.Unmarshal(result.Body, &bedrockCompletion)
if err != nil {
c.telemetry.Logger.Error("failed to parse bedrock chat response", zap.Error(err))

return nil, err
}

Expand All @@ -118,9 +119,9 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
Name: "",
},
TokenUsage: schemas.TokenUsage{
PromptTokens: float64(bedrockCompletion.Results[0].TokenCount),
PromptTokens: bedrockCompletion.Results[0].TokenCount,
ResponseTokens: -1,
TotalTokens: float64(bedrockCompletion.Results[0].TokenCount),
TotalTokens: bedrockCompletion.Results[0].TokenCount,
},
},
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/providers/cohere/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ type ChatCompletion struct {
}

type TokenCount struct {
PromptTokens float64 `json:"prompt_tokens"`
ResponseTokens float64 `json:"response_tokens"`
TotalTokens float64 `json:"total_tokens"`
BilledTokens float64 `json:"billed_tokens"`
PromptTokens int `json:"prompt_tokens"`
ResponseTokens int `json:"response_tokens"`
TotalTokens int `json:"total_tokens"`
BilledTokens int `json:"billed_tokens"`
}

type Meta struct {
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest)

if err == nil {
// record latency per token to normalize measurements
m.chatLatency.Add(float64(time.Since(startedAt)) / resp.ModelResponse.TokenUsage.ResponseTokens)
m.chatLatency.Add(float64(time.Since(startedAt)) / float64(resp.ModelResponse.TokenUsage.ResponseTokens))

// successful response
resp.ModelID = m.modelID
Expand Down
6 changes: 3 additions & 3 deletions pkg/providers/ollama/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
Content: ollamaCompletion.Message.Content,
},
TokenUsage: schemas.TokenUsage{
PromptTokens: float64(ollamaCompletion.EvalCount),
ResponseTokens: float64(ollamaCompletion.EvalCount),
TotalTokens: float64(ollamaCompletion.EvalCount),
PromptTokens: ollamaCompletion.EvalCount,
ResponseTokens: ollamaCompletion.EvalCount,
TotalTokens: ollamaCompletion.EvalCount,
},
},
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/providers/openai/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ type Choice struct {
}

type Usage struct {
PromptTokens float64 `json:"prompt_tokens"`
CompletionTokens float64 `json:"completion_tokens"`
TotalTokens float64 `json:"total_tokens"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

// ChatCompletionChunk represents SSEvent a chat response is broken down on chat streaming
Expand Down