Skip to content

Commit

Permalink
Fix langchain4j#757: Gemini: allow SystemMessage(s), merge them into …
Browse files Browse the repository at this point in the history
…the first UserMessage, warn in logs (langchain4j#812)

## Context
See langchain4j#757

## Change
All `SystemMessage`s from the input are now merged together into the
first `UserMessage`.
Warning abut this is given (once) in the log.

## Checklist
Before submitting this PR, please check the following points:
- [X] I have added unit and integration tests for my change
- [X] All unit and integration tests in the module I have added/changed
are green
- [X] All unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules are green
- [X] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added my new module in the
[BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)
(only when a new module is added)
  • Loading branch information
langchain4j authored Mar 25, 2024
1 parent b3c0dad commit 016f0b6
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 128 deletions.
6 changes: 6 additions & 0 deletions docs/docs/integrations/language-models/google-gemini.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ Caused by: io.grpc.StatusRuntimeException:
`projects/{YOUR_PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-ultra`
```

## Warning

Please note that Gemini does not support `SystemMessage`s.
If there are `SystemMessage`s provided to the `generate()` methods, they will be merged into the first
`UserMessage` (before the content of the `UserMessage`).

## Apply for early access

[Early access for Gemma](https://docs.google.com/forms/d/e/1FAIpQLSe0grG6mRFW6dNF3Rb1h_YvKqUp2GaXiglZBgA2Os5iTLWlcg/viewform)
Expand Down
11 changes: 11 additions & 0 deletions langchain4j-vertex-ai-gemini/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

<dependencyManagement>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,75 @@
package dev.langchain4j.model.vertexai;

import com.google.cloud.vertexai.api.Content;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.*;
import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

import static java.util.stream.Collectors.toList;

@Slf4j
class ContentsMapper {

static List<Content> map(List<ChatMessage> messages) {
private static volatile boolean warned = false;

static List<com.google.cloud.vertexai.api.Content> map(List<ChatMessage> messages) {

List<SystemMessage> systemMessages = messages.stream()
.filter(message -> message instanceof SystemMessage)
.map(message -> (SystemMessage) message)
.collect(toList());

if (!systemMessages.isEmpty()) {
if (!warned) {
log.warn("Gemini does not support SystemMessage(s). " +
"All SystemMessage(s) will be merged into the first UserMessage.");
warned = true;
}
messages = mergeSystemMessagesIntoUserMessage(messages, systemMessages);
}

// TODO what if only a single system message?

return messages.stream()
.peek(message -> {
if (message instanceof SystemMessage) {
throw new IllegalArgumentException("SystemMessage is currently not supported by Gemini");
}
})
.map(message -> Content.newBuilder()
.map(message -> com.google.cloud.vertexai.api.Content.newBuilder()
.setRole(RoleMapper.map(message.type()))
.addAllParts(PartsMapper.map(message))
.build())
.collect(toList());
}

private static List<ChatMessage> mergeSystemMessagesIntoUserMessage(List<ChatMessage> messages,
List<SystemMessage> systemMessages) {
AtomicBoolean injected = new AtomicBoolean(false);
return messages.stream()
.filter(message -> !(message instanceof SystemMessage))
.map(message -> {
if (injected.get()) {
return message;
}

if (message instanceof UserMessage) {
UserMessage userMessage = (UserMessage) message;

List<Content> allContents = new ArrayList<>();
allContents.addAll(systemMessages.stream()
.map(systemMessage -> TextContent.from(systemMessage.text()))
.collect(toList()));
allContents.addAll(userMessage.contents());

injected.set(true);

if (userMessage.name() != null) {
return UserMessage.from(userMessage.name(), allContents);
} else {
return UserMessage.from(allContents);
}
}

return message;
})
.collect(toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.service.AiServices;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.stream.Stream;

import static dev.langchain4j.internal.Utils.readBytes;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*;

class VertexAiGeminiChatModelIT {
Expand Down Expand Up @@ -65,17 +69,65 @@ void should_generate_response() {
assertThat(response.finishReason()).isEqualTo(FinishReason.STOP);
}

@Test
void should_deny_system_message() {
@ParameterizedTest
@MethodSource
void should_merge_system_messages_into_user_message(List<ChatMessage> messages) {

// given
SystemMessage systemMessage = SystemMessage.from("Be polite");
UserMessage userMessage = UserMessage.from("Tell me a joke");
// when
Response<AiMessage> response = model.generate(messages);

// when-then
assertThatThrownBy(() -> model.generate(systemMessage, userMessage))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("SystemMessage is currently not supported by Gemini");
// then
assertThat(response.content().text()).containsIgnoringCase("liebe");
}

static Stream<Arguments> should_merge_system_messages_into_user_message() {
return Stream.<Arguments>builder()
.add(Arguments.of(
asList(
SystemMessage.from("Translate in German"),
UserMessage.from("I love you")
)
))
.add(Arguments.of(
asList(
UserMessage.from("I love you"),
SystemMessage.from("Translate in German")
)
))
.add(Arguments.of(
asList(
SystemMessage.from("Translate in Italian"),
UserMessage.from("I love you"),
SystemMessage.from("No, translate in German!")
)
))
.add(Arguments.of(
asList(
SystemMessage.from("Translate in German"),
UserMessage.from(asList(
TextContent.from("I love you"),
TextContent.from("I see you")
))
)
))
.add(Arguments.of(
asList(
SystemMessage.from("Translate in German"),
UserMessage.from(asList(
TextContent.from("I see you"),
TextContent.from("I love you")
))
)
))
.add(Arguments.of(
asList(
SystemMessage.from("Translate in German"),
UserMessage.from("I see you"),
AiMessage.from("Ich sehe dich"),
UserMessage.from("I love you")
)
))
.build();
}

@Test
Expand Down Expand Up @@ -104,7 +156,7 @@ void should_respect_maxOutputTokens() {
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());

assertThat(response.finishReason()).isEqualTo(LENGTH);
assertThat(response.finishReason()).isEqualTo(STOP);
}

@Test
Expand Down
Loading

0 comments on commit 016f0b6

Please sign in to comment.