Skip to content

Commit 264ebad

Browse files
davidmotsonDavid Motsonashvili
andauthored
Enable JSONSchema encoding (#7474)
* Add JsonSchema types * Add Encoding Switches * Add Tests TODO: Still need to actually switch over behavior depending on which model is selected --------- Co-authored-by: David Motsonashvili <davidmotson@google.com>
1 parent 1c3812f commit 264ebad

File tree

8 files changed

+186
-18
lines changed

8 files changed

+186
-18
lines changed

firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ import kotlinx.coroutines.flow.map
7777
import kotlinx.coroutines.launch
7878
import kotlinx.coroutines.withTimeout
7979
import kotlinx.serialization.ExperimentalSerializationApi
80+
import kotlinx.serialization.json.ClassDiscriminatorMode
8081
import kotlinx.serialization.json.Json
8182

8283
@OptIn(ExperimentalSerializationApi::class)
@@ -85,6 +86,7 @@ internal val JSON = Json {
8586
prettyPrint = false
8687
isLenient = true
8788
explicitNulls = false
89+
classDiscriminatorMode = ClassDiscriminatorMode.NONE
8890
}
8991

9092
/**

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/FunctionDeclaration.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ public class FunctionDeclaration(
6161
internal val schema: Schema =
6262
Schema.obj(properties = parameters, optionalProperties = optionalParameters, nullable = false)
6363

64-
internal fun toInternal() = Internal(name, description, schema.toInternal())
64+
internal fun toInternal() = Internal(name, description, schema.toInternalOpenApi())
6565

6666
@Serializable
6767
internal data class Internal(
6868
val name: String,
6969
val description: String,
70-
val parameters: Schema.Internal
70+
val parameters: Schema.InternalOpenAPI
7171
)
7272
}

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/GenerationConfig.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ private constructor(
200200
frequencyPenalty = frequencyPenalty,
201201
presencePenalty = presencePenalty,
202202
responseMimeType = responseMimeType,
203-
responseSchema = responseSchema?.toInternal(),
203+
responseSchema = responseSchema?.toInternalOpenApi(),
204204
responseModalities = responseModalities?.map { it.toInternal() },
205205
thinkingConfig = thinkingConfig?.toInternal()
206206
)
@@ -216,7 +216,7 @@ private constructor(
216216
@SerialName("response_mime_type") val responseMimeType: String? = null,
217217
@SerialName("presence_penalty") val presencePenalty: Float? = null,
218218
@SerialName("frequency_penalty") val frequencyPenalty: Float? = null,
219-
@SerialName("response_schema") val responseSchema: Schema.Internal? = null,
219+
@SerialName("response_schema") val responseSchema: Schema.InternalOpenAPI? = null,
220220
@SerialName("response_modalities") val responseModalities: List<String>? = null,
221221
@SerialName("thinking_config") val thinkingConfig: ThinkingConfig.Internal? = null
222222
)

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import java.io.ByteArrayOutputStream
2424
import kotlinx.serialization.DeserializationStrategy
2525
import kotlinx.serialization.SerialName
2626
import kotlinx.serialization.Serializable
27-
import kotlinx.serialization.SerializationException
2827
import kotlinx.serialization.json.JsonContentPolymorphicSerializer
2928
import kotlinx.serialization.json.JsonElement
3029
import kotlinx.serialization.json.JsonNull

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt

Lines changed: 110 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -322,46 +322,147 @@ internal constructor(
322322
public fun anyOf(schemas: List<Schema>): Schema = Schema(type = "ANYOF", anyOf = schemas)
323323
}
324324

325-
internal fun toInternal(): Internal {
325+
internal fun toInternalOpenApi(): InternalOpenAPI {
326326
val cleanedType =
327327
if (type == "ANYOF") {
328328
null
329329
} else {
330330
type
331331
}
332-
return Internal(
332+
return InternalOpenAPI(
333333
cleanedType,
334334
description,
335335
format,
336336
nullable,
337337
enum,
338-
properties?.mapValues { it.value.toInternal() },
338+
properties?.mapValues { it.value.toInternalOpenApi() },
339339
required,
340-
items?.toInternal(),
340+
items?.toInternalOpenApi(),
341341
title,
342342
minItems,
343343
maxItems,
344344
minimum,
345345
maximum,
346-
anyOf?.map { it.toInternal() },
346+
anyOf?.map { it.toInternalOpenApi() },
347+
)
348+
}
349+
350+
internal fun toInternalJson(): InternalJson {
351+
val outType =
352+
if (type == "ANYOF" || (type == "STRING" && format == "enum")) {
353+
null
354+
} else {
355+
type.lowercase()
356+
}
357+
358+
val (outMinimum, outMaximum) =
359+
if (outType == "integer" && format == "int32") {
360+
(minimum ?: Integer.MIN_VALUE.toDouble()) to (maximum ?: Integer.MAX_VALUE.toDouble())
361+
} else {
362+
minimum to maximum
363+
}
364+
365+
val outFormat =
366+
if (
367+
(outType == "integer" && format == "int32") ||
368+
(outType == "number" && format == "float") ||
369+
format == "enum"
370+
) {
371+
null
372+
} else {
373+
format
374+
}
375+
376+
if (nullable == true) {
377+
return InternalJsonNullable(
378+
outType?.let { listOf(it, "null") },
379+
description,
380+
outFormat,
381+
enum?.let {
382+
buildList {
383+
addAll(it)
384+
add("null")
385+
}
386+
},
387+
properties?.mapValues { it.value.toInternalJson() },
388+
required,
389+
items?.toInternalJson(),
390+
title,
391+
minItems,
392+
maxItems,
393+
outMinimum,
394+
outMaximum,
395+
anyOf?.map { it.toInternalJson() },
396+
)
397+
}
398+
return InternalJsonNonNull(
399+
outType,
400+
description,
401+
outFormat,
402+
enum,
403+
properties?.mapValues { it.value.toInternalJson() },
404+
required,
405+
items?.toInternalJson(),
406+
title,
407+
minItems,
408+
maxItems,
409+
outMinimum,
410+
outMaximum,
411+
anyOf?.map { it.toInternalJson() },
347412
)
348413
}
349414

350415
@Serializable
351-
internal data class Internal(
416+
internal data class InternalOpenAPI(
352417
val type: String? = null,
353418
val description: String? = null,
354419
val format: String? = null,
355420
val nullable: Boolean? = false,
356421
val enum: List<String>? = null,
357-
val properties: Map<String, Internal>? = null,
422+
val properties: Map<String, InternalOpenAPI>? = null,
358423
val required: List<String>? = null,
359-
val items: Internal? = null,
424+
val items: InternalOpenAPI? = null,
360425
val title: String? = null,
361426
val minItems: Int? = null,
362427
val maxItems: Int? = null,
363428
val minimum: Double? = null,
364429
val maximum: Double? = null,
365-
val anyOf: List<Internal>? = null,
430+
val anyOf: List<InternalOpenAPI>? = null,
366431
)
432+
433+
@Serializable internal sealed interface InternalJson
434+
435+
@Serializable
436+
internal data class InternalJsonNonNull(
437+
val type: String? = null,
438+
val description: String? = null,
439+
val format: String? = null,
440+
val enum: List<String>? = null,
441+
val properties: Map<String, InternalJson>? = null,
442+
val required: List<String>? = null,
443+
val items: InternalJson? = null,
444+
val title: String? = null,
445+
val minItems: Int? = null,
446+
val maxItems: Int? = null,
447+
val minimum: Double? = null,
448+
val maximum: Double? = null,
449+
val anyOf: List<InternalJson>? = null,
450+
) : InternalJson
451+
452+
@Serializable
453+
internal data class InternalJsonNullable(
454+
val type: List<String>? = null,
455+
val description: String? = null,
456+
val format: String? = null,
457+
val enum: List<String>? = null,
458+
val properties: Map<String, InternalJson>? = null,
459+
val required: List<String>? = null,
460+
val items: InternalJson? = null,
461+
val title: String? = null,
462+
val minItems: Int? = null,
463+
val maxItems: Int? = null,
464+
val minimum: Double? = null,
465+
val maximum: Double? = null,
466+
val anyOf: List<InternalJson>? = null,
467+
) : InternalJson
367468
}

firebase-ai/src/test/java/com/google/firebase/ai/SchemaTests.kt

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ package com.google.firebase.ai
1919
import com.google.firebase.ai.type.Schema
2020
import com.google.firebase.ai.type.StringFormat
2121
import io.kotest.assertions.json.shouldEqualJson
22+
import java.io.File
2223
import kotlinx.serialization.encodeToString
24+
import kotlinx.serialization.json.ClassDiscriminatorMode
2325
import kotlinx.serialization.json.Json
2426
import org.junit.Test
2527

@@ -93,7 +95,7 @@ internal class SchemaTests {
9395
"""
9496
.trimIndent()
9597

96-
Json.encodeToString(schemaDeclaration.toInternal()).shouldEqualJson(expectedJson)
98+
Json.encodeToString(schemaDeclaration.toInternalOpenApi()).shouldEqualJson(expectedJson)
9799
}
98100

99101
@Test
@@ -216,6 +218,70 @@ internal class SchemaTests {
216218
"""
217219
.trimIndent()
218220

219-
Json.encodeToString(schemaDeclaration.toInternal()).shouldEqualJson(expectedJson)
221+
Json.encodeToString(schemaDeclaration.toInternalOpenApi()).shouldEqualJson(expectedJson)
220222
}
223+
224+
@Test
225+
fun `schema encoding openAPI spec test`() {
226+
val expectedSerialization = getSchemaJson("open-api-schema.json")
227+
val serializedSchema = JSON_ENCODER.encodeToString(TEST_SCHEMA.toInternalOpenApi())
228+
serializedSchema.shouldEqualJson(expectedSerialization)
229+
}
230+
231+
@Test
232+
fun `schema encoding jsonSchema spec test`() {
233+
val expectedSerialization = getSchemaJson("json-schema.json")
234+
val serializedSchema = JSON_ENCODER.encodeToString(TEST_SCHEMA.toInternalJson())
235+
serializedSchema.shouldEqualJson(expectedSerialization)
236+
}
237+
238+
internal fun getSchemaJson(filename: String): String {
239+
return File("src/test/resources/vertexai-sdk-test-data/mock-responses/schema/${filename}")
240+
.readText()
241+
}
242+
243+
private val JSON_ENCODER = Json { classDiscriminatorMode = ClassDiscriminatorMode.NONE }
244+
245+
private val TEST_SCHEMA =
246+
Schema.obj(
247+
properties =
248+
mapOf(
249+
"integerTest" to Schema.integer(title = "integerTest", nullable = true),
250+
"longTest" to
251+
Schema.long(
252+
title = "longTest",
253+
nullable = false,
254+
minimum = 0.0,
255+
maximum = 5.0,
256+
description = "a test long"
257+
),
258+
"floatTest" to Schema.float(title = "floatTest", nullable = false),
259+
"doubleTest" to Schema.double(title = "doubleTest", nullable = true),
260+
"listTest" to
261+
Schema.array(
262+
items = Schema.integer(nullable = false),
263+
title = "listTest",
264+
nullable = false,
265+
minItems = 0,
266+
maxItems = 5
267+
),
268+
"booleanTest" to Schema.boolean(title = "booleanTest", nullable = false),
269+
"stringTest" to
270+
Schema.string(title = "stringTest", format = StringFormat.Custom("email")),
271+
"objTest" to
272+
Schema.obj(
273+
properties =
274+
mapOf(
275+
"testInt" to Schema.integer(title = "testInt", nullable = false),
276+
),
277+
title = "objTest",
278+
description = "class kdoc should be used if property kdocs aren't present",
279+
nullable = false
280+
),
281+
"enumTest" to Schema.enumeration(values = listOf("val1", "val2", "val3"))
282+
),
283+
optionalProperties = listOf("booleanTest"),
284+
description = "A test kdoc",
285+
nullable = false
286+
)
221287
}

firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ internal class SerializationTests {
437437
}
438438
"""
439439
.trimIndent()
440-
val actualJson = descriptorToJson(Schema.Internal.serializer().descriptor)
440+
val actualJson = descriptorToJson(Schema.InternalOpenAPI.serializer().descriptor)
441441
expectedJsonAsString shouldEqualJson actualJson.toString()
442442
}
443443

firebase-ai/update_responses.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# This script replaces mock response files for Vertex AI unit tests with a fresh
1818
# clone of the shared repository of Vertex AI test data.
1919

20-
RESPONSES_VERSION='v14.*' # The major version of mock responses to use
20+
RESPONSES_VERSION='v15.*' # The major version of mock responses to use
2121
REPO_NAME="vertexai-sdk-test-data"
2222
REPO_LINK="https://github.com/FirebaseExtended/$REPO_NAME.git"
2323

0 commit comments

Comments
 (0)