Skip to content

Commit

Permalink
feat: allow more input types to functions, fix tests (sashabaranov#377)
Browse files Browse the repository at this point in the history
* feat: use json.rawMessage, test functions

* chore: lint

* fix: tests

the ChatCompletion mock server doesn't actually run otherwise. N=0
is the default request but the server will treat it as n=1

* fix: tests should default to n=1 completions

* chore: add back removed interfaces, custom marshal

* chore: lint

* chore: lint

* chore: add some tests

* chore: appease lint

* clean up JSON schema + tests

* chore: lint

* feat: remove backwards compatible functions

for illustrative purposes

* fix: revert params change

* chore: use interface{}

* chore: add test

* chore: add back FunctionDefine

* chore: /s/interface{}/any

* chore: add back jsonschemadefinition

* chore: testcov

* chore: lint

* chore: remove pointers

* chore: update comment

* chore: address CR

added test for compatibility as well

---------

Co-authored-by: James <jmacwhyte@MacBooger-II.local>
  • Loading branch information
stillmatic and James authored Jun 21, 2023
1 parent e948150 commit f22da8a
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 21 deletions.
34 changes: 19 additions & 15 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,23 @@ type ChatCompletionRequest struct {
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
LogitBias map[string]int `json:"logit_bias,omitempty"`
User string `json:"user,omitempty"`
Functions []*FunctionDefine `json:"functions,omitempty"`
FunctionCall string `json:"function_call,omitempty"`
Functions []FunctionDefinition `json:"functions,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
}

type FunctionDefine struct {
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
// it's required in function call
Parameters *FunctionParams `json:"parameters"`
// Parameters is an object describing the function.
// You can pass a raw byte array describing the schema,
// or you can pass in a struct which serializes to the proper JSONSchema.
// The JSONSchemaDefinition struct is provided for convenience, but you should
// consider another specialized library for more complex schemas.
Parameters any `json:"parameters"`
}

type FunctionParams struct {
// the Type must be JSONSchemaTypeObject
Type JSONSchemaType `json:"type"`
Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
// Deprecated: use FunctionDefinition instead.
type FunctionDefine = FunctionDefinition

type JSONSchemaType string

Expand All @@ -83,22 +83,26 @@ const (
JSONSchemaTypeBoolean JSONSchemaType = "boolean"
)

// JSONSchemaDefine is a struct for JSON Schema.
type JSONSchemaDefine struct {
// JSONSchemaDefinition is a struct for JSON Schema.
// It is fairly limited and you may have better luck using a third-party library.
type JSONSchemaDefinition struct {
// Type is a type of JSON Schema.
Type JSONSchemaType `json:"type,omitempty"`
// Description is a description of JSON Schema.
Description string `json:"description,omitempty"`
// Enum is a enum of JSON Schema. It used if Type is JSONSchemaTypeString.
Enum []string `json:"enum,omitempty"`
// Properties is a properties of JSON Schema. It used if Type is JSONSchemaTypeObject.
Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"`
Properties map[string]JSONSchemaDefinition `json:"properties,omitempty"`
// Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject.
Required []string `json:"required,omitempty"`
// Items is a property of JSON Schema. It used if Type is JSONSchemaTypeArray.
Items *JSONSchemaDefine `json:"items,omitempty"`
Items *JSONSchemaDefinition `json:"items,omitempty"`
}

// Deprecated: use JSONSchemaDefinition instead.
type JSONSchemaDefine = JSONSchemaDefinition

type FinishReason string

const (
Expand Down
157 changes: 154 additions & 3 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,130 @@ func TestChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateChatCompletion error")
}

// TestChatCompletionsFunctions tests including a function call.
func TestChatCompletionsFunctions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
t.Run("bytes", func(t *testing.T) {
//nolint:lll
msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`)
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo0613,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Functions: []FunctionDefine{{
Name: "test",
Parameters: &msg,
}},
})
checks.NoError(t, err, "CreateChatCompletion with functions error")
})
t.Run("struct", func(t *testing.T) {
type testMessage struct {
Count int `json:"count"`
Words []string `json:"words"`
}
msg := testMessage{
Count: 2,
Words: []string{"hello", "world"},
}
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo0613,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Functions: []FunctionDefinition{{
Name: "test",
Parameters: &msg,
}},
})
checks.NoError(t, err, "CreateChatCompletion with functions error")
})
t.Run("JSONSchemaDefine", func(t *testing.T) {
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo0613,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Functions: []FunctionDefinition{{
Name: "test",
Parameters: &JSONSchemaDefinition{
Type: JSONSchemaTypeObject,
Properties: map[string]JSONSchemaDefinition{
"count": {
Type: JSONSchemaTypeNumber,
Description: "total number of words in sentence",
},
"words": {
Type: JSONSchemaTypeArray,
Description: "list of words in sentence",
Items: &JSONSchemaDefinition{
Type: JSONSchemaTypeString,
},
},
"enumTest": {
Type: JSONSchemaTypeString,
Enum: []string{"hello", "world"},
},
},
},
}},
})
checks.NoError(t, err, "CreateChatCompletion with functions error")
})
t.Run("JSONSchemaDefineWithFunctionDefine", func(t *testing.T) {
// this is a compatibility check
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo0613,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Functions: []FunctionDefine{{
Name: "test",
Parameters: &JSONSchemaDefine{
Type: JSONSchemaTypeObject,
Properties: map[string]JSONSchemaDefine{
"count": {
Type: JSONSchemaTypeNumber,
Description: "total number of words in sentence",
},
"words": {
Type: JSONSchemaTypeArray,
Description: "list of words in sentence",
Items: &JSONSchemaDefine{
Type: JSONSchemaTypeString,
},
},
"enumTest": {
Type: JSONSchemaTypeString,
Enum: []string{"hello", "world"},
},
},
},
}},
})
checks.NoError(t, err, "CreateChatCompletion with functions error")
})
}

func TestAzureChatCompletions(t *testing.T) {
client, server, teardown := setupAzureTestServer()
defer teardown()
Expand Down Expand Up @@ -109,7 +233,34 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
Model: completionReq.Model,
}
// create completions
for i := 0; i < completionReq.N; i++ {
n := completionReq.N
if n == 0 {
n = 1
}
for i := 0; i < n; i++ {
// if there are functions, include them
if len(completionReq.Functions) > 0 {
var fcb []byte
b := completionReq.Functions[0].Parameters
fcb, err = json.Marshal(b)
if err != nil {
http.Error(w, "could not marshal function parameters", http.StatusInternalServerError)
return
}

res.Choices = append(res.Choices, ChatCompletionChoice{
Message: ChatCompletionMessage{
Role: ChatMessageRoleFunction,
// this is valid json so it should be fine
FunctionCall: &FunctionCall{
Name: completionReq.Functions[0].Name,
Arguments: string(fcb),
},
},
Index: i,
})
continue
}
// generate a random string of length completionReq.Length
completionStr := strings.Repeat("a", completionReq.MaxTokens)

Expand All @@ -121,8 +272,8 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
Index: i,
})
}
inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N
completionTokens := completionReq.MaxTokens * completionReq.N
inputTokens := numTokens(completionReq.Messages[0].Content) * n
completionTokens := completionReq.MaxTokens * n
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
Expand Down
10 changes: 7 additions & 3 deletions completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
Model: completionReq.Model,
}
// create completions
for i := 0; i < completionReq.N; i++ {
n := completionReq.N
if n == 0 {
n = 1
}
for i := 0; i < n; i++ {
// generate a random string of length completionReq.Length
completionStr := strings.Repeat("a", completionReq.MaxTokens)
if completionReq.Echo {
Expand All @@ -94,8 +98,8 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
Index: i,
})
}
inputTokens := numTokens(completionReq.Prompt.(string)) * completionReq.N
completionTokens := completionReq.MaxTokens * completionReq.N
inputTokens := numTokens(completionReq.Prompt.(string)) * n
completionTokens := completionReq.MaxTokens * n
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
Expand Down

0 comments on commit f22da8a

Please sign in to comment.