Skip to content

Commit

Permalink
Check if the model param is valid for moderations endpoint (sashabara…
Browse files Browse the repository at this point in the history
…nov#437)

* chore: check for models before sending moderation requets to openai endpoint

* chore: table driven tests to include more model cases for moderations endpoint
  • Loading branch information
MunarYesen committed Jul 13, 2023
1 parent 39b2acb commit e22a29d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
17 changes: 16 additions & 1 deletion moderation.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai

import (
"context"
"errors"
"net/http"
)

Expand All @@ -15,9 +16,19 @@ import (
const (
ModerationTextStable = "text-moderation-stable"
ModerationTextLatest = "text-moderation-latest"
ModerationText001 = "text-moderation-001"
// Deprecated: use ModerationTextStable and ModerationTextLatest instead.
ModerationText001 = "text-moderation-001"
)

var (
ErrModerationInvalidModel = errors.New("this model is not supported with moderation, please use text-moderation-stable or text-moderation-latest instead") //nolint:lll
)

var validModerationModel = map[string]struct{}{
ModerationTextStable: {},
ModerationTextLatest: {},
}

// ModerationRequest represents a request structure for moderation API.
type ModerationRequest struct {
Input string `json:"input,omitempty"`
Expand Down Expand Up @@ -63,6 +74,10 @@ type ModerationResponse struct {
// Moderations — perform a moderation api call over a string.
// Input can be an array or slice but a string will reduce the complexity.
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
if _, ok := validModerationModel[request.Model]; len(request.Model) > 0 && !ok {
err = ErrModerationInvalidModel
return
}
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request))
if err != nil {
return
Expand Down
35 changes: 35 additions & 0 deletions moderation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,41 @@ func TestModerations(t *testing.T) {
checks.NoError(t, err, "Moderation error")
}

// TestModerationsWithIncorrectModel Tests passing valid and invalid models to moderations endpoint.
func TestModerationsWithDifferentModelOptions(t *testing.T) {
var modelOptions []struct {
model string
expect error
}
modelOptions = append(modelOptions,
getModerationModelTestOption(GPT3Dot5Turbo, ErrModerationInvalidModel),
getModerationModelTestOption(ModerationTextStable, nil),
getModerationModelTestOption(ModerationTextLatest, nil),
getModerationModelTestOption("", nil),
)
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
for _, modelTest := range modelOptions {
_, err := client.Moderations(context.Background(), ModerationRequest{
Model: modelTest.model,
Input: "I want to kill them.",
})
checks.ErrorIs(t, err, modelTest.expect,
fmt.Sprintf("Moderations(..) expects err: %v, actual err:%v", modelTest.expect, err))
}
}

func getModerationModelTestOption(model string, expect error) struct {
model string
expect error
} {
return struct {
model string
expect error
}{model: model, expect: expect}
}

// handleModerationEndpoint Handles the moderation endpoint by the test server.
func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
Expand Down

0 comments on commit e22a29d

Please sign in to comment.