forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix langchain4j#757: Gemini: allow SystemMessage(s), merge them into …
…the first UserMessage, warn in logs (langchain4j#812) ## Context See langchain4j#757 ## Change All `SystemMessage`s from the input are now merged together into the first `UserMessage`. Warning abut this is given (once) in the log. ## Checklist Before submitting this PR, please check the following points: - [X] I have added unit and integration tests for my change - [X] All unit and integration tests in the module I have added/changed are green - [X] All unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules are green - [X] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added my new module in the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml) (only when a new module is added)
- Loading branch information
1 parent
b3c0dad
commit 016f0b6
Showing
5 changed files
with
228 additions
and
128 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
69 changes: 59 additions & 10 deletions
69
...chain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/ContentsMapper.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,75 @@ | ||
package dev.langchain4j.model.vertexai; | ||
|
||
import com.google.cloud.vertexai.api.Content; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.SystemMessage; | ||
import dev.langchain4j.data.message.*; | ||
import lombok.extern.slf4j.Slf4j; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.concurrent.atomic.AtomicBoolean; | ||
|
||
import static java.util.stream.Collectors.toList; | ||
|
||
@Slf4j | ||
class ContentsMapper { | ||
|
||
static List<Content> map(List<ChatMessage> messages) { | ||
private static volatile boolean warned = false; | ||
|
||
static List<com.google.cloud.vertexai.api.Content> map(List<ChatMessage> messages) { | ||
|
||
List<SystemMessage> systemMessages = messages.stream() | ||
.filter(message -> message instanceof SystemMessage) | ||
.map(message -> (SystemMessage) message) | ||
.collect(toList()); | ||
|
||
if (!systemMessages.isEmpty()) { | ||
if (!warned) { | ||
log.warn("Gemini does not support SystemMessage(s). " + | ||
"All SystemMessage(s) will be merged into the first UserMessage."); | ||
warned = true; | ||
} | ||
messages = mergeSystemMessagesIntoUserMessage(messages, systemMessages); | ||
} | ||
|
||
// TODO what if only a single system message? | ||
|
||
return messages.stream() | ||
.peek(message -> { | ||
if (message instanceof SystemMessage) { | ||
throw new IllegalArgumentException("SystemMessage is currently not supported by Gemini"); | ||
} | ||
}) | ||
.map(message -> Content.newBuilder() | ||
.map(message -> com.google.cloud.vertexai.api.Content.newBuilder() | ||
.setRole(RoleMapper.map(message.type())) | ||
.addAllParts(PartsMapper.map(message)) | ||
.build()) | ||
.collect(toList()); | ||
} | ||
|
||
private static List<ChatMessage> mergeSystemMessagesIntoUserMessage(List<ChatMessage> messages, | ||
List<SystemMessage> systemMessages) { | ||
AtomicBoolean injected = new AtomicBoolean(false); | ||
return messages.stream() | ||
.filter(message -> !(message instanceof SystemMessage)) | ||
.map(message -> { | ||
if (injected.get()) { | ||
return message; | ||
} | ||
|
||
if (message instanceof UserMessage) { | ||
UserMessage userMessage = (UserMessage) message; | ||
|
||
List<Content> allContents = new ArrayList<>(); | ||
allContents.addAll(systemMessages.stream() | ||
.map(systemMessage -> TextContent.from(systemMessage.text())) | ||
.collect(toList())); | ||
allContents.addAll(userMessage.contents()); | ||
|
||
injected.set(true); | ||
|
||
if (userMessage.name() != null) { | ||
return UserMessage.from(userMessage.name(), allContents); | ||
} else { | ||
return UserMessage.from(allContents); | ||
} | ||
} | ||
|
||
return message; | ||
}) | ||
.collect(toList()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.