Skip to content

Commit

Permalink
Increasing code coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
coggsflod committed Jun 12, 2023
1 parent 1b4c82c commit 04769a9
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 42 deletions.
59 changes: 33 additions & 26 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ import (
utils "github.com/sashabaranov/go-openai/internal"
)

var (
ErrClientEmptyCallbackURL = errors.New("Error retrieving callback URL (Operation-Location) for image request") //nolint:lll
ErrClientRetievingCallbackResponse = errors.New("Error retrieving callback response")
)

// Client is OpenAI GPT-3 API client.
type Client struct {
config ClientConfig
Expand All @@ -23,6 +28,21 @@ type Client struct {
createFormBuilder func(io.Writer) utils.FormBuilder
}

// Azure image request callback response struct.
type CBData []struct {
URL string `json:"url"`
}
type CBResult struct {
Data CBData `json:"data"`
}
type CallBackResponse struct {
Created int64 `json:"created"`
Expires int64 `json:"expires"`
ID string `json:"id"`
Result CBResult `json:"result"`
Status string `json:"status"`
}

// NewClient creates new OpenAI API client.
func NewClient(authToken string) *Client {
config := DefaultConfig(authToken)
Expand Down Expand Up @@ -71,15 +91,16 @@ func (c *Client) sendRequest(req *http.Request, v any) error {
if isFailureStatusCode(res) {
return c.handleErrorResp(res)
}
// Special handling for initial call to Azure DALL-E API.
if strings.Contains(req.URL.Path, "openai/images/generations") &&
(c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD) {
return c.requestImage(res, v)
}
// Special handling for callBack to Azure DALL-E API.
if strings.Contains(req.URL.Path, "openai/operations/images") &&
(c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD) {
return c.imageRequestCallback(req, v, res)

if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
// Special handling for initial call to Azure DALL-E API.
if strings.Contains(req.URL.Path, "openai/images/generations") {
return c.requestImage(res, v)
}
// Special handling for callBack to Azure DALL-E API.
if strings.Contains(req.URL.Path, "openai/operations/images") {
return c.imageRequestCallback(req, v, res)
}
}

return decodeResponse(res.Body, v)
Expand Down Expand Up @@ -110,7 +131,7 @@ func (c *Client) requestImage(res *http.Response, v any) error {
}
callBackURL := res.Header.Get("Operation-Location")
if callBackURL == "" {
return errors.New("Error retrieving call back URL (Operation-Location) for image request")
return ErrClientEmptyCallbackURL
}
newReq, err := http.NewRequest("GET", callBackURL, nil)
if err != nil {
Expand All @@ -124,28 +145,14 @@ func (c *Client) imageRequestCallback(req *http.Request, v any, res *http.Respon
// Retry Sleep seconds for Azure DALL-E 2 callback URL.
var callBackWaitTime = 3

type Data []struct {
URL string `json:"url"`
}
type Result struct {
Data Data `json:"data"`
}
type callBackResponse struct {
Created int64 `json:"created"`
Expires int64 `json:"expires"`
ID string `json:"id"`
Result Result `json:"result"`
Status string `json:"status"`
}

// Wait for the callBack to complete
var result *callBackResponse
var result *CallBackResponse
err := json.NewDecoder(res.Body).Decode(&result)
if err != nil {
return err
}
if result.Status == "" {
return errors.New("Error retrieving callBack response")
return ErrClientRetievingCallbackResponse
}
if result.Status == "notRunning" || result.Status == "running" {
time.Sleep(time.Duration(callBackWaitTime) * time.Second)
Expand Down
57 changes: 57 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ package openai //nolint:testpackage // testing private field
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"testing"
"time"

"github.com/sashabaranov/go-openai/internal/test"
)
Expand Down Expand Up @@ -284,3 +287,57 @@ func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}

func TestImageRequestCallbackErrors(t *testing.T) {

var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL
client := NewClientWithConfig(config)
client.requestBuilder = &failingRequestBuilder{}

// Test requestImage callback URL empty.
TestCase := "Callback URL is empty"
res := &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewBufferString("")),
}
v := &ImageRequest{}
err = client.requestImage(res, v)

if !errors.Is(err, ErrClientEmptyCallbackURL) {
t.Fatalf("%s did not return error. requestImage failed: %v", TestCase, err)
}

// Test imageRequestCallback status response empty.
TestCase = "imageRequestCallback status response empty"
var request ImageRequest
ctx := context.Background()
cbResponse := CallBackResponse{
Created: time.Now().Unix(),
Status: "",
Result: CBResult{
Data: CBData{
{URL: "http://example.com/image1"},
{URL: "http://example.com/image2"},
},
},
}
cbResponseBytes := new(bytes.Buffer)
json.NewEncoder(cbResponseBytes).Encode(cbResponse)
req, _ := client.requestBuilder.Build(ctx, http.MethodPost, client.fullURL(""), request)
res = &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewBufferString(cbResponseBytes.String())),
}
v = &ImageRequest{}
err = client.imageRequestCallback(req, v, res)
fmt.Println(err)
if !errors.Is(err, ErrClientRetievingCallbackResponse) {
t.Fatalf("%s did not return error. imageRequestCallback failed: %v", TestCase, err)
}
}
19 changes: 3 additions & 16 deletions image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,6 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
// handleImageCallbackEndpoint Handles the callback endpoint by the test server.
func handleImageCallbackEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
type Data []struct {
URL string `json:"url"`
}
type Result struct {
Data Data `json:"data"`
}
type callBackResponse struct {
Created int64 `json:"created"`
Expires int64 `json:"expires"`
ID string `json:"id"`
Result Result `json:"result"`
Status string `json:"status"`
}

// image callback only accepts GET requests
if r.Method != "GET" {
Expand All @@ -139,11 +126,11 @@ func handleImageCallbackEndpoint(w http.ResponseWriter, r *http.Request) {
status = "notRunning"
}

cbResponse := callBackResponse{
cbResponse := CallBackResponse{
Created: time.Now().Unix(),
Status: status,
Result: Result{
Data: Data{
Result: CBResult{
Data: CBData{
{URL: "http://example.com/image1"},
{URL: "http://example.com/image2"},
},
Expand Down

0 comments on commit 04769a9

Please sign in to comment.