Skip to content

Commit

Permalink
add testable json marshaller (sashabaranov#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
sashabaranov authored Mar 15, 2023
1 parent ba77a64 commit 53d195c
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 18 deletions.
13 changes: 9 additions & 4 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,22 @@ import (
// Client is OpenAI GPT-3 API client.
type Client struct {
config ClientConfig

marshaller marshaller
}

// NewClient creates new OpenAI API client.
func NewClient(authToken string) *Client {
config := DefaultConfig(authToken)
return &Client{config}
return NewClientWithConfig(config)
}

// NewClientWithConfig creates new OpenAI API client for specified config.
func NewClientWithConfig(config ClientConfig) *Client {
return &Client{config}
return &Client{
config: config,
marshaller: &jsonMarshaller{},
}
}

// NewOrgClient creates new OpenAI API client for specified Organization ID.
Expand All @@ -30,7 +35,7 @@ func NewClientWithConfig(config ClientConfig) *Client {
func NewOrgClient(authToken, org string) *Client {
config := DefaultConfig(authToken)
config.OrgID = org
return &Client{config}
return NewClientWithConfig(config)
}

func (c *Client) sendRequest(req *http.Request, v interface{}) error {
Expand Down Expand Up @@ -90,7 +95,7 @@ func (c *Client) newStreamRequest(
var reqBody []byte
if body != nil {
var err error
reqBody, err = json.Marshal(body)
reqBody, err = c.marshaller.marshal(body)
if err != nil {
return nil, err
}
Expand Down
3 changes: 1 addition & 2 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
)
Expand Down Expand Up @@ -74,7 +73,7 @@ func (c *Client) CreateChatCompletion(
}

var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
Expand Down
3 changes: 1 addition & 2 deletions completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
)
Expand Down Expand Up @@ -107,7 +106,7 @@ func (c *Client) CreateCompletion(
}

var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
Expand Down
3 changes: 1 addition & 2 deletions edits.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"net/http"
)

Expand Down Expand Up @@ -34,7 +33,7 @@ type EditsResponse struct {
// Perform an API call to the Edits endpoint.
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
Expand Down
3 changes: 1 addition & 2 deletions embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"net/http"
)

Expand Down Expand Up @@ -135,7 +134,7 @@ type EmbeddingRequest struct {
// https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
Expand Down
3 changes: 1 addition & 2 deletions fine_tunes.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)
Expand Down Expand Up @@ -70,7 +69,7 @@ type FineTuneDeleteResponse struct {

func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
Expand Down
3 changes: 1 addition & 2 deletions image.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"io"
"mime/multipart"
"net/http"
Expand Down Expand Up @@ -47,7 +46,7 @@ type ImageResponseDataInner struct {
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
Expand Down
15 changes: 15 additions & 0 deletions marshaller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package openai

import (
"encoding/json"
)

type marshaller interface {
marshal(value any) ([]byte, error)
}

type jsonMarshaller struct{}

func (jm *jsonMarshaller) marshal(value any) ([]byte, error) {
return json.Marshal(value)
}
71 changes: 71 additions & 0 deletions marshaller_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package openai //nolint:testpackage // testing private field

import (
"github.com/sashabaranov/go-openai/internal/test"

"context"
"errors"
"testing"
)

type failingMarshaller struct{}

var errTestMarshallerFailed = errors.New("test marshaller failed")

func (jm *failingMarshaller) marshal(value any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed
}

func TestClientReturnMarshallerErrors(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.marshaller = &failingMarshaller{}

ctx := context.Background()

_, err = client.CreateCompletion(ctx, CompletionRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}

_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}

_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}

_, err = client.CreateFineTune(ctx, FineTuneRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}

_, err = client.Moderations(ctx, ModerationRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}

_, err = client.Edits(ctx, EditsRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}

_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}

_, err = client.CreateImage(ctx, ImageRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
}
3 changes: 1 addition & 2 deletions moderation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"net/http"
)

Expand Down Expand Up @@ -53,7 +52,7 @@ type ModerationResponse struct {
// Input can be an array or slice but a string will reduce the complexity.
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
Expand Down

0 comments on commit 53d195c

Please sign in to comment.