Skip to content

Commit

Permalink
add ImageVariRequest/ImageEditRequest.ResponseFormat (sashabaranov#264)
Browse files Browse the repository at this point in the history
* add ImageEditRequest.ResponseFormat

* add ImageEditRequest/ImageVariRequest.ResponseFormat

* complete image_test

* delete var prompt param

---------

Co-authored-by: Aceld <liudanbing@tal.com>
  • Loading branch information
aceld and Aceld authored Apr 18, 2023
1 parent 061c97e commit 3b10c03
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 20 deletions.
33 changes: 25 additions & 8 deletions image.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons

// ImageEditRequest represents the request structure for the image API.
type ImageEditRequest struct {
Image *os.File `json:"image,omitempty"`
Mask *os.File `json:"mask,omitempty"`
Prompt string `json:"prompt,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Image *os.File `json:"image,omitempty"`
Mask *os.File `json:"mask,omitempty"`
Prompt string `json:"prompt,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}

// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API.
Expand All @@ -85,14 +86,22 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
if err != nil {
return
}

err = builder.writeField("n", strconv.Itoa(request.N))
if err != nil {
return
}

err = builder.writeField("size", request.Size)
if err != nil {
return
}

err = builder.writeField("response_format", request.ResponseFormat)
if err != nil {
return
}

err = builder.close()
if err != nil {
return
Expand All @@ -111,9 +120,10 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)

// ImageVariRequest represents the request structure for the image API.
type ImageVariRequest struct {
Image *os.File `json:"image,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Image *os.File `json:"image,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}

// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
Expand All @@ -132,10 +142,17 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
if err != nil {
return
}

err = builder.writeField("size", request.Size)
if err != nil {
return
}

err = builder.writeField("response_format", request.ResponseFormat)
if err != nil {
return
}

err = builder.close()
if err != nil {
return
Expand Down
35 changes: 23 additions & 12 deletions image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,12 @@ func TestImageEdit(t *testing.T) {
}()

req := ImageEditRequest{
Image: origin,
Mask: mask,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
Image: origin,
Mask: mask,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL,
}
_, err = client.CreateEditImage(ctx, req)
checks.NoError(t, err, "CreateImage error")
Expand Down Expand Up @@ -154,10 +155,11 @@ func TestImageEditWithoutMask(t *testing.T) {
}()

req := ImageEditRequest{
Image: origin,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
Image: origin,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL,
}
_, err = client.CreateEditImage(ctx, req)
checks.NoError(t, err, "CreateImage error")
Expand Down Expand Up @@ -220,9 +222,10 @@ func TestImageVariation(t *testing.T) {
}()

req := ImageVariRequest{
Image: origin,
N: 3,
Size: CreateImageSize1024x1024,
Image: origin,
N: 3,
Size: CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL,
}
_, err = client.CreateVariImage(ctx, req)
checks.NoError(t, err, "CreateImage error")
Expand Down Expand Up @@ -336,6 +339,10 @@ func TestImageFormBuilderFailures(t *testing.T) {
_, err = client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")

failForField = "response_format"
_, err = client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")

failForField = ""
mockBuilder.mockClose = func() error {
return mockFailedErr
Expand Down Expand Up @@ -384,6 +391,10 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
_, err = client.CreateVariImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")

failForField = "response_format"
_, err = client.CreateVariImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")

failForField = ""
mockBuilder.mockClose = func() error {
return mockFailedErr
Expand Down

0 comments on commit 3b10c03

Please sign in to comment.