Skip to content

Commit 7cdb063

Browse files
authored
[aisdk] Updates to openai jsonschema encoding (#586)
## Summary OpenAI uses [jsonschema](https://json-schema.org/understanding-json-schema/reference) to define a function tool's input (and other things), but they don't support any arbitrary json schema. Here we implement some (but not all) of the [supported subset](https://platform.openai.com/docs/guides/structured-outputs#supported-schemas), and force that schemas always contain `properties`, even if empty. These were being omitted when empty because the `Schema` struct has `json:"omitempty"`) Implementing all the OpenAI restrictions seems like unnecessary eager validation. The API itself will return an error if an invalid schema is passed in. ## How was it tested? Unit tests ## Community Contribution License All community contributions in this pull request are licensed to the project maintainers under the terms of the [Apache 2 License](https://www.apache.org/licenses/LICENSE-2.0). By creating this pull request I represent that I have the right to license the contributions to the project maintainers under the Apache 2 License as stated in the [Community Contribution License](https://github.com/jetify-com/opensource/blob/main/CONTRIBUTING.md#community-contribution-license). Co-authored-by: Rodrigo Ipince <ipince@users.noreply.github.com>
1 parent 0bc4434 commit 7cdb063

File tree

2 files changed

+86
-41
lines changed

2 files changed

+86
-41
lines changed

aisdk/ai/provider/openai/internal/codec/jsonschema.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@ func encodeSchema(schema *jsonschema.Schema) (map[string]any, error) {
1616
return nil, nil
1717
}
1818

19+
// Enforce OpenAI restrictions
20+
// https://platform.openai.com/docs/guides/structured-outputs#root-objects-must-not-be-anyof-and-must-be-an-object
21+
// NOTE: we could simply encode the input schema, pass it through to OpenAI and let it return an error, but there are
22+
// other encoding rules we want to enforce later, and limiting the scope here allows us to limit the scope later.
23+
if schema.Type != "object" {
24+
return nil, fmt.Errorf("schema root must be of type object, got: %s", schema.Type)
25+
}
26+
if schema.AnyOf != nil {
27+
return nil, fmt.Errorf("schema root cannot use AnyOf")
28+
}
29+
1930
// Marshal to JSON and unmarshal back to interface{} to convert the types
2031
data, err := json.Marshal(schema)
2132
if err != nil {
@@ -32,6 +43,12 @@ func encodeSchema(schema *jsonschema.Schema) (map[string]any, error) {
3243
return nil, fmt.Errorf("failed to unmarshal properties: %w\n\n%s", err, data)
3344
}
3445

46+
// Ensure properties field is set, even if it's empty. It's unclear whether OpenAI requires
47+
// this to be set for nested schema objects too. For now we only set it at the top-level.
48+
if _, ok := result["properties"]; !ok {
49+
result["properties"] = map[string]any{}
50+
}
51+
3552
// Convert {"not": {}} patterns to false throughout the schema
3653
normalizeSchemaMap(result)
3754

aisdk/ai/provider/openai/internal/codec/jsonschema_test.go

Lines changed: 69 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -143,20 +143,42 @@ func TestEncodeSchema(t *testing.T) {
143143
}`,
144144
},
145145
{
146-
name: "schema with allOf containing additionalProperties",
146+
name: "schema with nested AnyOf",
147147
input: &jsonschema.Schema{
148-
AllOf: []*jsonschema.Schema{
149-
{
150-
Type: "object",
151-
AdditionalProperties: api.FalseSchema(),
148+
Type: "object",
149+
Properties: map[string]*jsonschema.Schema{
150+
"numeric": {
151+
AnyOf: []*jsonschema.Schema{
152+
{
153+
Type: "string",
154+
},
155+
{
156+
Type: "number",
157+
},
158+
},
152159
},
153160
},
154161
},
155162
want: `{
156-
"allOf": [{
157-
"type": "object",
158-
"additionalProperties": false
159-
}]
163+
"type": "object",
164+
"properties": {
165+
"numeric": {
166+
"anyOf": [
167+
{ "type": "string" },
168+
{ "type": "number" }
169+
]
170+
}
171+
}
172+
}`,
173+
},
174+
{
175+
name: "schema without properties gets empty properties map",
176+
input: &jsonschema.Schema{
177+
Type: "object",
178+
},
179+
want: `{
180+
"type": "object",
181+
"properties": {}
160182
}`,
161183
},
162184
{
@@ -210,6 +232,44 @@ func TestEncodeSchema(t *testing.T) {
210232
"required": ["id"]
211233
}`,
212234
},
235+
236+
// Edge/error cases
237+
{
238+
name: "schema with non-object root",
239+
input: &jsonschema.Schema{
240+
Properties: map[string]*jsonschema.Schema{
241+
"name": {
242+
Type: "string",
243+
Description: "The name",
244+
},
245+
},
246+
},
247+
wantErr: true,
248+
},
249+
{
250+
name: "empty schema",
251+
input: &jsonschema.Schema{},
252+
wantErr: true,
253+
},
254+
{
255+
name: "schema with only additional properties",
256+
input: &jsonschema.Schema{
257+
AdditionalProperties: api.FalseSchema(),
258+
},
259+
wantErr: true,
260+
},
261+
{
262+
name: "schema with AnyOf at rool level",
263+
input: &jsonschema.Schema{
264+
AnyOf: []*jsonschema.Schema{
265+
{
266+
Type: "object",
267+
AdditionalProperties: api.FalseSchema(),
268+
},
269+
},
270+
},
271+
wantErr: true,
272+
},
213273
}
214274

215275
for _, tt := range tests {
@@ -451,35 +511,3 @@ func TestNormalizeSchemaMap(t *testing.T) {
451511
})
452512
}
453513
}
454-
455-
func TestEncodeSchema_EdgeCases(t *testing.T) {
456-
t.Run("schema with only additionalProperties", func(t *testing.T) {
457-
schema := &jsonschema.Schema{
458-
AdditionalProperties: api.FalseSchema(),
459-
}
460-
461-
got, err := encodeSchema(schema)
462-
require.NoError(t, err)
463-
464-
gotJSON, err := json.Marshal(got)
465-
require.NoError(t, err)
466-
467-
expectedJSON := `{
468-
"additionalProperties": false
469-
}`
470-
471-
assert.JSONEq(t, expectedJSON, string(gotJSON))
472-
})
473-
474-
t.Run("empty schema", func(t *testing.T) {
475-
schema := &jsonschema.Schema{}
476-
477-
got, err := encodeSchema(schema)
478-
require.NoError(t, err)
479-
480-
gotJSON, err := json.Marshal(got)
481-
require.NoError(t, err)
482-
483-
assert.JSONEq(t, "{}", string(gotJSON))
484-
})
485-
}

0 commit comments

Comments
 (0)