Skip to content

Commit

Permalink
split the auto function calling changes between common and generative…
Browse files Browse the repository at this point in the history
…ai (#90)

Co-authored-by: David Motsonashvili <davidmotson@google.com>
  • Loading branch information
davidmotson and David Motsonashvili authored Apr 2, 2024
1 parent 40f496e commit f22a52f
Show file tree
Hide file tree
Showing 13 changed files with 523 additions and 11 deletions.
1 change: 1 addition & 0 deletions .changes/cloud-camp-bait-calculator.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Add function calling"]}
1 change: 1 addition & 0 deletions generativeai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ dependencies {
implementation("org.slf4j:slf4j-nop:2.0.9")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3")
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
implementation("org.reactivestreams:reactive-streams:1.0.3")

implementation("com.google.guava:listenablefuture:1.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
}

private fun Content.assertComesFromUser() {
if (role != "user") {
throw InvalidStateException("Chat prompts should come from the 'user' role.")
if (role !in listOf("user", "function")) {
throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,29 @@ import com.google.ai.client.generativeai.internal.util.toPublic
import com.google.ai.client.generativeai.type.Content
import com.google.ai.client.generativeai.type.CountTokensResponse
import com.google.ai.client.generativeai.type.FinishReason
import com.google.ai.client.generativeai.type.FourParameterFunction
import com.google.ai.client.generativeai.type.FunctionCallPart
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.GenerationConfig
import com.google.ai.client.generativeai.type.GenerativeBeta
import com.google.ai.client.generativeai.type.GoogleGenerativeAIException
import com.google.ai.client.generativeai.type.InvalidStateException
import com.google.ai.client.generativeai.type.NoParameterFunction
import com.google.ai.client.generativeai.type.OneParameterFunction
import com.google.ai.client.generativeai.type.PromptBlockedException
import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.ResponseStoppedException
import com.google.ai.client.generativeai.type.SafetySetting
import com.google.ai.client.generativeai.type.SerializationException
import com.google.ai.client.generativeai.type.ThreeParameterFunction
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.TwoParameterFunction
import com.google.ai.client.generativeai.type.content
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.map
import kotlinx.serialization.ExperimentalSerializationApi
import org.json.JSONObject

/**
* A facilitator for a given multimodal model (eg; Gemini).
Expand All @@ -48,14 +59,16 @@ import kotlinx.coroutines.flow.map
* generation
* @property requestOptions configuration options to utilize during backend communication
*/
@OptIn(ExperimentalSerializationApi::class)
class GenerativeModel
internal constructor(
val modelName: String,
val apiKey: String,
val generationConfig: GenerationConfig? = null,
val safetySettings: List<SafetySetting>? = null,
val tools: List<Tool>? = null,
val requestOptions: RequestOptions = RequestOptions(),
private val controller: APIController
private val controller: APIController,
) {

@JvmOverloads
Expand All @@ -64,14 +77,16 @@ internal constructor(
apiKey: String,
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
tools: List<Tool>? = null,
requestOptions: RequestOptions = RequestOptions(),
) : this(
modelName,
apiKey,
generationConfig,
safetySettings,
tools,
requestOptions,
APIController(apiKey, modelName, requestOptions.toInternal())
APIController(apiKey, modelName, requestOptions.toInternal()),
)

/**
Expand Down Expand Up @@ -171,12 +186,45 @@ internal constructor(
return countTokens(content { image(prompt) })
}

/**
* Executes a function requested by the model.
*
* @param functionCallPart A [FunctionCallPart] from the model, containing a function call and
* parameters
* @return The output of the requested function call
*/
@OptIn(GenerativeBeta::class)
suspend fun executeFunction(functionCallPart: FunctionCallPart): JSONObject {
if (tools == null) {
throw InvalidStateException("No registered tools")
}
val callable =
tools.flatMap { it.functionDeclarations }.firstOrNull { it.name == functionCallPart.name }
?: throw InvalidStateException("No registered function named ${functionCallPart.name}")
return when (callable) {
is NoParameterFunction -> callable.execute()
is OneParameterFunction<*> ->
(callable as OneParameterFunction<Any?>).execute(functionCallPart)
is TwoParameterFunction<*, *> ->
(callable as TwoParameterFunction<Any?, Any?>).execute(functionCallPart)
is ThreeParameterFunction<*, *, *> ->
(callable as ThreeParameterFunction<Any?, Any?, Any?>).execute(functionCallPart)
is FourParameterFunction<*, *, *, *> ->
(callable as FourParameterFunction<Any?, Any?, Any?, Any?>).execute(functionCallPart)
else -> {
throw RuntimeException("UNREACHABLE")
}
}
}

@OptIn(GenerativeBeta::class)
private fun constructRequest(vararg prompt: Content) =
GenerateContentRequest(
modelName,
prompt.map { it.toInternal() },
safetySettings?.map { it.toInternal() },
generationConfig?.toInternal()
generationConfig?.toInternal(),
tools?.map { it.toInternal() },
)

private fun constructCountTokensRequest(vararg prompt: Content) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.google.ai.client.generativeai.common.CountTokensResponse
import com.google.ai.client.generativeai.common.GenerateContentResponse
import com.google.ai.client.generativeai.common.RequestOptions
import com.google.ai.client.generativeai.common.client.GenerationConfig
import com.google.ai.client.generativeai.common.client.Schema
import com.google.ai.client.generativeai.common.server.BlockReason
import com.google.ai.client.generativeai.common.server.Candidate
import com.google.ai.client.generativeai.common.server.CitationSources
Expand All @@ -33,17 +34,27 @@ import com.google.ai.client.generativeai.common.server.SafetyRating
import com.google.ai.client.generativeai.common.shared.Blob
import com.google.ai.client.generativeai.common.shared.BlobPart
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.FunctionCall
import com.google.ai.client.generativeai.common.shared.FunctionCallPart
import com.google.ai.client.generativeai.common.shared.FunctionResponse
import com.google.ai.client.generativeai.common.shared.FunctionResponsePart
import com.google.ai.client.generativeai.common.shared.HarmBlockThreshold
import com.google.ai.client.generativeai.common.shared.HarmCategory
import com.google.ai.client.generativeai.common.shared.Part
import com.google.ai.client.generativeai.common.shared.SafetySetting
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.type.BlockThreshold
import com.google.ai.client.generativeai.type.CitationMetadata
import com.google.ai.client.generativeai.type.FunctionDeclaration
import com.google.ai.client.generativeai.type.GenerativeBeta
import com.google.ai.client.generativeai.type.ImagePart
import com.google.ai.client.generativeai.type.SerializationException
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.content
import java.io.ByteArrayOutputStream
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import org.json.JSONObject

private const val BASE_64_FLAGS = Base64.NO_WRAP

Expand All @@ -59,6 +70,10 @@ internal fun com.google.ai.client.generativeai.type.Part.toInternal(): Part {
is ImagePart -> BlobPart(Blob("image/jpeg", encodeBitmapToBase64Png(image)))
is com.google.ai.client.generativeai.type.BlobPart ->
BlobPart(Blob(mimeType, Base64.encodeToString(blob, BASE_64_FLAGS)))
is com.google.ai.client.generativeai.type.FunctionCallPart ->
FunctionCallPart(FunctionCall(name, args.orEmpty()))
is com.google.ai.client.generativeai.type.FunctionResponsePart ->
FunctionResponsePart(FunctionResponse(name, response.toInternal()))
else ->
throw SerializationException(
"The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet."
Expand All @@ -76,7 +91,7 @@ internal fun com.google.ai.client.generativeai.type.GenerationConfig.toInternal(
topK = topK,
candidateCount = candidateCount,
maxOutputTokens = maxOutputTokens,
stopSequences = stopSequences
stopSequences = stopSequences,
)

internal fun com.google.ai.client.generativeai.type.HarmCategory.toInternal() =
Expand All @@ -99,6 +114,35 @@ internal fun BlockThreshold.toInternal() =
BlockThreshold.UNSPECIFIED -> HarmBlockThreshold.UNSPECIFIED
}

@GenerativeBeta
internal fun Tool.toInternal() =
com.google.ai.client.generativeai.common.client.Tool(functionDeclarations.map { it.toInternal() })

@GenerativeBeta
internal fun FunctionDeclaration.toInternal() =
com.google.ai.client.generativeai.common.client.FunctionDeclaration(
name,
description,
Schema(
properties = getParameters().associate { it.name to it.toInternal() },
required = getParameters().map { it.name },
type = "OBJECT",
),
)

internal fun <T> com.google.ai.client.generativeai.type.Schema<T>.toInternal(): Schema =
Schema(
type.name,
description,
format,
enum,
properties?.mapValues { it.value.toInternal() },
required,
items?.toInternal(),
)

internal fun JSONObject.toInternal() = Json.decodeFromString<JsonObject>(toString())

internal fun Candidate.toPublic(): com.google.ai.client.generativeai.type.Candidate {
val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty()
val citations = citationMetadata?.citationSources?.map { it.toPublic() }.orEmpty()
Expand All @@ -108,7 +152,7 @@ internal fun Candidate.toPublic(): com.google.ai.client.generativeai.type.Candid
this.content?.toPublic() ?: content("model") {},
safetyRatings,
citations,
finishReason
finishReason,
)
}

Expand All @@ -126,6 +170,16 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part {
com.google.ai.client.generativeai.type.BlobPart(inlineData.mimeType, data)
}
}
is FunctionCallPart ->
com.google.ai.client.generativeai.type.FunctionCallPart(
functionCall.name,
functionCall.args.orEmpty(),
)
is FunctionResponsePart ->
com.google.ai.client.generativeai.type.FunctionResponsePart(
functionResponse.name,
functionResponse.response.toPublic(),
)
else ->
throw SerializationException(
"Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK."
Expand Down Expand Up @@ -192,12 +246,14 @@ internal fun BlockReason.toPublic() =
internal fun GenerateContentResponse.toPublic() =
com.google.ai.client.generativeai.type.GenerateContentResponse(
candidates?.map { it.toPublic() }.orEmpty(),
promptFeedback?.toPublic()
promptFeedback?.toPublic(),
)

internal fun CountTokensResponse.toPublic() =
com.google.ai.client.generativeai.type.CountTokensResponse(totalTokens)

internal fun JsonObject.toPublic() = JSONObject(toString())

private fun encodeBitmapToBase64Png(input: Bitmap): String {
ByteArrayOutputStream().let {
input.compress(Bitmap.CompressFormat.JPEG, 80, it)
Expand Down
Loading

0 comments on commit f22a52f

Please sign in to comment.