Skip to content

Commit

Permalink
Correctly handle UsageMetadata (#135)
Browse files Browse the repository at this point in the history
Also, exposes the data when available.

Addresses issue #134
  • Loading branch information
rlazo authored May 3, 2024
1 parent a081fb3 commit 146c4d6
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ data class CountTokensResponse(val totalTokens: Int, val totalBillableCharacters

@Serializable
data class UsageMetadata(
val promptTokenCount: Int,
val candidatesTokenCount: Int?,
val totalTokenCount: Int
val promptTokenCount: Int? = null,
val candidatesTokenCount: Int? = null,
val totalTokenCount: Int? = null
)
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ internal class UnarySnapshotTests {
}
}

@Test
fun `response includes partial usage metadata`() =
goldenUnaryFile("success-partial-usage-metadata.json") {
withTimeout(testTimeout) {
val response = apiController.generateContent(textGenerateContentRequest("prompt"))

response.candidates?.isEmpty() shouldBe false
response.candidates?.first()?.finishReason shouldBe FinishReason.STOP
response.usageMetadata shouldNotBe null
response.usageMetadata?.promptTokenCount shouldBe 6
response.usageMetadata?.totalTokenCount shouldBe null
}
}

@Test
fun `citation returns correctly when using alternative name`() =
goldenUnaryFile("success-citations-altname.json") {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"candidates": [
{
"content": {
"parts": [
{
"text": "Mountain View, California, United States"
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
}
]
}
],
"usageMetadata": {
"promptTokenCount": 6
},
"promptFeedback": {
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
}
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import com.google.ai.client.generativeai.type.ImagePart
import com.google.ai.client.generativeai.type.SerializationException
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.ToolConfig
import com.google.ai.client.generativeai.type.UsageMetadata
import com.google.ai.client.generativeai.type.content
import java.io.ByteArrayOutputStream
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -136,6 +137,9 @@ internal fun ToolConfig.toInternal() =
)
)

internal fun com.google.ai.client.generativeai.common.UsageMetadata.toPublic(): UsageMetadata =
UsageMetadata(promptTokenCount ?: 0, candidatesTokenCount ?: 0, totalTokenCount ?: 0)

internal fun FunctionDeclaration.toInternal() =
com.google.ai.client.generativeai.common.client.FunctionDeclaration(
name,
Expand Down Expand Up @@ -269,6 +273,7 @@ internal fun GenerateContentResponse.toPublic() =
com.google.ai.client.generativeai.type.GenerateContentResponse(
candidates?.map { it.toPublic() }.orEmpty(),
promptFeedback?.toPublic(),
usageMetadata?.toPublic()
)

internal fun CountTokensResponse.toPublic() =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import android.util.Log
class GenerateContentResponse(
val candidates: List<Candidate>,
val promptFeedback: PromptFeedback?,
val usageMetadata: UsageMetadata?
) {
/** Convenience field representing the first text part in the response, if it exists. */
val text: String? by lazy { firstPartAs<TextPart>()?.text }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.type

/**
* Usage metadata about response(s).
*
* @param promptTokenCount Number of tokens in the request.
* @param candidatesTokenCount Number of tokens in the response(s).
* @param totalTokenCount Total number of tokens.
*/
class UsageMetadata(
val promptTokenCount: Int,
val candidatesTokenCount: Int,
val totalTokenCount: Int
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.google.ai.client.generativeai.common.GenerateContentRequest as Genera
import com.google.ai.client.generativeai.common.GenerateContentResponse as GenerateContentResponse_Common
import com.google.ai.client.generativeai.common.InvalidAPIKeyException as InvalidAPIKeyException_Common
import com.google.ai.client.generativeai.common.UnsupportedUserLocationException as UnsupportedUserLocationException_Common
import com.google.ai.client.generativeai.common.UsageMetadata as UsageMetadata_Common
import com.google.ai.client.generativeai.common.server.Candidate as Candidate_Common
import com.google.ai.client.generativeai.common.server.CitationMetadata as CitationMetadata_Common
import com.google.ai.client.generativeai.common.server.CitationSources
Expand All @@ -34,6 +35,7 @@ import com.google.ai.client.generativeai.type.InvalidAPIKeyException
import com.google.ai.client.generativeai.type.PromptFeedback
import com.google.ai.client.generativeai.type.TextPart
import com.google.ai.client.generativeai.type.UnsupportedUserLocationException
import com.google.ai.client.generativeai.type.UsageMetadata
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.collections.shouldHaveSize
import io.kotest.matchers.equality.shouldBeEqualToUsingFields
Expand Down Expand Up @@ -79,7 +81,8 @@ internal class GenerativeModelTests {
)
)
)
)
),
usageMetadata = UsageMetadata_Common(promptTokenCount = 10)
)

val expectedResponse =
Expand All @@ -100,7 +103,8 @@ internal class GenerativeModelTests {
finishReason = null
)
),
PromptFeedback(null, listOf())
PromptFeedback(null, listOf()),
UsageMetadata(10, 0, 0 /* default to 0*/)
)

val response = model.generateContent("Why's the sky blue?")
Expand Down

0 comments on commit 146c4d6

Please sign in to comment.