Skip to content

Commit

Permalink
feat: add /images support
Browse files Browse the repository at this point in the history
  • Loading branch information
0x9ef committed Dec 30, 2022
1 parent c204a59 commit 2b39904
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 21 deletions.
8 changes: 6 additions & 2 deletions completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@ func (e *Engine) Completion(ctx context.Context, opts *CompletionOptions) (*Comp
if err := e.validate.StructCtx(ctx, opts); err != nil {
return nil, err
}
url := e.apiBaseURL + "/completions"
req, err := e.newReq(ctx, http.MethodPost, url, opts)
uri := e.apiBaseURL + "/completions"
r, err := marshalJson(opts)
if err != nil {
return nil, err
}
req, err := e.newReq(ctx, http.MethodPost, uri, "json", r)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

func TestCompletion(t *testing.T) {
e := New(os.Getenv("OPENAPI_KEY"))
e := New(os.Getenv("OPENAI_KEY"))
r, err := e.Completion(context.Background(), &CompletionOptions{
Model: ModelTextDavinci001,
Prompt: []string{"Write a little bit of Wikipedia. What is that?"},
Expand Down
6 changes: 5 additions & 1 deletion edits.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ func (e *Engine) Edit(ctx context.Context, opts *EditOptions) (*EditResponse, er
return nil, err
}
url := e.apiBaseURL + "/edits"
req, err := e.newReq(ctx, http.MethodPost, url, opts)
r, err := marshalJson(opts)
if err != nil {
return nil, err
}
req, err := e.newReq(ctx, http.MethodPost, url, "json", r)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion edits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

func TestEdits(t *testing.T) {
e := New(os.Getenv("OPENAPI_KEY"))
e := New(os.Getenv("OPENAI_KEY"))
r, err := e.Edit(context.Background(), &EditOptions{
Model: ModelTextDavinci001,
Input: "Write a little bit of Wikipedia. What is that?",
Expand Down
169 changes: 169 additions & 0 deletions images.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// Copyright (c) 2022 0x9ef. All rights reserved.
// Use of this source code is governed by an MIT license
// that can be found in the LICENSE file.
package openai

import (
"context"
"net/http"
"net/url"
"strconv"
"strings"
)

type ImageCreateOptions struct {
Prompt string `json:"prompt" binding:"required"`
// The number of images to generate.
// Must be between 1 and 10.
N int `json:"n,omitempty" binding:"omitempty,min=1,max=10"`
// The size of the generated images.
// Must be one of 256x256, 512x512, or 1024x1024.
Size string `json:"size,omitempty" binding:"omitempty,oneof=256x256 512x512 1024x1024"`
// The format in which the generated images are returned.
// Must be one of url or b64_json
ResponseFormat string `json:"response_format,omitempty" binding:"omitempty,oneof=url b64_json"`
}

type ImageCreateResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
} `json:"data"`
}

// ImageCreate given a prompt and/or an input image, the model will generate a new image.
//
// Docs: https://beta.openai.com/docs/api-reference/images/create
func (e *Engine) ImageCreate(ctx context.Context, opts *ImageCreateOptions) (*ImageCreateResponse, error) {
if err := e.validate.StructCtx(ctx, opts); err != nil {
return nil, err
}
url := e.apiBaseURL + "/images/generations"
r, err := marshalJson(opts)
if err != nil {
return nil, err
}
req, err := e.newReq(ctx, http.MethodPost, url, "json", r)
if err != nil {
return nil, err
}
resp, err := e.doReq(req)
if err != nil {
return nil, err
}
var jsonResp ImageCreateResponse
if err := unmarshal(resp, &jsonResp); err != nil {
return nil, err
}
return &jsonResp, nil
}

type ImageEditOptions struct {
// The image to edit. Must be a valid PNG file, less than 4MB, and square.
// If mask is not provided, image must have transparency, which will be used as the mask.
Image string `binding:"required"`
// An additional image whose fully transparent areas (e.g. where alpha is zero)
// indicate where image should be edited. Must be a valid PNG file, less than 4MB,
// and have the same dimensions as image.
Mask string
// A text description of the desired image(s). The maximum length is 1000 characters.
Prompt string `binding:"required,max=1000"`
// The number of images to generate.
// Must be between 1 and 10.
N int `binding:"omitempty,min=1,max=10"`
// The size of the generated images.
// Must be one of 256x256, 512x512, or 1024x1024.
Size string `binding:"omitempty,oneof=256x256 512x512 1024x1024"`
// The format in which the generated images are returned.
// Must be one of url or b64_json
ResponseFormat string `binding:"omitempty,oneof=url b64_json"`
}

type ImageEditResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
} `json:"data"`
}

// ImageEdit creates an edited or extended image given an original image and a prompt.
//
// Docs: https://beta.openai.com/docs/api-reference/images/create-edit
func (e *Engine) ImageEdit(ctx context.Context, opts *ImageEditOptions) (*ImageEditResponse, error) {
if err := e.validate.StructCtx(ctx, opts); err != nil {
return nil, err
}
uri := e.apiBaseURL + "/images/edits"
postValues := url.Values{
"image": []string{opts.Image},
"mask": []string{opts.Mask},
"prompt": []string{opts.Prompt},
"n": []string{strconv.Itoa(opts.N)},
"size": []string{opts.Size},
"response_format": []string{opts.ResponseFormat},
}
req, err := e.newReq(ctx, http.MethodPost, uri, "formData", strings.NewReader(postValues.Encode()))
if err != nil {
return nil, err
}
resp, err := e.doReq(req)
if err != nil {
return nil, err
}
var jsonResp ImageEditResponse
if err := unmarshal(resp, &jsonResp); err != nil {
return nil, err
}
return &jsonResp, nil
}

type ImageVariationOptions struct {
// The image to edit. Must be a valid PNG file, less than 4MB, and square.
// If mask is not provided, image must have transparency, which will be used as the mask.
Image string `binding:"required"`
// The number of images to generate.
// Must be between 1 and 10.
N int `binding:"omitempty,min=1,max=10"`
// The size of the generated images.
// Must be one of 256x256, 512x512, or 1024x1024.
Size string `binding:"omitempty,oneof=256x256 512x512 1024x1024"`
// The format in which the generated images are returned.
// Must be one of url or b64_json
ResponseFormat string `binding:"omitempty,oneof=url b64_json"`
}

type ImageVariationResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
} `json:"data"`
}

// ImageVariation creates a variation of a given image.
//
// Docs: https://beta.openai.com/docs/api-reference/images/create-variation
func (e *Engine) ImageVariation(ctx context.Context, opts *ImageVariationOptions) (*ImageCreateResponse, error) {
if err := e.validate.StructCtx(ctx, opts); err != nil {
return nil, err
}
uri := e.apiBaseURL + "/images/variations"
postValues := url.Values{
"model": []string{opts.Image},
"n": []string{strconv.Itoa(opts.N)},
"size": []string{opts.Size},
"response_format": []string{opts.ResponseFormat},
}
req, err := e.newReq(ctx, http.MethodPost, uri, "formData", strings.NewReader(postValues.Encode()))
if err != nil {
return nil, err
}
resp, err := e.doReq(req)
if err != nil {
return nil, err
}
var jsonResp ImageCreateResponse
if err := unmarshal(resp, &jsonResp); err != nil {
return nil, err
}
return &jsonResp, nil
}
58 changes: 58 additions & 0 deletions images_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) 2022 0x9ef. All rights reserved.
// Use of this source code is governed by an MIT license
// that can be found in the LICENSE file.
package openai

import (
"context"
"encoding/json"
"log"
"os"
"testing"
)

func TestImageCreate(t *testing.T) {
e := New(os.Getenv("OPENAI_KEY"))
r, err := e.ImageCreate(context.Background(), &ImageCreateOptions{
Prompt: "Write a little bit of Wikipedia. What is that?",
})
if err != nil {
log.Fatal(err)
}
if b, err := json.MarshalIndent(r, "", " "); err != nil {
log.Fatal(err)
} else {
log.Println(string(b))
}
}

func TestImageEdit(t *testing.T) {
e := New(os.Getenv("OPENAI_KEY"))
r, err := e.ImageEdit(context.Background(), &ImageEditOptions{
Image: "000test.png",
Prompt: "Write a little bit of Wikipedia. What is that?",
})
if err != nil {
log.Fatal(err)
}
if b, err := json.MarshalIndent(r, "", " "); err != nil {
log.Fatal(err)
} else {
log.Println(string(b))
}
}

func TestImageVariation(t *testing.T) {
e := New(os.Getenv("OPENAI_KEY"))
r, err := e.ImageVariation(context.Background(), &ImageVariationOptions{
Image: "000test.png",
})
if err != nil {
log.Fatal(err)
}
if b, err := json.MarshalIndent(r, "", " "); err != nil {
log.Fatal(err)
} else {
log.Println(string(b))
}
}
4 changes: 2 additions & 2 deletions models.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type ListModelsResponse struct {
// Docs: https://beta.openai.com/docs/api-reference/models/list
func (e *Engine) ListModels(ctx context.Context) (*ListModelsResponse, error) {
url := e.apiBaseURL + "/models"
req, err := e.newReq(ctx, http.MethodGet, url, nil)
req, err := e.newReq(ctx, http.MethodGet, url, "", nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -76,7 +76,7 @@ func (e *Engine) RetrieveModel(ctx context.Context, opts *RetrieveModelOptions)
return nil, err
}
url := e.apiBaseURL + "/models/" + string(opts.ID)
req, err := e.newReq(ctx, http.MethodGet, url, nil)
req, err := e.newReq(ctx, http.MethodGet, url, "", nil)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions models_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

func TestListModels(t *testing.T) {
e := New(os.Getenv("OPENAPI_KEY"))
e := New(os.Getenv("OPENAI_KEY"))
r, err := e.ListModels(context.Background())
if err != nil {
log.Fatal(err)
Expand All @@ -25,7 +25,7 @@ func TestListModels(t *testing.T) {
}

func TestRetrieveModel(t *testing.T) {
e := New(os.Getenv("OPENAPI_KEY"))
e := New(os.Getenv("OPENAI_KEY"))
r, err := e.RetrieveModel(context.Background(), &RetrieveModelOptions{
ID: ModelDavinci,
})
Expand Down
31 changes: 19 additions & 12 deletions openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/go-playground/validator/v10"
Expand All @@ -34,27 +35,25 @@ func New(apiKey string) *Engine {
return e
}

func (e *Engine) newReq(ctx context.Context, method string, url string, body any) (*http.Request, error) {
func (e *Engine) newReq(ctx context.Context, method string, uri string, postType string, body io.Reader) (*http.Request, error) {
if ctx == nil {
ctx = context.Background() // prevent nil context error
}
r := new(bytes.Reader)
if body != nil {
jsonb, err := json.Marshal(body)
if err != nil {
return nil, err
}
r = bytes.NewReader(jsonb)
if body == nil {
body = new(bytes.Reader) // prevent nil body error
}
req, err := http.NewRequestWithContext(ctx, method, url, r)
req, err := http.NewRequestWithContext(ctx, method, uri, body)
if err != nil {
return nil, err
}
// Setup Content-Type=application/json header only on POST operation
if body != nil && method == http.MethodPost {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", e.apiKey))
// Setup Content-Type depends on postType
switch {
case body != nil && postType == "json":
req.Header.Set("Content-type", "application/json")
case body != nil && postType == "formData":
req.Header.Set("Content-type", "application/x-www-form-urlencoded")
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", e.apiKey))
return req, err
}

Expand Down Expand Up @@ -88,3 +87,11 @@ func unmarshal(resp *http.Response, v any) error {
}
return nil
}

func marshalJson(body any) (io.Reader, error) {
b, err := json.Marshal(body)
if err != nil {
return nil, err
}
return bytes.NewReader(b), nil
}

0 comments on commit 2b39904

Please sign in to comment.