Skip to content

Commit 428c467

Browse files
authored
#173: Add Streaming Support for Azure OpenAI (#188)
* #173: add streaming * #173: update header and test data * #173: Update test and schema * #173: lint --------- Co-authored-by: Max <mkrueger190@gmail.com>
1 parent bc1a665 commit 428c467

File tree

11 files changed

+518
-54
lines changed

11 files changed

+518
-54
lines changed

pkg/api/schemas/chat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ type OverrideChatRequest struct {
1515
func NewChatFromStr(message string) *ChatRequest {
1616
return &ChatRequest{
1717
Message: ChatMessage{
18-
"human",
18+
"user",
1919
message,
2020
"glide",
2121
},

pkg/api/schemas/chat_stream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ type ChatStreamRequest struct {
1919
func NewChatStreamFromStr(message string) *ChatStreamRequest {
2020
return &ChatStreamRequest{
2121
Message: ChatMessage{
22-
"human",
22+
"user",
2323
message,
2424
"glide",
2525
},

pkg/providers/azureopenai/chat.go

Lines changed: 18 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,6 @@ import (
1414
"go.uber.org/zap"
1515
)
1616

17-
type ChatMessage struct {
18-
Role string `json:"role"`
19-
Content string `json:"content"`
20-
}
21-
22-
// ChatRequest is an Azure openai-specific request schema
23-
type ChatRequest struct {
24-
Messages []ChatMessage `json:"messages"`
25-
Temperature float64 `json:"temperature,omitempty"`
26-
TopP float64 `json:"top_p,omitempty"`
27-
MaxTokens int `json:"max_tokens,omitempty"`
28-
N int `json:"n,omitempty"`
29-
StopWords []string `json:"stop,omitempty"`
30-
Stream bool `json:"stream,omitempty"`
31-
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
32-
PresencePenalty int `json:"presence_penalty,omitempty"`
33-
LogitBias *map[int]float64 `json:"logit_bias,omitempty"`
34-
User *string `json:"user,omitempty"`
35-
Seed *int `json:"seed,omitempty"`
36-
Tools []string `json:"tools,omitempty"`
37-
ToolChoice interface{} `json:"tool_choice,omitempty"`
38-
ResponseFormat interface{} `json:"response_format,omitempty"`
39-
}
40-
4117
// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives
4218
func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
4319
return &ChatRequest{
@@ -46,7 +22,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
4622
MaxTokens: cfg.DefaultParams.MaxTokens,
4723
N: cfg.DefaultParams.N,
4824
StopWords: cfg.DefaultParams.StopWords,
49-
Stream: false, // unsupported right now
25+
Stream: false,
5026
FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty,
5127
PresencePenalty: cfg.DefaultParams.PresencePenalty,
5228
LogitBias: cfg.DefaultParams.LogitBias,
@@ -58,23 +34,10 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
5834
}
5935
}
6036

61-
func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage {
62-
messages := make([]ChatMessage, 0, len(request.MessageHistory)+1)
63-
64-
// Add items from messageHistory first and the new chat message last
65-
for _, message := range request.MessageHistory {
66-
messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content})
67-
}
68-
69-
messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})
70-
71-
return messages
72-
}
73-
7437
// Chat sends a chat request to the specified azure openai model.
7538
func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
7639
// Create a new chat request
77-
chatRequest := c.createChatRequestSchema(request)
40+
chatRequest := c.createRequestSchema(request)
7841

7942
chatResponse, err := c.doChatRequest(ctx, chatRequest)
8043
if err != nil {
@@ -88,12 +51,21 @@ func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schem
8851
return chatResponse, nil
8952
}
9053

91-
func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest {
54+
// createRequestSchema creates a new ChatRequest object based on the given request.
55+
func (c *Client) createRequestSchema(request *schemas.ChatRequest) *ChatRequest {
9256
// TODO: consider using objectpool to optimize memory allocation
93-
chatRequest := c.chatRequestTemplate // hoping to get a copy of the template
94-
chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request)
57+
chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template
58+
59+
chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1)
60+
61+
// Add items from messageHistory first and the new chat message last
62+
for _, message := range request.MessageHistory {
63+
chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content})
64+
}
65+
66+
chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})
9567

96-
return chatRequest
68+
return &chatRequest
9769
}
9870

9971
func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) {
@@ -112,7 +84,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
11284
req.Header.Set("Content-Type", "application/json")
11385

11486
// TODO: this could leak information from messages which may not be a desired thing to have
115-
c.telemetry.Logger.Debug(
87+
c.tel.Logger.Debug(
11688
"azure openai chat request",
11789
zap.String("chat_url", c.chatURL),
11890
zap.Any("payload", payload),
@@ -132,7 +104,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
132104
// Read the response body into a byte slice
133105
bodyBytes, err := io.ReadAll(resp.Body)
134106
if err != nil {
135-
c.telemetry.Logger.Error("failed to read azure openai chat response", zap.Error(err))
107+
c.tel.Logger.Error("failed to read azure openai chat response", zap.Error(err))
136108
return nil, err
137109
}
138110

@@ -141,7 +113,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
141113

142114
err = json.Unmarshal(bodyBytes, &openAICompletion)
143115
if err != nil {
144-
c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err))
116+
c.tel.Logger.Error("failed to parse openai chat response", zap.Error(err))
145117
return nil, err
146118
}
147119

Lines changed: 219 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,231 @@
11
package azureopenai
22

33
import (
4+
"bytes"
45
"context"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"net/http"
510

6-
"glide/pkg/api/schemas"
11+
"github.com/r3labs/sse/v2"
712
"glide/pkg/providers/clients"
13+
"glide/pkg/telemetry"
14+
15+
"go.uber.org/zap"
16+
17+
"glide/pkg/api/schemas"
818
)
919

20+
var (
21+
StopReason = "stop"
22+
streamDoneMarker = []byte("[DONE]")
23+
)
24+
25+
// ChatStream represents chat stream for a specific request
26+
type ChatStream struct {
27+
tel *telemetry.Telemetry
28+
client *http.Client
29+
req *http.Request
30+
reqID string
31+
reqMetadata *schemas.Metadata
32+
resp *http.Response
33+
reader *sse.EventStreamReader
34+
errMapper *ErrorMapper
35+
}
36+
37+
func NewChatStream(
38+
tel *telemetry.Telemetry,
39+
client *http.Client,
40+
req *http.Request,
41+
reqID string,
42+
reqMetadata *schemas.Metadata,
43+
errMapper *ErrorMapper,
44+
) *ChatStream {
45+
return &ChatStream{
46+
tel: tel,
47+
client: client,
48+
req: req,
49+
reqID: reqID,
50+
reqMetadata: reqMetadata,
51+
errMapper: errMapper,
52+
}
53+
}
54+
55+
// Open initializes and opens a ChatStream.
56+
func (s *ChatStream) Open() error {
57+
resp, err := s.client.Do(s.req) //nolint:bodyclose
58+
if err != nil {
59+
return err
60+
}
61+
62+
if resp.StatusCode != http.StatusOK {
63+
return s.errMapper.Map(resp)
64+
}
65+
66+
s.resp = resp
67+
s.reader = sse.NewEventStreamReader(resp.Body, 4096) // TODO: should we expose maxBufferSize?
68+
69+
return nil
70+
}
71+
72+
// Recv receives a chat stream chunk from the ChatStream and returns a ChatStreamChunk object.
73+
func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
74+
var completionChunk ChatCompletionChunk
75+
76+
for {
77+
rawEvent, err := s.reader.ReadEvent()
78+
if err != nil {
79+
s.tel.L().Warn(
80+
"Chat stream is unexpectedly disconnected",
81+
zap.String("provider", providerName),
82+
zap.Error(err),
83+
)
84+
85+
// if err is io.EOF, this still means that the stream is interrupted unexpectedly
86+
// because the normal stream termination is done via finding out streamDoneMarker
87+
88+
return nil, clients.ErrProviderUnavailable
89+
}
90+
91+
s.tel.L().Debug(
92+
"Raw chat stream chunk",
93+
zap.String("provider", providerName),
94+
zap.ByteString("rawChunk", rawEvent),
95+
)
96+
97+
event, err := clients.ParseSSEvent(rawEvent)
98+
99+
if bytes.Equal(event.Data, streamDoneMarker) {
100+
s.tel.L().Info(
101+
"EOF: [DONE] marker found in chat stream",
102+
zap.String("provider", providerName),
103+
)
104+
105+
return nil, io.EOF
106+
}
107+
108+
if err != nil {
109+
return nil, fmt.Errorf("failed to parse chat stream message: %v", err)
110+
}
111+
112+
if !event.HasContent() {
113+
s.tel.L().Debug(
114+
"Received an empty message in chat stream, skipping it",
115+
zap.String("provider", providerName),
116+
zap.Any("msg", event),
117+
)
118+
119+
continue
120+
}
121+
122+
err = json.Unmarshal(event.Data, &completionChunk)
123+
if err != nil {
124+
return nil, fmt.Errorf("failed to unmarshal AzureOpenAI chat stream chunk: %v", err)
125+
}
126+
127+
responseChunk := completionChunk.Choices[0]
128+
129+
var finishReason *schemas.FinishReason
130+
131+
if responseChunk.FinishReason == StopReason {
132+
finishReason = &schemas.Complete
133+
}
134+
135+
// TODO: use objectpool here
136+
return &schemas.ChatStreamChunk{
137+
ID: s.reqID,
138+
Provider: providerName,
139+
Cached: false,
140+
ModelName: completionChunk.ModelName,
141+
Metadata: s.reqMetadata,
142+
ModelResponse: schemas.ModelChunkResponse{
143+
Metadata: &schemas.Metadata{
144+
"response_id": completionChunk.ID,
145+
"system_fingerprint": completionChunk.SystemFingerprint,
146+
},
147+
Message: schemas.ChatMessage{
148+
Role: responseChunk.Delta.Role,
149+
Content: responseChunk.Delta.Content,
150+
},
151+
FinishReason: finishReason,
152+
},
153+
}, nil
154+
}
155+
}
156+
157+
func (s *ChatStream) Close() error {
158+
if s.resp != nil {
159+
return s.resp.Body.Close()
160+
}
161+
162+
return nil
163+
}
164+
10165
func (c *Client) SupportChatStream() bool {
11-
return false
166+
return true
167+
}
168+
169+
func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) {
170+
// Create a new chat request
171+
httpRequest, err := c.makeStreamReq(ctx, req)
172+
if err != nil {
173+
return nil, err
174+
}
175+
176+
return NewChatStream(
177+
c.tel,
178+
c.httpClient,
179+
httpRequest,
180+
req.ID,
181+
req.Metadata,
182+
c.errMapper,
183+
), nil
12184
}
13185

14-
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
15-
return nil, clients.ErrChatStreamNotImplemented
186+
func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest {
187+
// TODO: consider using objectpool to optimize memory allocation
188+
chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template
189+
190+
chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1)
191+
192+
// Add items from messageHistory first and the new chat message last
193+
for _, message := range request.MessageHistory {
194+
chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content})
195+
}
196+
197+
chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})
198+
199+
return &chatRequest
200+
}
201+
202+
func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) {
203+
chatRequest := c.createRequestFromStream(req)
204+
205+
chatRequest.Stream = true
206+
207+
rawPayload, err := json.Marshal(chatRequest)
208+
if err != nil {
209+
return nil, fmt.Errorf("unable to marshal AzureOpenAI chat stream request payload: %w", err)
210+
}
211+
212+
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload))
213+
if err != nil {
214+
return nil, fmt.Errorf("unable to create AzureOpenAI stream chat request: %w", err)
215+
}
216+
217+
request.Header.Set("Content-Type", "application/json")
218+
request.Header.Set("api-key", string(c.config.APIKey))
219+
request.Header.Set("Cache-Control", "no-cache")
220+
request.Header.Set("Accept", "text/event-stream")
221+
request.Header.Set("Connection", "keep-alive")
222+
223+
// TODO: this could leak information from messages which may not be a desired thing to have
224+
c.tel.L().Debug(
225+
"Stream chat request",
226+
zap.String("chatURL", c.chatURL),
227+
zap.Any("payload", chatRequest),
228+
)
229+
230+
return request, nil
16231
}

0 commit comments

Comments
 (0)