From 6e716308594768821e0ea3aab29338ee2d818970 Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Wed, 29 May 2024 23:01:10 +0300 Subject: [PATCH] #153: Aligning chat & chatStream methods with new ChatParams --- pkg/providers/anthropic/chat.go | 2 + pkg/providers/anthropic/chat_stream.go | 2 +- pkg/providers/azureopenai/chat.go | 2 + pkg/providers/azureopenai/chat_stream.go | 32 +++-------- pkg/providers/bedrock/chat.go | 58 ++++++++++---------- pkg/providers/bedrock/chat_stream.go | 2 +- pkg/providers/octoml/chat.go | 70 ++++++++++-------------- pkg/providers/octoml/chat_stream.go | 2 +- pkg/providers/ollama/chat.go | 46 ++++++---------- pkg/providers/ollama/chat_stream.go | 2 +- 10 files changed, 94 insertions(+), 124 deletions(-) diff --git a/pkg/providers/anthropic/chat.go b/pkg/providers/anthropic/chat.go index f58014fa..eb52aafe 100644 --- a/pkg/providers/anthropic/chat.go +++ b/pkg/providers/anthropic/chat.go @@ -56,6 +56,8 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas chatReq := *c.chatRequestTemplate chatReq.ApplyParams(params) + chatReq.Stream = false + chatResponse, err := c.doChatRequest(ctx, &chatReq) if err != nil { diff --git a/pkg/providers/anthropic/chat_stream.go b/pkg/providers/anthropic/chat_stream.go index 6a6b6c01..5a6f2112 100644 --- a/pkg/providers/anthropic/chat_stream.go +++ b/pkg/providers/anthropic/chat_stream.go @@ -12,6 +12,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) { +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/providers/azureopenai/chat.go b/pkg/providers/azureopenai/chat.go index 08a46ffe..5350a295 100644 --- a/pkg/providers/azureopenai/chat.go +++ b/pkg/providers/azureopenai/chat.go @@ -43,6 +43,8 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas chatReq := *c.chatRequestTemplate // hoping to get a copy of the template chatReq.ApplyParams(params) + chatReq.Stream = false + chatResponse, err := c.doChatRequest(ctx, &chatReq) if err != nil { diff --git a/pkg/providers/azureopenai/chat_stream.go b/pkg/providers/azureopenai/chat_stream.go index c6bbbd56..24f2b193 100644 --- a/pkg/providers/azureopenai/chat_stream.go +++ b/pkg/providers/azureopenai/chat_stream.go @@ -155,9 +155,10 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { // Create a new chat request - httpRequest, err := c.makeStreamReq(ctx, req) + httpRequest, err := c.makeStreamReq(ctx, params) + if err != nil { return nil, err } @@ -171,28 +172,13 @@ func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) ), nil } -func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest { - // TODO: consider using objectpool to optimize memory allocation - chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template - - chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1) - - // Add items from messageHistory first and the new chat message last - for _, message := range request.MessageHistory { - chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content}) - } - - chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) - - return &chatRequest -} - -func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) { - chatRequest := c.createRequestFromStream(req) +func (c *Client) makeStreamReq(ctx context.Context, params *schemas.ChatParams) (*http.Request, error) { + chatReq := *c.chatRequestTemplate + chatReq.ApplyParams(params) - chatRequest.Stream = true + chatReq.Stream = true - rawPayload, err := json.Marshal(chatRequest) + rawPayload, err := json.Marshal(chatReq) if err != nil { return nil, fmt.Errorf("unable to marshal AzureOpenAI chat stream request payload: %w", err) } @@ -212,7 +198,7 @@ func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamReque c.tel.L().Debug( "Stream chat request", zap.String("chatURL", c.chatURL), - zap.Any("payload", chatRequest), + zap.Any("payload", chatReq), ) return request, nil diff --git a/pkg/providers/bedrock/chat.go b/pkg/providers/bedrock/chat.go index 00498fcc..b391cd07 100644 --- a/pkg/providers/bedrock/chat.go +++ b/pkg/providers/bedrock/chat.go @@ -16,12 +16,21 @@ import ( "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" ) -// ChatRequest is an Bedrock-specific request schema +// ChatRequest is a Bedrock-specific request schema type ChatRequest struct { Messages string `json:"inputText"` TextGenerationConfig TextGenerationConfig `json:"textGenerationConfig"` } +func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { + // message history not yet supported for AWS models + // TODO: do something about lack of message history. Maybe just concatenate all messages? + // in any case, this is not a way to go to ignore message history + message := params.Messages[len(params.Messages)-1] + + r.Messages = fmt.Sprintf("Role: %s, Content: %s", message.Role, message.Content) +} + type TextGenerationConfig struct { Temperature float64 `json:"temperature"` TopP float64 `json:"topP"` @@ -41,38 +50,22 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } } -func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) string { - // message history not yet supported for AWS models - message := fmt.Sprintf("Role: %s, Content: %s", request.Message.Role, request.Message.Content) - - return message -} - // Chat sends a chat request to the specified bedrock model. -func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { // Create a new chat request - chatRequest := c.createChatRequestSchema(request) + // TODO: consider using objectpool to optimize memory allocation + chatReq := *c.chatRequestTemplate // hoping to get a copy of the template + chatReq.ApplyParams(params) + + chatResponse, err := c.doChatRequest(ctx, &chatReq) - chatResponse, err := c.doChatRequest(ctx, chatRequest) if err != nil { return nil, err } - if len(chatResponse.ModelResponse.Message.Content) == 0 { - return nil, ErrEmptyResponse - } - return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { - // TODO: consider using objectpool to optimize memory allocation - chatRequest := c.chatRequestTemplate // hoping to get a copy of the template - chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) - - return chatRequest -} - func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { rawPayload, err := json.Marshal(payload) if err != nil { @@ -84,6 +77,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche ContentType: aws.String("application/json"), Body: rawPayload, }) + if err != nil { c.telemetry.Logger.Error("Error: Couldn't invoke model. Here's why: %v\n", zap.Error(err)) return nil, err @@ -92,12 +86,19 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche var bedrockCompletion ChatCompletion err = json.Unmarshal(result.Body, &bedrockCompletion) + if err != nil { c.telemetry.Logger.Error("failed to parse bedrock chat response", zap.Error(err)) return nil, err } + modelResult := bedrockCompletion.Results[0] + + if len(modelResult.OutputText) == 0 { + return nil, ErrEmptyResponse + } + response := schemas.ChatResponse{ ID: uuid.NewString(), Created: int(time.Now().Unix()), @@ -105,17 +106,16 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche ModelName: c.config.Model, Cached: false, ModelResponse: schemas.ModelResponse{ - Metadata: map[string]string{ - "system_fingerprint": "none", - }, + Metadata: map[string]string{}, Message: schemas.ChatMessage{ Role: "assistant", - Content: bedrockCompletion.Results[0].OutputText, + Content: modelResult.OutputText, }, TokenUsage: schemas.TokenUsage{ - PromptTokens: bedrockCompletion.Results[0].TokenCount, + // TODO: what would happen if there is a few responses? We need to sum that up + PromptTokens: modelResult.TokenCount, ResponseTokens: -1, - TotalTokens: bedrockCompletion.Results[0].TokenCount, + TotalTokens: modelResult.TokenCount, }, }, } diff --git a/pkg/providers/bedrock/chat_stream.go b/pkg/providers/bedrock/chat_stream.go index bb922860..bb07da7d 100644 --- a/pkg/providers/bedrock/chat_stream.go +++ b/pkg/providers/bedrock/chat_stream.go @@ -12,6 +12,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) { +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/providers/octoml/chat.go b/pkg/providers/octoml/chat.go index 09ee61bc..8d9b97e3 100644 --- a/pkg/providers/octoml/chat.go +++ b/pkg/providers/octoml/chat.go @@ -28,6 +28,11 @@ type ChatRequest struct { PresencePenalty int `json:"presence_penalty,omitempty"` } +func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { + r.Messages = params.Messages + // TODO(185): set other params +} + // NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives func NewChatRequestFromConfig(cfg *Config) *ChatRequest { return &ChatRequest{ @@ -36,50 +41,29 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { TopP: cfg.DefaultParams.TopP, MaxTokens: cfg.DefaultParams.MaxTokens, StopWords: cfg.DefaultParams.StopWords, - Stream: false, // unsupported right now FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty, PresencePenalty: cfg.DefaultParams.PresencePenalty, } } -func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage { - messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) - - // Add items from messageHistory first and the new chat message last - for _, message := range request.MessageHistory { - messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content}) - } - - messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) - - return messages -} - // Chat sends a chat request to the specified octoml model. -func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { // Create a new chat request - chatRequest := c.createChatRequestSchema(request) + // TODO: consider using objectpool to optimize memory allocation + chatReq := *c.chatRequestTemplate // hoping to get a copy of the template + chatReq.ApplyParams(params) + + chatReq.Stream = false + + chatResponse, err := c.doChatRequest(ctx, &chatReq) - chatResponse, err := c.doChatRequest(ctx, chatRequest) if err != nil { return nil, err } - if len(chatResponse.ModelResponse.Message.Content) == 0 { - return nil, ErrEmptyResponse - } - return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { - // TODO: consider using objectpool to optimize memory allocation - chatRequest := c.chatRequestTemplate // hoping to get a copy of the template - chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) - - return chatRequest -} - func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) @@ -121,33 +105,39 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Parse the response JSON - var openAICompletion openai.ChatCompletion // Octo uses the same response schema as OpenAI + var completion openai.ChatCompletion // Octo uses the same response schema as OpenAI - err = json.Unmarshal(bodyBytes, &openAICompletion) + err = json.Unmarshal(bodyBytes, &completion) if err != nil { c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err)) return nil, err } + modelChoice := completion.Choices[0] + + if len(modelChoice.Message.Content) == 0 { + return nil, ErrEmptyResponse + } + // Map response to UnifiedChatResponse schema response := schemas.ChatResponse{ - ID: openAICompletion.ID, - Created: openAICompletion.Created, + ID: completion.ID, + Created: completion.Created, Provider: providerName, - ModelName: openAICompletion.ModelName, + ModelName: completion.ModelName, Cached: false, ModelResponse: schemas.ModelResponse{ Metadata: map[string]string{ - "system_fingerprint": openAICompletion.SystemFingerprint, + "system_fingerprint": completion.SystemFingerprint, }, Message: schemas.ChatMessage{ - Role: openAICompletion.Choices[0].Message.Role, - Content: openAICompletion.Choices[0].Message.Content, + Role: modelChoice.Message.Role, + Content: modelChoice.Message.Content, }, TokenUsage: schemas.TokenUsage{ - PromptTokens: openAICompletion.Usage.PromptTokens, - ResponseTokens: openAICompletion.Usage.CompletionTokens, - TotalTokens: openAICompletion.Usage.TotalTokens, + PromptTokens: completion.Usage.PromptTokens, + ResponseTokens: completion.Usage.CompletionTokens, + TotalTokens: completion.Usage.TotalTokens, }, }, } diff --git a/pkg/providers/octoml/chat_stream.go b/pkg/providers/octoml/chat_stream.go index 8580263a..d0e33420 100644 --- a/pkg/providers/octoml/chat_stream.go +++ b/pkg/providers/octoml/chat_stream.go @@ -12,6 +12,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) { +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/providers/ollama/chat.go b/pkg/providers/ollama/chat.go index 107407d0..c0c74315 100644 --- a/pkg/providers/ollama/chat.go +++ b/pkg/providers/ollama/chat.go @@ -40,6 +40,11 @@ type ChatRequest struct { Stream bool `json:"stream"` } +func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { + r.Messages = params.Messages + // TODO(185): set other params +} + // NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives func NewChatRequestFromConfig(cfg *Config) *ChatRequest { return &ChatRequest{ @@ -62,44 +67,24 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } } -func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage { - messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) - - // Add items from messageHistory first and the new chat message last - for _, message := range request.MessageHistory { - messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content}) - } - - messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) - - return messages -} - // Chat sends a chat request to the specified ollama model. -func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { // Create a new chat request - chatRequest := c.createChatRequestSchema(request) + // TODO: consider using objectpool to optimize memory allocation + chatReq := *c.chatRequestTemplate // hoping to get a copy of the template + chatReq.ApplyParams(params) + + chatReq.Stream = false + + chatResponse, err := c.doChatRequest(ctx, &chatReq) - chatResponse, err := c.doChatRequest(ctx, chatRequest) if err != nil { return nil, fmt.Errorf("chat request failed: %w", err) } - if len(chatResponse.ModelResponse.Message.Content) == 0 { - return nil, ErrEmptyResponse - } - return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { - // TODO: consider using objectpool to optimize memory allocation - chatRequest := c.chatRequestTemplate // hoping to get a copy of the template - chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) - - return chatRequest -} - func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) @@ -147,6 +132,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche // Parse the value to get the duration cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err) } @@ -174,6 +160,10 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, err } + if len(ollamaCompletion.Message.Content) == 0 { + return nil, clients.ErrEmptyResponse + } + // Map response to UnifiedChatResponse schema response := schemas.ChatResponse{ ID: uuid.NewString(), diff --git a/pkg/providers/ollama/chat_stream.go b/pkg/providers/ollama/chat_stream.go index 7da5b292..31075ca1 100644 --- a/pkg/providers/ollama/chat_stream.go +++ b/pkg/providers/ollama/chat_stream.go @@ -12,6 +12,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) { +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { return nil, clients.ErrChatStreamNotImplemented }