Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
luoshengheng committed Nov 30, 2023
1 parent ff04986 commit f1869a6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
4 changes: 4 additions & 0 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ func (c *Client) CreateChatCompletion(
}

urlSuffix := chatCompletionsSuffix
if len(extraHeaders) > 0 && extraHeaders[0] != nil && extraHeaders[0]["path"] != "" {
urlSuffix = extraHeaders[0]["path"]

}
if !checkEndpointSupportsModel(urlSuffix, request.Model) {
err = ErrChatCompletionInvalidModel
return
Expand Down
4 changes: 4 additions & 0 deletions chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ func (c *Client) CreateChatCompletionStream(
extraHeaders ...map[string]string,
) (stream *ChatCompletionStream, err error) {
urlSuffix := chatCompletionsSuffix
if len(extraHeaders) > 0 && extraHeaders[0] != nil && extraHeaders[0]["path"] != "" {
urlSuffix = extraHeaders[0]["path"]

}
if !checkEndpointSupportsModel(urlSuffix, request.Model) {
err = ErrChatCompletionInvalidModel
return
Expand Down
18 changes: 17 additions & 1 deletion stream_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -44,15 +45,30 @@ func (stream *streamReader[T]) processLines() (T, error) {
var (
emptyMessagesCount uint
hasErrorPrefix bool
messagesText string
)

for {
rawLine, readErr := stream.reader.ReadBytes('\n')
messagesText += string(rawLine)
if readErr != nil || hasErrorPrefix {
respErr := stream.unmarshalError()
if respErr != nil {
return *new(T), fmt.Errorf("error, %w", respErr.Error)
}
if len(messagesText) > 0 {
var response ChatCompletionStreamResponse
response.Choices = []ChatCompletionStreamChoice{{Index: 0, Delta: ChatCompletionStreamChoiceDelta{Content: messagesText}, FinishReason: "Completed"}}
bytes, _ := json.Marshal(response)

var t T
unmarshalErr := stream.unmarshaler.Unmarshal(bytes, &t)
if unmarshalErr != nil {
return *new(T), unmarshalErr
}
messagesText = ""
return t, nil
}
return *new(T), readErr
}

Expand Down Expand Up @@ -87,7 +103,7 @@ func (stream *streamReader[T]) processLines() (T, error) {
if unmarshalErr != nil {
return *new(T), unmarshalErr
}

messagesText = ""
return response, nil
}
}
Expand Down

0 comments on commit f1869a6

Please sign in to comment.