diff --git a/files.go b/files.go index ec441c3fb..b701b9454 100644 --- a/files.go +++ b/files.go @@ -4,10 +4,7 @@ import ( "bytes" "context" "fmt" - "io" - "mime/multipart" "net/http" - "net/url" "os" ) @@ -33,77 +30,38 @@ type FilesList struct { Files []File `json:"data"` } -// isUrl is a helper function that determines whether the given FilePath -// is a remote URL or a local file path. -func isURL(path string) bool { - _, err := url.ParseRequestURI(path) - if err != nil { - return false - } - - u, err := url.Parse(path) - if err != nil || u.Scheme == "" || u.Host == "" { - return false - } - - return true -} - // CreateFile uploads a jsonl file to GPT3 -// FilePath can be either a local file path or a URL. +// FilePath must be a local file path. func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) { var b bytes.Buffer - w := multipart.NewWriter(&b) - - var fw io.Writer + builder := c.createFormBuilder(&b) - err = w.WriteField("purpose", request.Purpose) + err = builder.writeField("purpose", request.Purpose) if err != nil { return } - fw, err = w.CreateFormFile("file", request.FileName) + fileData, err := os.Open(request.FilePath) if err != nil { return } - var fileData io.ReadCloser - if isURL(request.FilePath) { - var remoteFile *http.Response - remoteFile, err = http.Get(request.FilePath) - if err != nil { - return - } - - defer remoteFile.Body.Close() - - // Check server response - if remoteFile.StatusCode != http.StatusOK { - err = fmt.Errorf("error, status code: %d, message: failed to fetch file", remoteFile.StatusCode) - return - } - - fileData = remoteFile.Body - } else { - fileData, err = os.Open(request.FilePath) - if err != nil { - return - } + err = builder.createFormFile("file", fileData) + if err != nil { + return } - _, err = io.Copy(fw, fileData) + err = builder.close() if err != nil { return } - w.Close() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/files"), &b) if err != nil { return } - req.Header.Set("Content-Type", w.FormDataContentType()) + req.Header.Set("Content-Type", builder.formDataContentType()) err = c.sendRequest(req, &file) diff --git a/files_test.go b/files_test.go index fbfe11c87..45b3900dc 100644 --- a/files_test.go +++ b/files_test.go @@ -1,14 +1,15 @@ -package openai_test +package openai //nolint:testpackage // testing private field import ( - . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" "encoding/json" "fmt" + "io" "net/http" + "os" "strconv" "testing" "time" @@ -34,7 +35,7 @@ func TestFileUpload(t *testing.T) { Purpose: "fine-tune", } _, err = client.CreateFile(ctx, req) - checks.NoError(t, err, "CreateFile erro") + checks.NoError(t, err, "CreateFile error") } // handleCreateFile Handles the images endpoint by the test server. @@ -78,3 +79,50 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) { resBytes, _ = json.Marshal(fileReq) fmt.Fprint(w, string(resBytes)) } + +func TestFileUploadWithFailingFormBuilder(t *testing.T) { + config := DefaultConfig("") + config.BaseURL = "" + client := NewClientWithConfig(config) + mockBuilder := &mockFormBuilder{} + client.createFormBuilder = func(io.Writer) formBuilder { + return mockBuilder + } + + ctx := context.Background() + req := FileRequest{ + FileName: "test.go", + FilePath: "client.go", + Purpose: "fine-tune", + } + + mockError := fmt.Errorf("mockWriteField error") + mockBuilder.mockWriteField = func(string, string) error { + return mockError + } + _, err := client.CreateFile(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") + + mockError = fmt.Errorf("mockCreateFormFile error") + mockBuilder.mockWriteField = func(string, string) error { + return nil + } + mockBuilder.mockCreateFormFile = func(string, *os.File) error { + return mockError + } + _, err = client.CreateFile(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") + + mockError = fmt.Errorf("mockClose error") + mockBuilder.mockWriteField = func(string, string) error { + return nil + } + mockBuilder.mockCreateFormFile = func(string, *os.File) error { + return nil + } + mockBuilder.mockClose = func() error { + return mockError + } + _, err = client.CreateFile(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") +}