diff --git a/error.go b/error.go index f68e92875..b2d01e22e 100644 --- a/error.go +++ b/error.go @@ -7,12 +7,20 @@ import ( ) // APIError provides error information returned by the OpenAI API. +// InnerError struct is only valid for Azure OpenAI Service. type APIError struct { - Code any `json:"code,omitempty"` - Message string `json:"message"` - Param *string `json:"param,omitempty"` - Type string `json:"type"` - HTTPStatusCode int `json:"-"` + Code any `json:"code,omitempty"` + Message string `json:"message"` + Param *string `json:"param,omitempty"` + Type string `json:"type"` + HTTPStatusCode int `json:"-"` + InnerError *InnerError `json:"innererror,omitempty"` +} + +// InnerError Azure Content filtering. Only valid for Azure OpenAI Service. +type InnerError struct { + Code string `json:"code,omitempty"` + ContentFilterResults ContentFilterResults `json:"content_filter_result,omitempty"` } // RequestError provides informations about generic request errors. @@ -61,6 +69,13 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } } + if _, ok := rawMap["innererror"]; ok { + err = json.Unmarshal(rawMap["innererror"], &e.InnerError) + if err != nil { + return + } + } + // optional fields if _, ok := rawMap["param"]; ok { err = json.Unmarshal(rawMap["param"], &e.Param) diff --git a/error_test.go b/error_test.go index e2309abd7..a0806b7ed 100644 --- a/error_test.go +++ b/error_test.go @@ -3,6 +3,7 @@ package openai_test import ( "errors" "net/http" + "reflect" "testing" . "github.com/sashabaranov/go-openai" @@ -57,6 +58,77 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { assertAPIErrorMessage(t, apiErr, "") }, }, + { + name: "parse succeeds when the innerError is not exists (Azure Openai)", + response: `{ + "message": "test message", + "type": null, + "param": "prompt", + "code": "content_filter", + "status": 400, + "innererror": { + "code": "ResponsibleAIPolicyViolation", + "content_filter_result": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": true, + "severity": "medium" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + } + } + }`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorInnerError(t, apiErr, &InnerError{ + Code: "ResponsibleAIPolicyViolation", + ContentFilterResults: ContentFilterResults{ + Hate: Hate{ + Filtered: false, + Severity: "safe", + }, + SelfHarm: SelfHarm{ + Filtered: false, + Severity: "safe", + }, + Sexual: Sexual{ + Filtered: true, + Severity: "medium", + }, + Violence: Violence{ + Filtered: false, + Severity: "safe", + }, + }, + }) + }, + }, + { + name: "parse succeeds when the innerError is empty (Azure Openai)", + response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorInnerError(t, apiErr, &InnerError{}) + }, + }, + { + name: "parse succeeds when the innerError is not InnerError struct (Azure Openai)", + response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`, + hasError: true, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorInnerError(t, apiErr, &InnerError{}) + }, + }, { name: "parse failed when the message is object", response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, @@ -152,6 +224,12 @@ func assertAPIErrorMessage(t *testing.T, apiErr APIError, expected string) { } } +func assertAPIErrorInnerError(t *testing.T, apiErr APIError, expected interface{}) { + if !reflect.DeepEqual(apiErr.InnerError, expected) { + t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected) + } +} + func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { switch v := apiErr.Code.(type) { case int: