Skip to content

Commit

Permalink
Move form_builder into internal pkg. (sashabaranov#311)
Browse files Browse the repository at this point in the history
* Move form_uilder into internal pkg.

* Fix import of audio.go

* Reorganize.

* Fix import.

* Fix

---------

Co-authored-by: JoyShi <joy.shi@sap.com>
  • Loading branch information
JiayueShi and jiayueshi-work authored May 16, 2023
1 parent 83d03fc commit 21eef5b
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 90 deletions.
20 changes: 11 additions & 9 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"fmt"
"net/http"
"os"

utils "github.com/sashabaranov/go-openai/internal"
)

// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.
Expand Down Expand Up @@ -72,7 +74,7 @@ func (c *Client) callAudioAPI(
if err != nil {
return AudioResponse{}, err
}
req.Header.Add("Content-Type", builder.formDataContentType())
req.Header.Add("Content-Type", builder.FormDataContentType())

if request.HasJSONResponse() {
err = c.sendRequest(req, &response)
Expand All @@ -92,55 +94,55 @@ func (r AudioRequest) HasJSONResponse() bool {

// audioMultipartForm creates a form with audio file contents and the name of the model to use for
// audio processing.
func audioMultipartForm(request AudioRequest, b formBuilder) error {
func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
f, err := os.Open(request.FilePath)
if err != nil {
return fmt.Errorf("opening audio file: %w", err)
}
defer f.Close()

err = b.createFormFile("file", f)
err = b.CreateFormFile("file", f)
if err != nil {
return fmt.Errorf("creating form file: %w", err)
}

err = b.writeField("model", request.Model)
err = b.WriteField("model", request.Model)
if err != nil {
return fmt.Errorf("writing model name: %w", err)
}

// Create a form field for the prompt (if provided)
if request.Prompt != "" {
err = b.writeField("prompt", request.Prompt)
err = b.WriteField("prompt", request.Prompt)
if err != nil {
return fmt.Errorf("writing prompt: %w", err)
}
}

// Create a form field for the format (if provided)
if request.Format != "" {
err = b.writeField("response_format", string(request.Format))
err = b.WriteField("response_format", string(request.Format))
if err != nil {
return fmt.Errorf("writing format: %w", err)
}
}

// Create a form field for the temperature (if provided)
if request.Temperature != 0 {
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))
err = b.WriteField("temperature", fmt.Sprintf("%.2f", request.Temperature))
if err != nil {
return fmt.Errorf("writing temperature: %w", err)
}
}

// Create a form field for the language (if provided)
if request.Language != "" {
err = b.writeField("language", request.Language)
err = b.WriteField("language", request.Language)
if err != nil {
return fmt.Errorf("writing language: %w", err)
}
}

// Close the multipart writer
return b.close()
return b.Close()
}
8 changes: 5 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ import (
"io"
"net/http"
"strings"

utils "github.com/sashabaranov/go-openai/internal"
)

// Client is OpenAI GPT-3 API client.
type Client struct {
config ClientConfig

requestBuilder requestBuilder
createFormBuilder func(io.Writer) formBuilder
createFormBuilder func(io.Writer) utils.FormBuilder
}

// NewClient creates new OpenAI API client.
Expand All @@ -28,8 +30,8 @@ func NewClientWithConfig(config ClientConfig) *Client {
return &Client{
config: config,
requestBuilder: newRequestBuilder(),
createFormBuilder: func(body io.Writer) formBuilder {
return newFormBuilder(body)
createFormBuilder: func(body io.Writer) utils.FormBuilder {
return utils.NewFormBuilder(body)
},
}
}
Expand Down
8 changes: 4 additions & 4 deletions files.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
var b bytes.Buffer
builder := c.createFormBuilder(&b)

err = builder.writeField("purpose", request.Purpose)
err = builder.WriteField("purpose", request.Purpose)
if err != nil {
return
}
Expand All @@ -46,12 +46,12 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
return
}

err = builder.createFormFile("file", fileData)
err = builder.CreateFormFile("file", fileData)
if err != nil {
return
}

err = builder.close()
err = builder.Close()
if err != nil {
return
}
Expand All @@ -61,7 +61,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
return
}

req.Header.Set("Content-Type", builder.formDataContentType())
req.Header.Set("Content-Type", builder.FormDataContentType())

err = c.sendRequest(req, &file)

Expand Down
3 changes: 2 additions & 1 deletion files_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai //nolint:testpackage // testing private field

import (
. "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

Expand Down Expand Up @@ -85,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) {
config.BaseURL = ""
client := NewClientWithConfig(config)
mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) formBuilder {
client.createFormBuilder = func(io.Writer) FormBuilder {
return mockBuilder
}

Expand Down
49 changes: 0 additions & 49 deletions form_builder.go

This file was deleted.

28 changes: 14 additions & 14 deletions image.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,40 +69,40 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
builder := c.createFormBuilder(body)

// image
err = builder.createFormFile("image", request.Image)
err = builder.CreateFormFile("image", request.Image)
if err != nil {
return
}

// mask, it is optional
if request.Mask != nil {
err = builder.createFormFile("mask", request.Mask)
err = builder.CreateFormFile("mask", request.Mask)
if err != nil {
return
}
}

err = builder.writeField("prompt", request.Prompt)
err = builder.WriteField("prompt", request.Prompt)
if err != nil {
return
}

err = builder.writeField("n", strconv.Itoa(request.N))
err = builder.WriteField("n", strconv.Itoa(request.N))
if err != nil {
return
}

err = builder.writeField("size", request.Size)
err = builder.WriteField("size", request.Size)
if err != nil {
return
}

err = builder.writeField("response_format", request.ResponseFormat)
err = builder.WriteField("response_format", request.ResponseFormat)
if err != nil {
return
}

err = builder.close()
err = builder.Close()
if err != nil {
return
}
Expand All @@ -113,7 +113,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
return
}

req.Header.Set("Content-Type", builder.formDataContentType())
req.Header.Set("Content-Type", builder.FormDataContentType())
err = c.sendRequest(req, &response)
return
}
Expand All @@ -133,27 +133,27 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
builder := c.createFormBuilder(body)

// image
err = builder.createFormFile("image", request.Image)
err = builder.CreateFormFile("image", request.Image)
if err != nil {
return
}

err = builder.writeField("n", strconv.Itoa(request.N))
err = builder.WriteField("n", strconv.Itoa(request.N))
if err != nil {
return
}

err = builder.writeField("size", request.Size)
err = builder.WriteField("size", request.Size)
if err != nil {
return
}

err = builder.writeField("response_format", request.ResponseFormat)
err = builder.WriteField("response_format", request.ResponseFormat)
if err != nil {
return
}

err = builder.close()
err = builder.Close()
if err != nil {
return
}
Expand All @@ -165,7 +165,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
return
}

req.Header.Set("Content-Type", builder.formDataContentType())
req.Header.Set("Content-Type", builder.FormDataContentType())
err = c.sendRequest(req, &response)
return
}
13 changes: 7 additions & 6 deletions image_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai //nolint:testpackage // testing private field

import (
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

Expand Down Expand Up @@ -268,19 +269,19 @@ type mockFormBuilder struct {
mockClose func() error
}

func (fb *mockFormBuilder) createFormFile(fieldname string, file *os.File) error {
func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
return fb.mockCreateFormFile(fieldname, file)
}

func (fb *mockFormBuilder) writeField(fieldname, value string) error {
func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
return fb.mockWriteField(fieldname, value)
}

func (fb *mockFormBuilder) close() error {
func (fb *mockFormBuilder) Close() error {
return fb.mockClose()
}

func (fb *mockFormBuilder) formDataContentType() string {
func (fb *mockFormBuilder) FormDataContentType() string {
return ""
}

Expand All @@ -290,7 +291,7 @@ func TestImageFormBuilderFailures(t *testing.T) {
client := NewClientWithConfig(config)

mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) formBuilder {
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
return mockBuilder
}
ctx := context.Background()
Expand Down Expand Up @@ -357,7 +358,7 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
client := NewClientWithConfig(config)

mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) formBuilder {
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
return mockBuilder
}
ctx := context.Background()
Expand Down
Loading

0 comments on commit 21eef5b

Please sign in to comment.