|
1 | 1 | package azureopenai
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "bytes" |
4 | 5 | "context"
|
| 6 | + "encoding/json" |
| 7 | + "fmt" |
| 8 | + "io" |
| 9 | + "net/http" |
5 | 10 |
|
6 |
| - "glide/pkg/api/schemas" |
| 11 | + "github.com/r3labs/sse/v2" |
7 | 12 | "glide/pkg/providers/clients"
|
| 13 | + "glide/pkg/telemetry" |
| 14 | + |
| 15 | + "go.uber.org/zap" |
| 16 | + |
| 17 | + "glide/pkg/api/schemas" |
8 | 18 | )
|
9 | 19 |
|
| 20 | +var ( |
| 21 | + StopReason = "stop" |
| 22 | + streamDoneMarker = []byte("[DONE]") |
| 23 | +) |
| 24 | + |
| 25 | +// ChatStream represents chat stream for a specific request |
| 26 | +type ChatStream struct { |
| 27 | + tel *telemetry.Telemetry |
| 28 | + client *http.Client |
| 29 | + req *http.Request |
| 30 | + reqID string |
| 31 | + reqMetadata *schemas.Metadata |
| 32 | + resp *http.Response |
| 33 | + reader *sse.EventStreamReader |
| 34 | + errMapper *ErrorMapper |
| 35 | +} |
| 36 | + |
| 37 | +func NewChatStream( |
| 38 | + tel *telemetry.Telemetry, |
| 39 | + client *http.Client, |
| 40 | + req *http.Request, |
| 41 | + reqID string, |
| 42 | + reqMetadata *schemas.Metadata, |
| 43 | + errMapper *ErrorMapper, |
| 44 | +) *ChatStream { |
| 45 | + return &ChatStream{ |
| 46 | + tel: tel, |
| 47 | + client: client, |
| 48 | + req: req, |
| 49 | + reqID: reqID, |
| 50 | + reqMetadata: reqMetadata, |
| 51 | + errMapper: errMapper, |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +// Open initializes and opens a ChatStream. |
| 56 | +func (s *ChatStream) Open() error { |
| 57 | + resp, err := s.client.Do(s.req) //nolint:bodyclose |
| 58 | + if err != nil { |
| 59 | + return err |
| 60 | + } |
| 61 | + |
| 62 | + if resp.StatusCode != http.StatusOK { |
| 63 | + return s.errMapper.Map(resp) |
| 64 | + } |
| 65 | + |
| 66 | + s.resp = resp |
| 67 | + s.reader = sse.NewEventStreamReader(resp.Body, 4096) // TODO: should we expose maxBufferSize? |
| 68 | + |
| 69 | + return nil |
| 70 | +} |
| 71 | + |
| 72 | +// Recv receives a chat stream chunk from the ChatStream and returns a ChatStreamChunk object. |
| 73 | +func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { |
| 74 | + var completionChunk ChatCompletionChunk |
| 75 | + |
| 76 | + for { |
| 77 | + rawEvent, err := s.reader.ReadEvent() |
| 78 | + if err != nil { |
| 79 | + s.tel.L().Warn( |
| 80 | + "Chat stream is unexpectedly disconnected", |
| 81 | + zap.String("provider", providerName), |
| 82 | + zap.Error(err), |
| 83 | + ) |
| 84 | + |
| 85 | + // if err is io.EOF, this still means that the stream is interrupted unexpectedly |
| 86 | + // because the normal stream termination is done via finding out streamDoneMarker |
| 87 | + |
| 88 | + return nil, clients.ErrProviderUnavailable |
| 89 | + } |
| 90 | + |
| 91 | + s.tel.L().Debug( |
| 92 | + "Raw chat stream chunk", |
| 93 | + zap.String("provider", providerName), |
| 94 | + zap.ByteString("rawChunk", rawEvent), |
| 95 | + ) |
| 96 | + |
| 97 | + event, err := clients.ParseSSEvent(rawEvent) |
| 98 | + |
| 99 | + if bytes.Equal(event.Data, streamDoneMarker) { |
| 100 | + s.tel.L().Info( |
| 101 | + "EOF: [DONE] marker found in chat stream", |
| 102 | + zap.String("provider", providerName), |
| 103 | + ) |
| 104 | + |
| 105 | + return nil, io.EOF |
| 106 | + } |
| 107 | + |
| 108 | + if err != nil { |
| 109 | + return nil, fmt.Errorf("failed to parse chat stream message: %v", err) |
| 110 | + } |
| 111 | + |
| 112 | + if !event.HasContent() { |
| 113 | + s.tel.L().Debug( |
| 114 | + "Received an empty message in chat stream, skipping it", |
| 115 | + zap.String("provider", providerName), |
| 116 | + zap.Any("msg", event), |
| 117 | + ) |
| 118 | + |
| 119 | + continue |
| 120 | + } |
| 121 | + |
| 122 | + err = json.Unmarshal(event.Data, &completionChunk) |
| 123 | + if err != nil { |
| 124 | + return nil, fmt.Errorf("failed to unmarshal AzureOpenAI chat stream chunk: %v", err) |
| 125 | + } |
| 126 | + |
| 127 | + responseChunk := completionChunk.Choices[0] |
| 128 | + |
| 129 | + var finishReason *schemas.FinishReason |
| 130 | + |
| 131 | + if responseChunk.FinishReason == StopReason { |
| 132 | + finishReason = &schemas.Complete |
| 133 | + } |
| 134 | + |
| 135 | + // TODO: use objectpool here |
| 136 | + return &schemas.ChatStreamChunk{ |
| 137 | + ID: s.reqID, |
| 138 | + Provider: providerName, |
| 139 | + Cached: false, |
| 140 | + ModelName: completionChunk.ModelName, |
| 141 | + Metadata: s.reqMetadata, |
| 142 | + ModelResponse: schemas.ModelChunkResponse{ |
| 143 | + Metadata: &schemas.Metadata{ |
| 144 | + "response_id": completionChunk.ID, |
| 145 | + "system_fingerprint": completionChunk.SystemFingerprint, |
| 146 | + }, |
| 147 | + Message: schemas.ChatMessage{ |
| 148 | + Role: responseChunk.Delta.Role, |
| 149 | + Content: responseChunk.Delta.Content, |
| 150 | + }, |
| 151 | + FinishReason: finishReason, |
| 152 | + }, |
| 153 | + }, nil |
| 154 | + } |
| 155 | +} |
| 156 | + |
| 157 | +func (s *ChatStream) Close() error { |
| 158 | + if s.resp != nil { |
| 159 | + return s.resp.Body.Close() |
| 160 | + } |
| 161 | + |
| 162 | + return nil |
| 163 | +} |
| 164 | + |
10 | 165 | func (c *Client) SupportChatStream() bool {
|
11 |
| - return false |
| 166 | + return true |
| 167 | +} |
| 168 | + |
| 169 | +func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) { |
| 170 | + // Create a new chat request |
| 171 | + httpRequest, err := c.makeStreamReq(ctx, req) |
| 172 | + if err != nil { |
| 173 | + return nil, err |
| 174 | + } |
| 175 | + |
| 176 | + return NewChatStream( |
| 177 | + c.tel, |
| 178 | + c.httpClient, |
| 179 | + httpRequest, |
| 180 | + req.ID, |
| 181 | + req.Metadata, |
| 182 | + c.errMapper, |
| 183 | + ), nil |
12 | 184 | }
|
13 | 185 |
|
14 |
| -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) { |
15 |
| - return nil, clients.ErrChatStreamNotImplemented |
| 186 | +func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest { |
| 187 | + // TODO: consider using objectpool to optimize memory allocation |
| 188 | + chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template |
| 189 | + |
| 190 | + chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1) |
| 191 | + |
| 192 | + // Add items from messageHistory first and the new chat message last |
| 193 | + for _, message := range request.MessageHistory { |
| 194 | + chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content}) |
| 195 | + } |
| 196 | + |
| 197 | + chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) |
| 198 | + |
| 199 | + return &chatRequest |
| 200 | +} |
| 201 | + |
| 202 | +func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) { |
| 203 | + chatRequest := c.createRequestFromStream(req) |
| 204 | + |
| 205 | + chatRequest.Stream = true |
| 206 | + |
| 207 | + rawPayload, err := json.Marshal(chatRequest) |
| 208 | + if err != nil { |
| 209 | + return nil, fmt.Errorf("unable to marshal AzureOpenAI chat stream request payload: %w", err) |
| 210 | + } |
| 211 | + |
| 212 | + request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload)) |
| 213 | + if err != nil { |
| 214 | + return nil, fmt.Errorf("unable to create AzureOpenAI stream chat request: %w", err) |
| 215 | + } |
| 216 | + |
| 217 | + request.Header.Set("Content-Type", "application/json") |
| 218 | + request.Header.Set("api-key", string(c.config.APIKey)) |
| 219 | + request.Header.Set("Cache-Control", "no-cache") |
| 220 | + request.Header.Set("Accept", "text/event-stream") |
| 221 | + request.Header.Set("Connection", "keep-alive") |
| 222 | + |
| 223 | + // TODO: this could leak information from messages which may not be a desired thing to have |
| 224 | + c.tel.L().Debug( |
| 225 | + "Stream chat request", |
| 226 | + zap.String("chatURL", c.chatURL), |
| 227 | + zap.Any("payload", chatRequest), |
| 228 | + ) |
| 229 | + |
| 230 | + return request, nil |
16 | 231 | }
|
0 commit comments