Skip to content

Commit

Permalink
Implement Unmarshaller interface. Resolves sashabaranov#244 (sashabar…
Browse files Browse the repository at this point in the history
  • Loading branch information
young-steveo authored Apr 14, 2023
1 parent d94c5e7 commit 061c97e
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 5 deletions.
102 changes: 99 additions & 3 deletions api_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package openai_test

import (
"encoding/json"

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"

Expand Down Expand Up @@ -110,7 +112,7 @@ func TestAPIError(t *testing.T) {
c := NewClient(apiToken + "_invalid")
ctx := context.Background()
_, err = c.ListEngines(ctx)
checks.NoError(t, err, "ListEngines did not fail")
checks.HasError(t, err, "ListEngines should fail with an invalid key")

var apiErr *APIError
if !errors.As(err, &apiErr) {
Expand All @@ -120,14 +122,108 @@ func TestAPIError(t *testing.T) {
if apiErr.StatusCode != 401 {
t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode)
}
if *apiErr.Code != "invalid_api_key" {
t.Fatalf("Unexpected API error code: %s", *apiErr.Code)

switch v := apiErr.Code.(type) {
case string:
if v != "invalid_api_key" {
t.Fatalf("Unexpected API error code: %s", v)
}
default:
t.Fatalf("Unexpected API error code type: %T", v)
}

if apiErr.Error() == "" {
t.Fatal("Empty error message occurred")
}
}

func TestAPIErrorUnmarshalJSONInteger(t *testing.T) {
var apiErr APIError
response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`
err := json.Unmarshal([]byte(response), &apiErr)
checks.NoError(t, err, "Unexpected Unmarshal API response error")

switch v := apiErr.Code.(type) {
case int:
if v != 418 {
t.Fatalf("Unexpected API code integer: %d; expected 418", v)
}
default:
t.Fatalf("Unexpected API error code type: %T", v)
}
}

func TestAPIErrorUnmarshalJSONString(t *testing.T) {
var apiErr APIError
response := `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`
err := json.Unmarshal([]byte(response), &apiErr)
checks.NoError(t, err, "Unexpected Unmarshal API response error")

switch v := apiErr.Code.(type) {
case string:
if v != "teapot" {
t.Fatalf("Unexpected API code string: %s; expected `teapot`", v)
}
default:
t.Fatalf("Unexpected API error code type: %T", v)
}
}

func TestAPIErrorUnmarshalJSONNoCode(t *testing.T) {
// test integer code
response := `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`
var apiErr APIError
err := json.Unmarshal([]byte(response), &apiErr)
checks.NoError(t, err, "Unexpected Unmarshal API response error")

switch v := apiErr.Code.(type) {
case nil:
default:
t.Fatalf("Unexpected API error code type: %T", v)
}
}

func TestAPIErrorUnmarshalInvalidData(t *testing.T) {
apiErr := APIError{}
data := []byte(`--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`)
err := apiErr.UnmarshalJSON(data)
checks.HasError(t, err, "Expected error when unmarshaling invalid data")

if apiErr.Code != nil {
t.Fatalf("Expected nil code, got %q", apiErr.Code)
}
if apiErr.Message != "" {
t.Fatalf("Expected empty message, got %q", apiErr.Message)
}
if apiErr.Param != nil {
t.Fatalf("Expected nil param, got %q", *apiErr.Param)
}
if apiErr.Type != "" {
t.Fatalf("Expected empty type, got %q", apiErr.Type)
}
}

func TestAPIErrorUnmarshalJSONInvalidParam(t *testing.T) {
var apiErr APIError
response := `{"code":418,"message":"I'm a teapot","param":true,"type":"teapot_error"}`
err := json.Unmarshal([]byte(response), &apiErr)
checks.HasError(t, err, "Param should be a string")
}

func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) {
var apiErr APIError
response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":true}`
err := json.Unmarshal([]byte(response), &apiErr)
checks.HasError(t, err, "Type should be a string")
}

func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) {
var apiErr APIError
response := `{"code":418,"message":false,"param":"prompt","type":"teapot_error"}`
err := json.Unmarshal([]byte(response), &apiErr)
checks.HasError(t, err, "Message should be a string")
}

func TestRequestError(t *testing.T) {
var err error

Expand Down
48 changes: 46 additions & 2 deletions error.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package openai

import "fmt"
import (
"encoding/json"
"fmt"
)

// APIError provides error information returned by the OpenAI API.
type APIError struct {
Code *string `json:"code,omitempty"`
Code any `json:"code,omitempty"`
Message string `json:"message"`
Param *string `json:"param,omitempty"`
Type string `json:"type"`
Expand All @@ -25,6 +28,47 @@ func (e *APIError) Error() string {
return e.Message
}

func (e *APIError) UnmarshalJSON(data []byte) (err error) {
var rawMap map[string]json.RawMessage
err = json.Unmarshal(data, &rawMap)
if err != nil {
return
}

err = json.Unmarshal(rawMap["message"], &e.Message)
if err != nil {
return
}

err = json.Unmarshal(rawMap["type"], &e.Type)
if err != nil {
return
}

// optional fields
if _, ok := rawMap["param"]; ok {
err = json.Unmarshal(rawMap["param"], &e.Param)
if err != nil {
return
}
}

if _, ok := rawMap["code"]; !ok {
return nil
}

// if the api returned a number, we need to force an integer
// since the json package defaults to float64
var intCode int
err = json.Unmarshal(rawMap["code"], &intCode)
if err == nil {
e.Code = intCode
return nil
}

return json.Unmarshal(rawMap["code"], &e.Code)
}

func (e *RequestError) Error() string {
if e.Err != nil {
return e.Err.Error()
Expand Down

0 comments on commit 061c97e

Please sign in to comment.