diff --git a/image.go b/image.go index b9fde5ac6..107d1bb28 100644 --- a/image.go +++ b/image.go @@ -86,20 +86,65 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - // mask - mask, err := writer.CreateFormFile("mask", request.Mask.Name()) + // mask, it is optional + if request.Mask != nil { + mask, err2 := writer.CreateFormFile("mask", request.Mask.Name()) + if err2 != nil { + return + } + _, err = io.Copy(mask, request.Mask) + if err != nil { + return + } + } + + err = writer.WriteField("prompt", request.Prompt) if err != nil { return } - _, err = io.Copy(mask, request.Mask) + err = writer.WriteField("n", strconv.Itoa(request.N)) + if err != nil { + return + } + err = writer.WriteField("size", request.Size) + if err != nil { + return + } + writer.Close() + urlSuffix := "/images/edits" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) if err != nil { return } - err = writer.WriteField("prompt", request.Prompt) + req.Header.Set("Content-Type", writer.FormDataContentType()) + err = c.sendRequest(req, &response) + return +} + +// ImageVariRequest represents the request structure for the image API. +type ImageVariRequest struct { + Image *os.File `json:"image,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` +} + +// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. +// Use abbreviations(vari for variation) because ci-lint has a single-line length limit ... +func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) (response ImageResponse, err error) { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + // image + image, err := writer.CreateFormFile("image", request.Image.Name()) + if err != nil { + return + } + _, err = io.Copy(image, request.Image) if err != nil { return } + err = writer.WriteField("n", strconv.Itoa(request.N)) if err != nil { return @@ -109,7 +154,8 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } writer.Close() - urlSuffix := "/images/edits" + //https://platform.openai.com/docs/api-reference/images/create-variation + urlSuffix := "/images/variations" req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) if err != nil { return diff --git a/image_test.go b/image_test.go index db558e0b3..b7949c896 100644 --- a/image_test.go +++ b/image_test.go @@ -132,6 +132,43 @@ func TestImageEdit(t *testing.T) { } } +func TestImageEditWithoutMask(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + origin, err := os.Create("image.png") + if err != nil { + t.Error("open origin file error") + return + } + + defer func() { + origin.Close() + os.Remove("image.png") + }() + + req := ImageEditRequest{ + Image: origin, + Prompt: "There is a turtle in the pool", + N: 3, + Size: CreateImageSize1024x1024, + } + _, err = client.CreateEditImage(ctx, req) + if err != nil { + t.Fatalf("CreateImage error: %v", err) + } +} + // handleEditImageEndpoint Handles the images endpoint by the test server. func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { var resBytes []byte @@ -162,3 +199,70 @@ func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { resBytes, _ = json.Marshal(responses) fmt.Fprintln(w, string(resBytes)) } + +func TestImageVariation(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + origin, err := os.Create("image.png") + if err != nil { + t.Error("open origin file error") + return + } + + defer func() { + origin.Close() + os.Remove("image.png") + }() + + req := ImageVariRequest{ + Image: origin, + N: 3, + Size: CreateImageSize1024x1024, + } + _, err = client.CreateVariImage(ctx, req) + if err != nil { + t.Fatalf("CreateImage error: %v", err) + } +} + +// handleVariateImageEndpoint Handles the images endpoint by the test server. +func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // imagess only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + responses := ImageResponse{ + Created: time.Now().Unix(), + Data: []ImageResponseDataInner{ + { + URL: "test-url1", + B64JSON: "", + }, + { + URL: "test-url2", + B64JSON: "", + }, + { + URL: "test-url3", + B64JSON: "", + }, + }, + } + + resBytes, _ = json.Marshal(responses) + fmt.Fprintln(w, string(resBytes)) +}