Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
luoshengheng committed Dec 1, 2023
1 parent b460238 commit e8020c6
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 30 deletions.
17 changes: 9 additions & 8 deletions chat.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai

import (
"bufio"
"context"
"errors"
"net/http"
Expand Down Expand Up @@ -84,14 +85,14 @@ type ChatCompletionRequest struct {
// LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string.
// incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}`
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias
LogitBias map[string]int `json:"logit_bias,omitempty"`
User string `json:"user,omitempty"`
Functions []FunctionDefinition `json:"functions,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
AllowFallback bool `json:"allow_fallback,omitempty"`
CaptchaToken string `json:"captchaToken,omitempty"` //easychat的额外参数
Token string `json:"token,omitempty"` //ylokh的额外参数
PlainText bool `json:"-"` //用于标识响应的内容是否纯文本
LogitBias map[string]int `json:"logit_bias,omitempty"`
User string `json:"user,omitempty"`
Functions []FunctionDefinition `json:"functions,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
AllowFallback bool `json:"allow_fallback,omitempty"`
CaptchaToken string `json:"captchaToken,omitempty"` //easychat的额外参数
Token string `json:"token,omitempty"` //ylokh的额外参数
ContentProcessor func(*bufio.Reader) (content string, err error) `json:"-"` //非标准化响应结果的处理函数
}

type FunctionDefinition struct {
Expand Down
4 changes: 1 addition & 3 deletions chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ func (c *Client) CreateChatCompletionStream(
if err != nil {
return
}
if request.PlainText {
resp.ResponsePlainText = true
}
resp.ContentProcessor = request.ContentProcessor
stream = &ChatCompletionStream{
StreamReader: resp,
}
Expand Down
25 changes: 6 additions & 19 deletions stream_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"io"
"net/http"
"unicode/utf8"

utils "github.com/sashabaranov/go-openai/internal"
)
Expand All @@ -24,7 +23,7 @@ type Streamable interface {
type StreamReader[T Streamable] struct {
EmptyMessagesLimit uint
IsFinished bool
ResponsePlainText bool
ContentProcessor func(*bufio.Reader) (content string, err error)
Reader *bufio.Reader
Response *http.Response
ErrAccumulator utils.ErrorAccumulator
Expand All @@ -47,26 +46,15 @@ func (stream *StreamReader[T]) processLines() (T, error) {
emptyMessagesCount uint
hasErrorPrefix bool
)
if stream.ResponsePlainText {
totalBytes := []byte{}
if stream.ContentProcessor != nil {
for {
respBytes := make([]byte, 1)
_, readErr := stream.Reader.Read(respBytes)
if readErr != nil {
respErr := stream.unmarshalError()
if respErr != nil {
return *new(T), fmt.Errorf("error, %w", respErr.Error)
}
return *new(T), readErr
}
totalBytes = append(totalBytes, respBytes...)
r, _ := utf8.DecodeRune(totalBytes)
if r == utf8.RuneError {
continue
conent, err := stream.ContentProcessor(stream.Reader)
if err != nil {
return *new(T), err
}

var response ChatCompletionStreamResponse
response.Choices = []ChatCompletionStreamChoice{{Index: 0, Delta: ChatCompletionStreamChoiceDelta{Content: string(totalBytes)}, FinishReason: "PlainText"}}
response.Choices = []ChatCompletionStreamChoice{{Index: 0, Delta: ChatCompletionStreamChoiceDelta{Content: conent}, FinishReason: "ContentProcessor"}}
bytes, _ := json.Marshal(response)

var t T
Expand All @@ -75,7 +63,6 @@ func (stream *StreamReader[T]) processLines() (T, error) {
return *new(T), unmarshalErr
}
return t, nil

}
} else {
for {
Expand Down

0 comments on commit e8020c6

Please sign in to comment.