diff --git a/client_test.go b/client_test.go index 1c9084585..664f9fb92 100644 --- a/client_test.go +++ b/client_test.go @@ -247,6 +247,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"CreateImage", func() (any, error) { return client.CreateImage(ctx, ImageRequest{}) }}, + {"CreateFileBytes", func() (any, error) { + return client.CreateFileBytes(ctx, FileBytesRequest{}) + }}, {"DeleteFile", func() (any, error) { return nil, client.DeleteFile(ctx, "") }}, diff --git a/files.go b/files.go index 9e521fbbe..371d06c69 100644 --- a/files.go +++ b/files.go @@ -15,6 +15,24 @@ type FileRequest struct { Purpose string `json:"purpose"` } +// PurposeType represents the purpose of the file when uploading. +type PurposeType string + +const ( + PurposeFineTune PurposeType = "fine-tune" + PurposeAssistants PurposeType = "assistants" +) + +// FileBytesRequest represents a file upload request. +type FileBytesRequest struct { + // the name of the uploaded file in OpenAI + Name string + // the bytes of the file + Bytes []byte + // the purpose of the file + Purpose PurposeType +} + // File struct represents an OpenAPI file. type File struct { Bytes int `json:"bytes"` @@ -36,6 +54,37 @@ type FilesList struct { httpHeader } +// CreateFileBytes uploads bytes directly to OpenAI without requiring a local file. +func (c *Client) CreateFileBytes(ctx context.Context, request FileBytesRequest) (file File, err error) { + var b bytes.Buffer + reader := bytes.NewReader(request.Bytes) + builder := c.createFormBuilder(&b) + + err = builder.WriteField("purpose", string(request.Purpose)) + if err != nil { + return + } + + err = builder.CreateFormFileReader("file", reader, request.Name) + if err != nil { + return + } + + err = builder.Close() + if err != nil { + return + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"), + withBody(&b), withContentType(builder.FormDataContentType())) + if err != nil { + return + } + + err = c.sendRequest(req, &file) + return +} + // CreateFile uploads a jsonl file to GPT3 // FilePath must be a local file path. func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) { diff --git a/files_api_test.go b/files_api_test.go index 330b88159..6f62a3fbc 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -16,6 +16,19 @@ import ( "github.com/sashabaranov/go-openai/internal/test/checks" ) +func TestFileBytesUpload(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.FileBytesRequest{ + Name: "foo", + Bytes: []byte("foo"), + Purpose: openai.PurposeFineTune, + } + _, err := client.CreateFileBytes(context.Background(), req) + checks.NoError(t, err, "CreateFile error") +} + func TestFileUpload(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/files_test.go b/files_test.go index f588b30dc..3c1b99fb4 100644 --- a/files_test.go +++ b/files_test.go @@ -11,6 +11,53 @@ import ( "github.com/sashabaranov/go-openai/internal/test/checks" ) +func TestFileBytesUploadWithFailingFormBuilder(t *testing.T) { + config := DefaultConfig("") + config.BaseURL = "" + client := NewClientWithConfig(config) + mockBuilder := &mockFormBuilder{} + client.createFormBuilder = func(io.Writer) utils.FormBuilder { + return mockBuilder + } + + ctx := context.Background() + req := FileBytesRequest{ + Name: "foo", + Bytes: []byte("foo"), + Purpose: PurposeAssistants, + } + + mockError := fmt.Errorf("mockWriteField error") + mockBuilder.mockWriteField = func(string, string) error { + return mockError + } + _, err := client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") + + mockError = fmt.Errorf("mockCreateFormFile error") + mockBuilder.mockWriteField = func(string, string) error { + return nil + } + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { + return mockError + } + _, err = client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") + + mockError = fmt.Errorf("mockClose error") + mockBuilder.mockWriteField = func(string, string) error { + return nil + } + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { + return nil + } + mockBuilder.mockClose = func() error { + return mockError + } + _, err = client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") +} + func TestFileUploadWithFailingFormBuilder(t *testing.T) { config := DefaultConfig("") config.BaseURL = "" @@ -55,6 +102,9 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) { return mockError } _, err = client.CreateFile(ctx, req) + if err == nil { + t.Fatal("CreateFile should return error if form builder fails") + } checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") }