From e22a29d84ebb8c5c911937669f27ac3265f3c982 Mon Sep 17 00:00:00 2001 From: Munar <118156704+MunaerYesiyan@users.noreply.github.com> Date: Thu, 13 Jul 2023 13:30:58 +0900 Subject: [PATCH] Check if the model param is valid for moderations endpoint (#437) * chore: check for models before sending moderation requets to openai endpoint * chore: table driven tests to include more model cases for moderations endpoint --- moderation.go | 17 ++++++++++++++++- moderation_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/moderation.go b/moderation.go index a58d759c0..a32f123f3 100644 --- a/moderation.go +++ b/moderation.go @@ -2,6 +2,7 @@ package openai import ( "context" + "errors" "net/http" ) @@ -15,9 +16,19 @@ import ( const ( ModerationTextStable = "text-moderation-stable" ModerationTextLatest = "text-moderation-latest" - ModerationText001 = "text-moderation-001" + // Deprecated: use ModerationTextStable and ModerationTextLatest instead. + ModerationText001 = "text-moderation-001" ) +var ( + ErrModerationInvalidModel = errors.New("this model is not supported with moderation, please use text-moderation-stable or text-moderation-latest instead") //nolint:lll +) + +var validModerationModel = map[string]struct{}{ + ModerationTextStable: {}, + ModerationTextLatest: {}, +} + // ModerationRequest represents a request structure for moderation API. type ModerationRequest struct { Input string `json:"input,omitempty"` @@ -63,6 +74,10 @@ type ModerationResponse struct { // Moderations — perform a moderation api call over a string. // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { + if _, ok := validModerationModel[request.Model]; len(request.Model) > 0 && !ok { + err = ErrModerationInvalidModel + return + } req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) if err != nil { return diff --git a/moderation_test.go b/moderation_test.go index 4e756137e..68f9565e1 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -27,6 +27,41 @@ func TestModerations(t *testing.T) { checks.NoError(t, err, "Moderation error") } +// TestModerationsWithIncorrectModel Tests passing valid and invalid models to moderations endpoint. +func TestModerationsWithDifferentModelOptions(t *testing.T) { + var modelOptions []struct { + model string + expect error + } + modelOptions = append(modelOptions, + getModerationModelTestOption(GPT3Dot5Turbo, ErrModerationInvalidModel), + getModerationModelTestOption(ModerationTextStable, nil), + getModerationModelTestOption(ModerationTextLatest, nil), + getModerationModelTestOption("", nil), + ) + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/moderations", handleModerationEndpoint) + for _, modelTest := range modelOptions { + _, err := client.Moderations(context.Background(), ModerationRequest{ + Model: modelTest.model, + Input: "I want to kill them.", + }) + checks.ErrorIs(t, err, modelTest.expect, + fmt.Sprintf("Moderations(..) expects err: %v, actual err:%v", modelTest.expect, err)) + } +} + +func getModerationModelTestOption(model string, expect error) struct { + model string + expect error +} { + return struct { + model string + expect error + }{model: model, expect: expect} +} + // handleModerationEndpoint Handles the moderation endpoint by the test server. func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { var err error