From e49d771fff3bc699bca7cf22c9d93b67316047e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Sat, 17 Jun 2023 22:57:29 +0900 Subject: [PATCH] support for parsing error response message fields even if they are arrays (#381) (#384) --- api_test.go | 109 ++++++++++++++++++++++++++++++++++++++++++++++++---- error.go | 10 ++++- 2 files changed, 111 insertions(+), 8 deletions(-) diff --git a/api_test.go b/api_test.go index 083b67412..34173708f 100644 --- a/api_test.go +++ b/api_test.go @@ -137,6 +137,108 @@ func TestAPIError(t *testing.T) { } } +func TestAPIErrorUnmarshalJSONMessageField(t *testing.T) { + type testCase struct { + name string + response string + hasError bool + checkFn func(t *testing.T, apiErr APIError) + } + testCases := []testCase{ + { + name: "parse succeeds when the message is string", + response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + expected := "foo" + if apiErr.Message != expected { + t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) + } + }, + }, + { + name: "parse succeeds when the message is array with single item", + response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + expected := "foo" + if apiErr.Message != expected { + t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) + } + }, + }, + { + name: "parse succeeds when the message is array with multiple items", + response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + expected := "foo, bar, baz" + if apiErr.Message != expected { + t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) + } + }, + }, + { + name: "parse succeeds when the message is empty array", + response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + if apiErr.Message != "" { + t.Fatalf("Unexpected API message: %v; expected: empty", apiErr) + } + }, + }, + { + name: "parse succeeds when the message is null", + response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + if apiErr.Message != "" { + t.Fatalf("Unexpected API message: %v; expected: empty", apiErr) + } + }, + }, + { + name: "parse failed when the message is object", + response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is int", + response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is float", + response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is bool", + response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is not exists", + response: `{"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var apiErr APIError + err := json.Unmarshal([]byte(tc.response), &apiErr) + if (err != nil) != tc.hasError { + t.Errorf("Unexpected error: %v", err) + return + } + if tc.checkFn != nil { + tc.checkFn(t, apiErr) + } + }) + } +} + func TestAPIErrorUnmarshalJSONInteger(t *testing.T) { var apiErr APIError response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}` @@ -217,13 +319,6 @@ func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) { checks.HasError(t, err, "Type should be a string") } -func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":false,"param":"prompt","type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.HasError(t, err, "Message should be a string") -} - func TestRequestError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/error.go b/error.go index b789ed7d5..f68e92875 100644 --- a/error.go +++ b/error.go @@ -3,6 +3,7 @@ package openai import ( "encoding/json" "fmt" + "strings" ) // APIError provides error information returned by the OpenAI API. @@ -41,7 +42,14 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { err = json.Unmarshal(rawMap["message"], &e.Message) if err != nil { - return + // If the parameter field of a function call is invalid as a JSON schema + // refs: https://github.com/sashabaranov/go-openai/issues/381 + var messages []string + err = json.Unmarshal(rawMap["message"], &messages) + if err != nil { + return + } + e.Message = strings.Join(messages, ", ") } // optional fields for azure openai