Skip to content

Commit

Permalink
#173: Add Streaming Support for Azure OpenAI (#188)
Browse files Browse the repository at this point in the history
* #173: add streaming

* #173: update header and test data

* #173: Update test and schema

* #173: lint

---------

Co-authored-by: Max <mkrueger190@gmail.com>
  • Loading branch information
mkrueger12 and mkrueger12 authored Mar 24, 2024
1 parent bc1a665 commit 428c467
Show file tree
Hide file tree
Showing 11 changed files with 518 additions and 54 deletions.
2 changes: 1 addition & 1 deletion pkg/api/schemas/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type OverrideChatRequest struct {
func NewChatFromStr(message string) *ChatRequest {
return &ChatRequest{
Message: ChatMessage{
"human",
"user",
message,
"glide",
},
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/schemas/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type ChatStreamRequest struct {
func NewChatStreamFromStr(message string) *ChatStreamRequest {
return &ChatStreamRequest{
Message: ChatMessage{
"human",
"user",
message,
"glide",
},
Expand Down
64 changes: 18 additions & 46 deletions pkg/providers/azureopenai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,6 @@ import (
"go.uber.org/zap"
)

type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}

// ChatRequest is an Azure openai-specific request schema
type ChatRequest struct {
Messages []ChatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
N int `json:"n,omitempty"`
StopWords []string `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
PresencePenalty int `json:"presence_penalty,omitempty"`
LogitBias *map[int]float64 `json:"logit_bias,omitempty"`
User *string `json:"user,omitempty"`
Seed *int `json:"seed,omitempty"`
Tools []string `json:"tools,omitempty"`
ToolChoice interface{} `json:"tool_choice,omitempty"`
ResponseFormat interface{} `json:"response_format,omitempty"`
}

// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives
func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
return &ChatRequest{
Expand All @@ -46,7 +22,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
MaxTokens: cfg.DefaultParams.MaxTokens,
N: cfg.DefaultParams.N,
StopWords: cfg.DefaultParams.StopWords,
Stream: false, // unsupported right now
Stream: false,
FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty,
PresencePenalty: cfg.DefaultParams.PresencePenalty,
LogitBias: cfg.DefaultParams.LogitBias,
Expand All @@ -58,23 +34,10 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
}
}

func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage {
messages := make([]ChatMessage, 0, len(request.MessageHistory)+1)

// Add items from messageHistory first and the new chat message last
for _, message := range request.MessageHistory {
messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content})
}

messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})

return messages
}

// Chat sends a chat request to the specified azure openai model.
func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
// Create a new chat request
chatRequest := c.createChatRequestSchema(request)
chatRequest := c.createRequestSchema(request)

chatResponse, err := c.doChatRequest(ctx, chatRequest)
if err != nil {
Expand All @@ -88,12 +51,21 @@ func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schem
return chatResponse, nil
}

func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest {
// createRequestSchema creates a new ChatRequest object based on the given request.
func (c *Client) createRequestSchema(request *schemas.ChatRequest) *ChatRequest {
// TODO: consider using objectpool to optimize memory allocation
chatRequest := c.chatRequestTemplate // hoping to get a copy of the template
chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request)
chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template

chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1)

// Add items from messageHistory first and the new chat message last
for _, message := range request.MessageHistory {
chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content})
}

chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})

return chatRequest
return &chatRequest
}

func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) {
Expand All @@ -112,7 +84,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
req.Header.Set("Content-Type", "application/json")

// TODO: this could leak information from messages which may not be a desired thing to have
c.telemetry.Logger.Debug(
c.tel.Logger.Debug(
"azure openai chat request",
zap.String("chat_url", c.chatURL),
zap.Any("payload", payload),
Expand All @@ -132,7 +104,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
// Read the response body into a byte slice
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
c.telemetry.Logger.Error("failed to read azure openai chat response", zap.Error(err))
c.tel.Logger.Error("failed to read azure openai chat response", zap.Error(err))
return nil, err
}

Expand All @@ -141,7 +113,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche

err = json.Unmarshal(bodyBytes, &openAICompletion)
if err != nil {
c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err))
c.tel.Logger.Error("failed to parse openai chat response", zap.Error(err))
return nil, err
}

Expand Down
223 changes: 219 additions & 4 deletions pkg/providers/azureopenai/chat_stream.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,231 @@
package azureopenai

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"

"glide/pkg/api/schemas"
"github.com/r3labs/sse/v2"
"glide/pkg/providers/clients"
"glide/pkg/telemetry"

"go.uber.org/zap"

"glide/pkg/api/schemas"
)

var (
StopReason = "stop"
streamDoneMarker = []byte("[DONE]")
)

// ChatStream represents chat stream for a specific request
type ChatStream struct {
tel *telemetry.Telemetry
client *http.Client
req *http.Request
reqID string
reqMetadata *schemas.Metadata
resp *http.Response
reader *sse.EventStreamReader
errMapper *ErrorMapper
}

func NewChatStream(
tel *telemetry.Telemetry,
client *http.Client,
req *http.Request,
reqID string,
reqMetadata *schemas.Metadata,
errMapper *ErrorMapper,
) *ChatStream {
return &ChatStream{
tel: tel,
client: client,
req: req,
reqID: reqID,
reqMetadata: reqMetadata,
errMapper: errMapper,
}
}

// Open initializes and opens a ChatStream.
func (s *ChatStream) Open() error {
resp, err := s.client.Do(s.req) //nolint:bodyclose
if err != nil {
return err
}

if resp.StatusCode != http.StatusOK {
return s.errMapper.Map(resp)
}

s.resp = resp
s.reader = sse.NewEventStreamReader(resp.Body, 4096) // TODO: should we expose maxBufferSize?

return nil
}

// Recv receives a chat stream chunk from the ChatStream and returns a ChatStreamChunk object.
func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
var completionChunk ChatCompletionChunk

for {
rawEvent, err := s.reader.ReadEvent()
if err != nil {
s.tel.L().Warn(
"Chat stream is unexpectedly disconnected",
zap.String("provider", providerName),
zap.Error(err),
)

// if err is io.EOF, this still means that the stream is interrupted unexpectedly
// because the normal stream termination is done via finding out streamDoneMarker

return nil, clients.ErrProviderUnavailable
}

s.tel.L().Debug(
"Raw chat stream chunk",
zap.String("provider", providerName),
zap.ByteString("rawChunk", rawEvent),
)

event, err := clients.ParseSSEvent(rawEvent)

if bytes.Equal(event.Data, streamDoneMarker) {
s.tel.L().Info(
"EOF: [DONE] marker found in chat stream",
zap.String("provider", providerName),
)

return nil, io.EOF
}

if err != nil {
return nil, fmt.Errorf("failed to parse chat stream message: %v", err)
}

if !event.HasContent() {
s.tel.L().Debug(
"Received an empty message in chat stream, skipping it",
zap.String("provider", providerName),
zap.Any("msg", event),
)

continue
}

err = json.Unmarshal(event.Data, &completionChunk)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal AzureOpenAI chat stream chunk: %v", err)
}

responseChunk := completionChunk.Choices[0]

var finishReason *schemas.FinishReason

if responseChunk.FinishReason == StopReason {
finishReason = &schemas.Complete
}

// TODO: use objectpool here
return &schemas.ChatStreamChunk{
ID: s.reqID,
Provider: providerName,
Cached: false,
ModelName: completionChunk.ModelName,
Metadata: s.reqMetadata,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
"response_id": completionChunk.ID,
"system_fingerprint": completionChunk.SystemFingerprint,
},
Message: schemas.ChatMessage{
Role: responseChunk.Delta.Role,
Content: responseChunk.Delta.Content,
},
FinishReason: finishReason,
},
}, nil
}
}

func (s *ChatStream) Close() error {
if s.resp != nil {
return s.resp.Body.Close()
}

return nil
}

func (c *Client) SupportChatStream() bool {
return false
return true
}

func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) {
// Create a new chat request
httpRequest, err := c.makeStreamReq(ctx, req)
if err != nil {
return nil, err
}

return NewChatStream(
c.tel,
c.httpClient,
httpRequest,
req.ID,
req.Metadata,
c.errMapper,
), nil
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest {
// TODO: consider using objectpool to optimize memory allocation
chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template

chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1)

// Add items from messageHistory first and the new chat message last
for _, message := range request.MessageHistory {
chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content})
}

chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})

return &chatRequest
}

func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) {
chatRequest := c.createRequestFromStream(req)

chatRequest.Stream = true

rawPayload, err := json.Marshal(chatRequest)
if err != nil {
return nil, fmt.Errorf("unable to marshal AzureOpenAI chat stream request payload: %w", err)
}

request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload))
if err != nil {
return nil, fmt.Errorf("unable to create AzureOpenAI stream chat request: %w", err)
}

request.Header.Set("Content-Type", "application/json")
request.Header.Set("api-key", string(c.config.APIKey))
request.Header.Set("Cache-Control", "no-cache")
request.Header.Set("Accept", "text/event-stream")
request.Header.Set("Connection", "keep-alive")

// TODO: this could leak information from messages which may not be a desired thing to have
c.tel.L().Debug(
"Stream chat request",
zap.String("chatURL", c.chatURL),
zap.Any("payload", chatRequest),
)

return request, nil
}
Loading

0 comments on commit 428c467

Please sign in to comment.