Skip to content

Commit

Permalink
#153: Aligning chat & chatStream methods with new ChatParams
Browse files Browse the repository at this point in the history
  • Loading branch information
roma-glushko committed May 29, 2024
1 parent 56eb570 commit 6e71630
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 124 deletions.
2 changes: 2 additions & 0 deletions pkg/providers/anthropic/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas
chatReq := *c.chatRequestTemplate
chatReq.ApplyParams(params)

chatReq.Stream = false

chatResponse, err := c.doChatRequest(ctx, &chatReq)

if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/anthropic/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
2 changes: 2 additions & 0 deletions pkg/providers/azureopenai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas
chatReq := *c.chatRequestTemplate // hoping to get a copy of the template
chatReq.ApplyParams(params)

chatReq.Stream = false

chatResponse, err := c.doChatRequest(ctx, &chatReq)

if err != nil {
Expand Down
32 changes: 9 additions & 23 deletions pkg/providers/azureopenai/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ func (c *Client) SupportChatStream() bool {
return true
}

func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) {
func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) {
// Create a new chat request
httpRequest, err := c.makeStreamReq(ctx, req)
httpRequest, err := c.makeStreamReq(ctx, params)

if err != nil {
return nil, err
}
Expand All @@ -171,28 +172,13 @@ func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest)
), nil
}

func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest {
// TODO: consider using objectpool to optimize memory allocation
chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template

chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1)

// Add items from messageHistory first and the new chat message last
for _, message := range request.MessageHistory {
chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content})
}

chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})

return &chatRequest
}

func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) {
chatRequest := c.createRequestFromStream(req)
func (c *Client) makeStreamReq(ctx context.Context, params *schemas.ChatParams) (*http.Request, error) {
chatReq := *c.chatRequestTemplate
chatReq.ApplyParams(params)

chatRequest.Stream = true
chatReq.Stream = true

rawPayload, err := json.Marshal(chatRequest)
rawPayload, err := json.Marshal(chatReq)
if err != nil {
return nil, fmt.Errorf("unable to marshal AzureOpenAI chat stream request payload: %w", err)
}
Expand All @@ -212,7 +198,7 @@ func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamReque
c.tel.L().Debug(
"Stream chat request",
zap.String("chatURL", c.chatURL),
zap.Any("payload", chatRequest),
zap.Any("payload", chatReq),
)

return request, nil
Expand Down
58 changes: 29 additions & 29 deletions pkg/providers/bedrock/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,21 @@ import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
)

// ChatRequest is an Bedrock-specific request schema
// ChatRequest is a Bedrock-specific request schema
type ChatRequest struct {
Messages string `json:"inputText"`
TextGenerationConfig TextGenerationConfig `json:"textGenerationConfig"`
}

func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) {
// message history not yet supported for AWS models
// TODO: do something about lack of message history. Maybe just concatenate all messages?
// in any case, this is not a way to go to ignore message history
message := params.Messages[len(params.Messages)-1]

r.Messages = fmt.Sprintf("Role: %s, Content: %s", message.Role, message.Content)
}

type TextGenerationConfig struct {
Temperature float64 `json:"temperature"`
TopP float64 `json:"topP"`
Expand All @@ -41,38 +50,22 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
}
}

func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) string {
// message history not yet supported for AWS models
message := fmt.Sprintf("Role: %s, Content: %s", request.Message.Role, request.Message.Content)

return message
}

// Chat sends a chat request to the specified bedrock model.
func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) {
// Create a new chat request
chatRequest := c.createChatRequestSchema(request)
// TODO: consider using objectpool to optimize memory allocation
chatReq := *c.chatRequestTemplate // hoping to get a copy of the template
chatReq.ApplyParams(params)

chatResponse, err := c.doChatRequest(ctx, &chatReq)

chatResponse, err := c.doChatRequest(ctx, chatRequest)
if err != nil {
return nil, err
}

if len(chatResponse.ModelResponse.Message.Content) == 0 {
return nil, ErrEmptyResponse
}

return chatResponse, nil
}

func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest {
// TODO: consider using objectpool to optimize memory allocation
chatRequest := c.chatRequestTemplate // hoping to get a copy of the template
chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request)

return chatRequest
}

func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) {
rawPayload, err := json.Marshal(payload)
if err != nil {
Expand All @@ -84,6 +77,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
ContentType: aws.String("application/json"),
Body: rawPayload,
})

if err != nil {
c.telemetry.Logger.Error("Error: Couldn't invoke model. Here's why: %v\n", zap.Error(err))
return nil, err
Expand All @@ -92,30 +86,36 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
var bedrockCompletion ChatCompletion

err = json.Unmarshal(result.Body, &bedrockCompletion)

if err != nil {
c.telemetry.Logger.Error("failed to parse bedrock chat response", zap.Error(err))

return nil, err
}

modelResult := bedrockCompletion.Results[0]

if len(modelResult.OutputText) == 0 {
return nil, ErrEmptyResponse
}

response := schemas.ChatResponse{
ID: uuid.NewString(),
Created: int(time.Now().Unix()),
Provider: "aws-bedrock",
ModelName: c.config.Model,
Cached: false,
ModelResponse: schemas.ModelResponse{
Metadata: map[string]string{
"system_fingerprint": "none",
},
Metadata: map[string]string{},
Message: schemas.ChatMessage{
Role: "assistant",
Content: bedrockCompletion.Results[0].OutputText,
Content: modelResult.OutputText,
},
TokenUsage: schemas.TokenUsage{
PromptTokens: bedrockCompletion.Results[0].TokenCount,
// TODO: what would happen if there is a few responses? We need to sum that up
PromptTokens: modelResult.TokenCount,
ResponseTokens: -1,
TotalTokens: bedrockCompletion.Results[0].TokenCount,
TotalTokens: modelResult.TokenCount,
},
},
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/bedrock/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
70 changes: 30 additions & 40 deletions pkg/providers/octoml/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ type ChatRequest struct {
PresencePenalty int `json:"presence_penalty,omitempty"`
}

func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) {
r.Messages = params.Messages
// TODO(185): set other params
}

// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives
func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
return &ChatRequest{
Expand All @@ -36,50 +41,29 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
TopP: cfg.DefaultParams.TopP,
MaxTokens: cfg.DefaultParams.MaxTokens,
StopWords: cfg.DefaultParams.StopWords,
Stream: false, // unsupported right now
FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty,
PresencePenalty: cfg.DefaultParams.PresencePenalty,
}
}

func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage {
messages := make([]ChatMessage, 0, len(request.MessageHistory)+1)

// Add items from messageHistory first and the new chat message last
for _, message := range request.MessageHistory {
messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content})
}

messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})

return messages
}

// Chat sends a chat request to the specified octoml model.
func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) {
// Create a new chat request
chatRequest := c.createChatRequestSchema(request)
// TODO: consider using objectpool to optimize memory allocation
chatReq := *c.chatRequestTemplate // hoping to get a copy of the template
chatReq.ApplyParams(params)

chatReq.Stream = false

chatResponse, err := c.doChatRequest(ctx, &chatReq)

chatResponse, err := c.doChatRequest(ctx, chatRequest)
if err != nil {
return nil, err
}

if len(chatResponse.ModelResponse.Message.Content) == 0 {
return nil, ErrEmptyResponse
}

return chatResponse, nil
}

func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest {
// TODO: consider using objectpool to optimize memory allocation
chatRequest := c.chatRequestTemplate // hoping to get a copy of the template
chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request)

return chatRequest
}

func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) {
// Build request payload
rawPayload, err := json.Marshal(payload)
Expand Down Expand Up @@ -121,33 +105,39 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
}

// Parse the response JSON
var openAICompletion openai.ChatCompletion // Octo uses the same response schema as OpenAI
var completion openai.ChatCompletion // Octo uses the same response schema as OpenAI

err = json.Unmarshal(bodyBytes, &openAICompletion)
err = json.Unmarshal(bodyBytes, &completion)
if err != nil {
c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err))
return nil, err
}

modelChoice := completion.Choices[0]

if len(modelChoice.Message.Content) == 0 {
return nil, ErrEmptyResponse
}

// Map response to UnifiedChatResponse schema
response := schemas.ChatResponse{
ID: openAICompletion.ID,
Created: openAICompletion.Created,
ID: completion.ID,
Created: completion.Created,
Provider: providerName,
ModelName: openAICompletion.ModelName,
ModelName: completion.ModelName,
Cached: false,
ModelResponse: schemas.ModelResponse{
Metadata: map[string]string{
"system_fingerprint": openAICompletion.SystemFingerprint,
"system_fingerprint": completion.SystemFingerprint,
},
Message: schemas.ChatMessage{
Role: openAICompletion.Choices[0].Message.Role,
Content: openAICompletion.Choices[0].Message.Content,
Role: modelChoice.Message.Role,
Content: modelChoice.Message.Content,
},
TokenUsage: schemas.TokenUsage{
PromptTokens: openAICompletion.Usage.PromptTokens,
ResponseTokens: openAICompletion.Usage.CompletionTokens,
TotalTokens: openAICompletion.Usage.TotalTokens,
PromptTokens: completion.Usage.PromptTokens,
ResponseTokens: completion.Usage.CompletionTokens,
TotalTokens: completion.Usage.TotalTokens,
},
},
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/octoml/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
Loading

0 comments on commit 6e71630

Please sign in to comment.