Skip to content

Commit

Permalink
refactor: refactoring http request creation and sending (sashabaranov…
Browse files Browse the repository at this point in the history
…#395)

* refactoring http request creation and sending

* fix lint error

* increase the test coverage of client.go

* refactor: Change the style of HTTPRequestBuilder.Build func to one-argument-per-line.
  • Loading branch information
vvatanabe committed Jun 22, 2023
1 parent 157de06 commit f1b6696
Show file tree
Hide file tree
Showing 20 changed files with 209 additions and 126 deletions.
2 changes: 1 addition & 1 deletion api_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) {
az.OrgID = c.OrgID

cli := NewClientWithConfig(az)
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil, "")
req, err := cli.newRequest(context.Background(), "POST", "/chat/completions")
if err != nil {
t.Errorf("Failed to create request: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ func (c *Client) callAudioAPI(
}

urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), &formBody)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model),
withBody(&formBody), withContentType(builder.FormDataContentType()))
if err != nil {
return AudioResponse{}, err
}
req.Header.Add("Content-Type", builder.FormDataContentType())

if request.HasJSONResponse() {
err = c.sendRequest(req, &response)
Expand Down
2 changes: 1 addition & 1 deletion chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (c *Client) CreateChatCompletion(
return
}

req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
if err != nil {
return
}
Expand Down
22 changes: 5 additions & 17 deletions chat_stream.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package openai

import (
"bufio"
"context"

utils "github.com/sashabaranov/go-openai/internal"
"net/http"
)

type ChatCompletionStreamChoiceDelta struct {
Expand Down Expand Up @@ -48,27 +46,17 @@ func (c *Client) CreateChatCompletionStream(
}

request.Stream = true
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
if err != nil {
return
return nil, err
}

resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
resp, err := sendRequestStream[ChatCompletionStreamResponse](c, req)
if err != nil {
return
}
if isFailureStatusCode(resp) {
return nil, c.handleErrorResp(resp)
}

stream = &ChatCompletionStream{
streamReader: &streamReader[ChatCompletionStreamResponse]{
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
},
streamReader: resp,
}
return
}
94 changes: 72 additions & 22 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai

import (
"bufio"
"context"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -45,6 +46,42 @@ func NewOrgClient(authToken, org string) *Client {
return NewClientWithConfig(config)
}

type requestOptions struct {
body any
header http.Header
}

type requestOption func(*requestOptions)

func withBody(body any) requestOption {
return func(args *requestOptions) {
args.body = body
}
}

func withContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
}
}

func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) {
// Default Options
args := &requestOptions{
body: nil,
header: make(http.Header),
}
for _, setter := range setters {
setter(args)
}
req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header)
if err != nil {
return nil, err
}
c.setCommonHeaders(req)
return req, nil
}

func (c *Client) sendRequest(req *http.Request, v any) error {
req.Header.Set("Accept", "application/json; charset=utf-8")

Expand All @@ -55,8 +92,6 @@ func (c *Client) sendRequest(req *http.Request, v any) error {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}

c.setCommonHeaders(req)

res, err := c.config.HTTPClient.Do(req)
if err != nil {
return err
Expand All @@ -71,6 +106,41 @@ func (c *Client) sendRequest(req *http.Request, v any) error {
return decodeResponse(res.Body, v)
}

func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err error) {
resp, err := c.config.HTTPClient.Do(req)
if err != nil {
return
}

if isFailureStatusCode(resp) {
err = c.handleErrorResp(resp)
return
}
return resp.Body, nil
}

func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")

resp, err := client.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
if err != nil {
return new(streamReader[T]), err
}
if isFailureStatusCode(resp) {
return new(streamReader[T]), client.handleErrorResp(resp)
}
return &streamReader[T]{
emptyMessagesLimit: client.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
}, nil
}

func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
// Azure API Key authentication
Expand Down Expand Up @@ -138,26 +208,6 @@ func (c *Client) fullURL(suffix string, args ...any) string {
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
}

func (c *Client) newStreamRequest(
ctx context.Context,
method string,
urlSuffix string,
body any,
model string) (*http.Request, error) {
req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix, model), body)
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")

c.setCommonHeaders(req)
return req, nil
}

func (c *Client) handleErrorResp(resp *http.Response) error {
var errRes ErrorResponse
err := json.NewDecoder(resp.Body).Decode(&errRes)
Expand Down
25 changes: 20 additions & 5 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var errTestRequestBuilderFailed = errors.New("test request builder failed")

type failingRequestBuilder struct{}

func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any, _ http.Header) (*http.Request, error) {
return nil, errTestRequestBuilderFailed
}

Expand All @@ -41,9 +41,10 @@ func TestDecodeResponse(t *testing.T) {
stringInput := ""

testCases := []struct {
name string
value interface{}
body io.Reader
name string
value interface{}
body io.Reader
hasError bool
}{
{
name: "nil input",
Expand All @@ -60,18 +61,32 @@ func TestDecodeResponse(t *testing.T) {
value: &map[string]interface{}{},
body: bytes.NewReader([]byte(`{"test": "test"}`)),
},
{
name: "reader return error",
value: &stringInput,
body: &errorReader{err: errors.New("dummy")},
hasError: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := decodeResponse(tc.body, tc.value)
if err != nil {
if (err != nil) != tc.hasError {
t.Errorf("Unexpected error: %v", err)
}
})
}
}

type errorReader struct {
err error
}

func (e *errorReader) Read(_ []byte) (n int, err error) {
return 0, e.err
}

func TestHandleErrorResp(t *testing.T) {
// var errRes *ErrorResponse
var errRes ErrorResponse
Expand Down
2 changes: 1 addition & 1 deletion completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (c *Client) CreateCompletion(
return
}

req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion edits.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type EditsResponse struct {

// Perform an API call to the Edits endpoint.
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request))
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ type EmbeddingRequest struct {
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
// https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), withBody(request))
if err != nil {
return
}
Expand Down
4 changes: 2 additions & 2 deletions engines.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type EnginesList struct {
// ListEngines Lists the currently available engines, and provides basic
// information about each option such as the owner and availability.
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/engines"))
if err != nil {
return
}
Expand All @@ -38,7 +38,7 @@ func (c *Client) GetEngine(
engineID string,
) (engine Engine, err error) {
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
if err != nil {
return
}
Expand Down
28 changes: 7 additions & 21 deletions files.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,19 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
return
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/files"), &b)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"),
withBody(&b), withContentType(builder.FormDataContentType()))
if err != nil {
return
}

req.Header.Set("Content-Type", builder.FormDataContentType())

err = c.sendRequest(req, &file)

return
}

// DeleteFile deletes an existing file.
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/files/"+fileID))
if err != nil {
return
}
Expand All @@ -83,7 +81,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
// ListFiles Lists the currently available files,
// and provides basic information about each file such as the file name and purpose.
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/files"), nil)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/files"))
if err != nil {
return
}
Expand All @@ -96,7 +94,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
// such as the file name and purpose.
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
urlSuffix := fmt.Sprintf("/files/%s", fileID)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
if err != nil {
return
}
Expand All @@ -107,23 +105,11 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err

func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) {
urlSuffix := fmt.Sprintf("/files/%s/content", fileID)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}

c.setCommonHeaders(req)

res, err := c.config.HTTPClient.Do(req)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
if err != nil {
return
}

if isFailureStatusCode(res) {
err = c.handleErrorResp(res)
return
}

content = res.Body
content, err = c.sendRequestRaw(req)
return
}
Loading

0 comments on commit f1b6696

Please sign in to comment.