Skip to content

Commit

Permalink
Implement optional io.Reader in AudioRequest (sashabaranov#303) (sash…
Browse files Browse the repository at this point in the history
…abaranov#265) (sashabaranov#331)

* Implement optional io.Reader in AudioRequest (sashabaranov#303) (sashabaranov#265)

* Fix err shadowing

* Add test to cover AudioRequest io.Reader usage

* Add additional test cases to cover AudioRequest io.Reader usage

* Add test to cover opening the file specified in an AudioRequest
  • Loading branch information
mdarc authored Jun 5, 2023
1 parent 61ba5f3 commit fa694c6
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 18 deletions.
45 changes: 35 additions & 10 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"

Expand All @@ -27,8 +28,14 @@ const (
// AudioRequest represents a request structure for audio API.
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
type AudioRequest struct {
Model string
FilePath string
Model string

// FilePath is either an existing file in your filesystem or a filename representing the contents of Reader.
FilePath string

// Reader is an optional io.Reader when you do not want to use an existing file.
Reader io.Reader

Prompt string // For translation, it should be in English
Temperature float32
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
Expand Down Expand Up @@ -95,15 +102,9 @@ func (r AudioRequest) HasJSONResponse() bool {
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
// audio processing.
func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
f, err := os.Open(request.FilePath)
if err != nil {
return fmt.Errorf("opening audio file: %w", err)
}
defer f.Close()

err = b.CreateFormFile("file", f)
err := createFileField(request, b)
if err != nil {
return fmt.Errorf("creating form file: %w", err)
return err
}

err = b.WriteField("model", request.Model)
Expand Down Expand Up @@ -146,3 +147,27 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
// Close the multipart writer
return b.Close()
}

// createFileField creates the "file" form field from either an existing file or by using the reader.
func createFileField(request AudioRequest, b utils.FormBuilder) error {
if request.Reader != nil {
err := b.CreateFormFileReader("file", request.Reader, request.FilePath)
if err != nil {
return fmt.Errorf("creating form using reader: %w", err)
}
return nil
}

f, err := os.Open(request.FilePath)
if err != nil {
return fmt.Errorf("opening audio file: %w", err)
}
defer f.Close()

err = b.CreateFormFile("file", f)
if err != nil {
return fmt.Errorf("creating form file: %w", err)
}

return nil
}
66 changes: 63 additions & 3 deletions audio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai //nolint:testpackage // testing private field

import (
"bytes"
"context"
"errors"
"fmt"
"io"
Expand All @@ -11,12 +12,10 @@ import (
"os"
"path/filepath"
"strings"
"testing"

"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"testing"
)

// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
Expand Down Expand Up @@ -65,6 +64,16 @@ func TestAudio(t *testing.T) {
_, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})

t.Run(tc.name+" (with reader)", func(t *testing.T) {
req := AudioRequest{
FilePath: "fake.webm",
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
Model: "whisper-3",
}
_, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})
}
}

Expand Down Expand Up @@ -213,3 +222,54 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails")
}
}

func TestCreateFileField(t *testing.T) {
t.Run("createFileField failing file", func(t *testing.T) {
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path)

req := AudioRequest{
FilePath: path,
}

mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder := &mockFormBuilder{
mockCreateFormFile: func(string, *os.File) error {
return mockFailedErr
},
}

err := createFileField(req, mockBuilder)
checks.ErrorIs(t, err, mockFailedErr, "createFileField using a file should return error if form builder fails")
})

t.Run("createFileField failing reader", func(t *testing.T) {
req := AudioRequest{
FilePath: "test.wav",
Reader: bytes.NewBuffer([]byte(`wav test contents`)),
}

mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder := &mockFormBuilder{
mockCreateFormFileReader: func(string, io.Reader, string) error {
return mockFailedErr
},
}

err := createFileField(req, mockBuilder)
checks.ErrorIs(t, err, mockFailedErr, "createFileField using a reader should return error if form builder fails")
})

t.Run("createFileField failing open", func(t *testing.T) {
req := AudioRequest{
FilePath: "non_existing_file.wav",
}

mockBuilder := &mockFormBuilder{}

err := createFileField(req, mockBuilder)
checks.HasError(t, err, "createFileField using file should return error when open file fails")
})
}
11 changes: 8 additions & 3 deletions image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,20 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
}

type mockFormBuilder struct {
mockCreateFormFile func(string, *os.File) error
mockWriteField func(string, string) error
mockClose func() error
mockCreateFormFile func(string, *os.File) error
mockCreateFormFileReader func(string, io.Reader, string) error
mockWriteField func(string, string) error
mockClose func() error
}

func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
return fb.mockCreateFormFile(fieldname, file)
}

func (fb *mockFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
return fb.mockCreateFormFileReader(fieldname, r, filename)
}

func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
return fb.mockWriteField(fieldname, value)
}
Expand Down
20 changes: 18 additions & 2 deletions internal/form_builder.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package openai

import (
"fmt"
"io"
"mime/multipart"
"os"
"path"
)

type FormBuilder interface {
CreateFormFile(fieldname string, file *os.File) error
CreateFormFileReader(fieldname string, r io.Reader, filename string) error
WriteField(fieldname, value string) error
Close() error
FormDataContentType() string
Expand All @@ -24,15 +27,28 @@ func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
}

func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
return fb.createFormFile(fieldname, file, file.Name())
}

func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
return fb.createFormFile(fieldname, r, path.Base(filename))
}

func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
if filename == "" {
return fmt.Errorf("filename cannot be empty")
}

fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename)
if err != nil {
return err
}

_, err = io.Copy(fieldWriter, file)
_, err = io.Copy(fieldWriter, r)
if err != nil {
return err
}

return nil
}

Expand Down

0 comments on commit fa694c6

Please sign in to comment.