Skip to content

Commit

Permalink
add semaphore and illegal state exception to chat (#21)
Browse files Browse the repository at this point in the history
Co-authored-by: David Motsonashvili <davidmotson@google.com>
Co-authored-by: Rodrigo Lazo <rlazo@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 19, 2023
1 parent 3f0ccb1 commit f2aedb0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
1 change: 1 addition & 0 deletions .changes/calculator-bag-baby-chair.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MINOR","changes":["An instance of Chat will now throw an InvalidStateException if multiple requests are made simultaneously."]}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.google.ai.client.generativeai.type.InvalidStateException
import com.google.ai.client.generativeai.type.TextPart
import com.google.ai.client.generativeai.type.content
import java.util.LinkedList
import java.util.concurrent.Semaphore
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
Expand All @@ -35,33 +36,41 @@ import kotlinx.coroutines.flow.onEach
* Handles the capturing and storage of the communication with the model, providing methods for
* further interaction.
*
* Note: This object is not thread-safe, and calling [sendMessage] multiple times without waiting
* for a response will throw an [InvalidStateException].
*
* @param model the model to use for the interaction
* @property history the previous interactions with the model
*/
class Chat(private val model: GenerativeModel, val history: MutableList<Content> = ArrayList()) {
private var lock = Semaphore(1)

/**
* Generates a response from the backend with the provided [Content], and any previous ones
* sent/returned from this chat.
*
* @param prompt A [Content] to send to the model.
* @throws InvalidStateException if the prompt is not coming from the 'user' role
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: Content): GenerateContentResponse {
prompt.assertComesFromUser()

val response = model.generateContent(*history.toTypedArray(), prompt)

history.add(prompt)
history.add(response.candidates.first().content)

return response
attemptLock()
try {
val response = model.generateContent(*history.toTypedArray(), prompt)
history.add(prompt)
history.add(response.candidates.first().content)
return response
} finally {
lock.release()
}
}

/**
* Generates a response from the backend with the provided text represented [Content].
*
* @param prompt The text to be converted into a single piece of [Content] to send to the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: String): GenerateContentResponse {
val content = content("user") { text(prompt) }
Expand All @@ -72,6 +81,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* Generates a response from the backend with the provided image represented [Content].
*
* @param prompt The image to be converted into a single piece of [Content] to send to the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: Bitmap): GenerateContentResponse {
val content = content("user") { image(prompt) }
Expand All @@ -84,9 +94,11 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @param prompt A [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
* @throws InvalidStateException if the prompt is not coming from the 'user' role
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
fun sendMessageStream(prompt: Content): Flow<GenerateContentResponse> {
prompt.assertComesFromUser()
attemptLock()

val flow = model.generateContentStream(*history.toTypedArray(), prompt)
val bitmaps = LinkedList<Bitmap>()
Expand All @@ -109,6 +121,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
}
}
.onCompletion {
lock.release()
if (it == null) {
val content =
content("model") {
Expand All @@ -134,6 +147,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
*
* @param prompt A [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
fun sendMessageStream(prompt: String): Flow<GenerateContentResponse> {
val content = content("user") { text(prompt) }
Expand All @@ -145,6 +159,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
*
* @param prompt A [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
fun sendMessageStream(prompt: Bitmap): Flow<GenerateContentResponse> {
val content = content("user") { image(prompt) }
Expand All @@ -156,4 +171,13 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
throw InvalidStateException("Chat prompts should come from the 'user' role.")
}
}

private fun attemptLock() {
if (!lock.tryAcquire()) {
throw InvalidStateException(
"This chat instance currently has an ongoing request, please wait for it to complete " +
"before sending more messages"
)
}
}
}

0 comments on commit f2aedb0

Please sign in to comment.