Skip to content

Commit

Permalink
make args nullable in parsing and add test, also specify that we expe… (
Browse files Browse the repository at this point in the history
google-gemini#165)

…ct it to not be null

---------

Co-authored-by: David Motsonashvili <davidmotson@google.com>
  • Loading branch information
davidmotson and David Motsonashvili authored May 31, 2024
1 parent 5f02a32 commit b1803c4
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ data class Schema(
val type: String,
val description: String? = null,
val format: String? = null,
val nullable: Boolean? = false,
val enum: List<String>? = null,
val properties: Map<String, Schema>? = null,
val required: List<String>? = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List<Pa

@Serializable data class FunctionResponse(val name: String, val response: JsonObject)

@Serializable data class FunctionCall(val name: String, val args: Map<String, String>)
@Serializable data class FunctionCall(val name: String, val args: Map<String, String?>)

@Serializable data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.google.ai.client.generativeai.common.server.BlockReason
import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.server.HarmProbability
import com.google.ai.client.generativeai.common.server.HarmSeverity
import com.google.ai.client.generativeai.common.shared.FunctionCallPart
import com.google.ai.client.generativeai.common.shared.HarmCategory
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.common.util.goldenUnaryFile
Expand Down Expand Up @@ -301,4 +302,15 @@ internal class UnarySnapshotTests {
}
}
}

@Test
fun `function call contains null param`() =
goldenUnaryFile("success-function-call-null.json") {
withTimeout(testTimeout) {
val response = apiController.generateContent(textGenerateContentRequest("prompt"))
val callPart = (response.candidates!!.first().content!!.parts.first() as FunctionCallPart)

callPart.functionCall.args["season"] shouldBe null
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"candidates": [
{
"content": {
"parts": [
{
"functionCall": {
"name": "functionName",
"args": {
"original_title": "String",
"season": null
}
}
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
}
]
}
],
"usageMetadata": {
"promptTokenCount": 774,
"candidatesTokenCount": 4176,
"totalTokenCount": 4950
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ internal fun FunctionDeclaration.toInternal() =
properties = getParameters().associate { it.name to it.toInternal() },
required = getParameters().map { it.name },
type = "OBJECT",
nullable = false,
),
)

Expand All @@ -158,6 +159,7 @@ internal fun <T> com.google.ai.client.generativeai.type.Schema<T>.toInternal():
type.name,
description,
format,
nullable,
enum,
properties?.mapValues { it.value.toInternal() },
required,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class Schema<T>(
val name: String,
val description: String,
val format: String? = null,
val nullable: Boolean? = null,
val enum: List<String>? = null,
val properties: Map<String, Schema<out Any>>? = null,
val required: List<String>? = null,
Expand All @@ -184,19 +185,39 @@ class Schema<T>(
companion object {
/** Registers a schema for an integer number */
fun int(name: String, description: String) =
Schema<Long>(name = name, description = description, type = FunctionType.INTEGER)
Schema<Long>(
name = name,
description = description,
type = FunctionType.INTEGER,
nullable = false,
)

/** Registers a schema for a string */
fun str(name: String, description: String) =
Schema<String>(name = name, description = description, type = FunctionType.STRING)
Schema<String>(
name = name,
description = description,
type = FunctionType.STRING,
nullable = false,
)

/** Registers a schema for a boolean */
fun bool(name: String, description: String) =
Schema<Boolean>(name = name, description = description, type = FunctionType.BOOLEAN)
Schema<Boolean>(
name = name,
description = description,
type = FunctionType.BOOLEAN,
nullable = false,
)

/** Registers a schema for a floating point number */
fun num(name: String, description: String) =
Schema<Double>(name = name, description = description, type = FunctionType.NUMBER)
Schema<Double>(
name = name,
description = description,
type = FunctionType.NUMBER,
nullable = false,
)

/**
* Registers a schema for a complex object. In a function it will be returned as a [JSONObject]
Expand All @@ -208,11 +229,17 @@ class Schema<T>(
type = FunctionType.OBJECT,
required = contents.map { it.name },
properties = contents.associateBy { it.name }.toMap(),
nullable = false,
)

/** Registers a schema for an array */
fun arr(name: String, description: String) =
Schema<List<String>>(name = name, description = description, type = FunctionType.ARRAY)
Schema<List<String>>(
name = name,
description = description,
type = FunctionType.ARRAY,
nullable = false,
)

/** Registers a schema for an enum */
fun enum(name: String, description: String, values: List<String>) =
Expand All @@ -222,6 +249,7 @@ class Schema<T>(
format = "enum",
enum = values,
type = FunctionType.STRING,
nullable = false,
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class FileDataPart(val uri: String, val mimeType: String) : Part
fun Part.asFileDataPartOrNull(): FileDataPart? = this as? FileDataPart

/** Represents function call name and params received from requests. */
class FunctionCallPart(val name: String, val args: Map<String, String>) : Part
class FunctionCallPart(val name: String, val args: Map<String, String?>) : Part

/** Represents function call output to be returned to the model when it requests a function call */
class FunctionResponsePart(val name: String, val response: JSONObject) : Part
Expand Down

0 comments on commit b1803c4

Please sign in to comment.