Skip to content

Commit

Permalink
Support get http header and x-ratelimit-* headers (sashabaranov#507)
Browse files Browse the repository at this point in the history
* feat: add headers to http response

* feat: support rate limit headers

* fix: go lint

* fix: test coverage

* refactor streamReader

* refactor streamReader

* refactor: NewRateLimitHeaders to newRateLimitHeaders

* refactor: RateLimitHeaders Resets filed

* refactor: move RateLimitHeaders struct
  • Loading branch information
liushuangls authored Oct 10, 2023
1 parent 8e165dc commit b77d01e
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 5 deletions.
89 changes: 86 additions & 3 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"testing"

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
)

func TestChatCompletionsStreamWrongModel(t *testing.T) {
Expand Down Expand Up @@ -178,6 +180,87 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
t.Logf("%+v\n", apiErr)
}

func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set(xCustomHeader, xCustomHeaderValue)

// Send test responses
//nolint:lll
dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`)
dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)

_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})

stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

value := stream.Header().Get(xCustomHeader)
if value != xCustomHeaderValue {
t.Errorf("expected %s to be %s", xCustomHeaderValue, value)
}
}

func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
for k, v := range rateLimitHeaders {
switch val := v.(type) {
case int:
w.Header().Set(k, strconv.Itoa(val))
default:
w.Header().Set(k, fmt.Sprintf("%s", v))
}
}

// Send test responses
//nolint:lll
dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`)
dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)

_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})

stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

headers := stream.GetRateLimitHeaders()
bs1, _ := json.Marshal(headers)
bs2, _ := json.Marshal(rateLimitHeaders)
if string(bs1) != string(bs2) {
t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
}
}

func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
Expand Down
53 changes: 53 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ const (
xCustomHeaderValue = "test"
)

var (
rateLimitHeaders = map[string]any{
"x-ratelimit-limit-requests": 60,
"x-ratelimit-limit-tokens": 150000,
"x-ratelimit-remaining-requests": 59,
"x-ratelimit-remaining-tokens": 149984,
"x-ratelimit-reset-requests": "1s",
"x-ratelimit-reset-tokens": "6m0s",
}
)

func TestChatCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
Expand Down Expand Up @@ -97,6 +108,40 @@ func TestChatCompletionsWithHeaders(t *testing.T) {
}
}

// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server.
func TestChatCompletionsWithRateLimitHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")

headers := resp.GetRateLimitHeaders()
resetRequests := headers.ResetRequests.String()
if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] {
t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"])
}
resetRequestsTime := headers.ResetRequests.Time()
if resetRequestsTime.Before(time.Now()) {
t.Errorf("unexpected reset requetsts: %v", resetRequestsTime)
}

bs1, _ := json.Marshal(headers)
bs2, _ := json.Marshal(rateLimitHeaders)
if string(bs1) != string(bs2) {
t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
}
}

// TestChatCompletionsFunctions tests including a function call.
func TestChatCompletionsFunctions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
Expand Down Expand Up @@ -311,6 +356,14 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
}
resBytes, _ = json.Marshal(res)
w.Header().Set(xCustomHeader, xCustomHeaderValue)
for k, v := range rateLimitHeaders {
switch val := v.(type) {
case int:
w.Header().Set(k, strconv.Itoa(val))
default:
w.Header().Set(k, fmt.Sprintf("%s", v))
}
}
fmt.Fprintln(w, string(resBytes))
}

Expand Down
9 changes: 7 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ func (h *httpHeader) SetHeader(header http.Header) {
*h = httpHeader(header)
}

func (h httpHeader) Header() http.Header {
return http.Header(h)
func (h *httpHeader) Header() http.Header {
return http.Header(*h)
}

func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders {
return newRateLimitHeaders(h.Header())
}

// NewClient creates new OpenAI API client.
Expand Down Expand Up @@ -156,6 +160,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
response: resp,
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
httpHeader: httpHeader(resp.Header),
}, nil
}

Expand Down
43 changes: 43 additions & 0 deletions ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package openai

import (
"net/http"
"strconv"
"time"
)

// RateLimitHeaders struct represents Openai rate limits headers.
type RateLimitHeaders struct {
LimitRequests int `json:"x-ratelimit-limit-requests"`
LimitTokens int `json:"x-ratelimit-limit-tokens"`
RemainingRequests int `json:"x-ratelimit-remaining-requests"`
RemainingTokens int `json:"x-ratelimit-remaining-tokens"`
ResetRequests ResetTime `json:"x-ratelimit-reset-requests"`
ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"`
}

type ResetTime string

func (r ResetTime) String() string {
return string(r)
}

func (r ResetTime) Time() time.Time {
d, _ := time.ParseDuration(string(r))
return time.Now().Add(d)
}

func newRateLimitHeaders(h http.Header) RateLimitHeaders {
limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests"))
limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens"))
remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests"))
remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens"))
return RateLimitHeaders{
LimitRequests: limitReq,
LimitTokens: limitTokens,
RemainingRequests: remainingReq,
RemainingTokens: remainingTokens,
ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")),
ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")),
}
}
2 changes: 2 additions & 0 deletions stream_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ type streamReader[T streamable] struct {
response *http.Response
errAccumulator utils.ErrorAccumulator
unmarshaler utils.Unmarshaler

httpHeader
}

func (stream *streamReader[T]) Recv() (response T, err error) {
Expand Down

0 comments on commit b77d01e

Please sign in to comment.