forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Freature langchain4j#1005 - Add streaming API for Bedrock Anthropics (l…
…angchain4j#1006) ## Context Feature adds Bedrock Antrhopics Streaming capability langchain4j#1005 Previous PR langchain4j#679 ## Change Added new streaming model `AbstractBedrockStreamingChatModel` ## Checklist Before submitting this PR, please check the following points: - [x] I have added unit and integration tests for my change - [x] All unit and integration tests in the module I have added/changed are green - [x] All unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules are green - [x] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added my new module in the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml) (only when a new module is added) ## Checklist for adding new embedding store integration - [ ] I have added a {NameOfIntegration}EmbeddingStoreIT that extends from either EmbeddingStoreIT or EmbeddingStoreWithFilteringIT
- Loading branch information
1 parent
050e93b
commit c2a1520
Showing
7 changed files
with
283 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 33 additions & 0 deletions
33
...drock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package dev.langchain4j.model.bedrock; | ||
|
||
import dev.langchain4j.model.bedrock.internal.AbstractBedrockStreamingChatModel; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.experimental.SuperBuilder; | ||
|
||
@Getter | ||
@SuperBuilder | ||
public class BedrockAnthropicStreamingChatModel extends AbstractBedrockStreamingChatModel { | ||
@Builder.Default | ||
private final String model = BedrockAnthropicStreamingChatModel.Types.AnthropicClaudeV2.getValue(); | ||
|
||
@Override | ||
protected String getModelId() { | ||
return model; | ||
} | ||
|
||
@Getter | ||
/** | ||
* Bedrock Anthropic model ids | ||
*/ | ||
public enum Types { | ||
AnthropicClaudeV2("anthropic.claude-v2"), | ||
AnthropicClaudeV2_1("anthropic.claude-v2:1"); | ||
|
||
private final String value; | ||
|
||
Types(String modelID) { | ||
this.value = modelID; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
...c/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package dev.langchain4j.model.bedrock.internal; | ||
|
||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.UserMessage; | ||
import dev.langchain4j.internal.Json; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | ||
import dev.langchain4j.model.output.Response; | ||
import lombok.Getter; | ||
import lombok.experimental.SuperBuilder; | ||
import software.amazon.awssdk.core.SdkBytes; | ||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; | ||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest; | ||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
|
||
/** | ||
* Bedrock Streaming chat model | ||
*/ | ||
@Getter | ||
@SuperBuilder | ||
public abstract class AbstractBedrockStreamingChatModel extends AbstractSharedBedrockChatModel implements StreamingChatLanguageModel { | ||
@Getter | ||
private final BedrockRuntimeAsyncClient asyncClient = initAsyncClient(); | ||
|
||
class StreamingResponse { | ||
public String completion; | ||
} | ||
|
||
@Override | ||
public void generate(String userMessage, StreamingResponseHandler<AiMessage> handler) { | ||
List<ChatMessage> messages = new ArrayList<>(); | ||
messages.add(new UserMessage(userMessage)); | ||
generate(messages, handler); | ||
} | ||
|
||
@Override | ||
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) { | ||
InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder() | ||
.body(SdkBytes.fromUtf8String(convertMessagesToAwsBody(messages))) | ||
.modelId(getModelId()) | ||
.contentType("application/json") | ||
.accept("application/json") | ||
.build(); | ||
|
||
StringBuffer finalCompletion = new StringBuffer(); | ||
|
||
InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor.builder() | ||
.onChunk(chunk -> { | ||
StreamingResponse sr = Json.fromJson(chunk.bytes().asUtf8String(), StreamingResponse.class); | ||
finalCompletion.append(sr.completion); | ||
handler.onNext(sr.completion); | ||
}) | ||
.build(); | ||
|
||
InvokeModelWithResponseStreamResponseHandler h = InvokeModelWithResponseStreamResponseHandler.builder() | ||
.onEventStream(stream -> stream.subscribe(event -> event.accept(visitor))) | ||
.onComplete(() -> { | ||
handler.onComplete(Response.from(new AiMessage(finalCompletion.toString()))); | ||
}) | ||
.onError(handler::onError) | ||
.build(); | ||
asyncClient.invokeModelWithResponseStream(request, h).join(); | ||
|
||
} | ||
|
||
/** | ||
* Initialize async bedrock client | ||
* | ||
* @return async bedrock client | ||
*/ | ||
private BedrockRuntimeAsyncClient initAsyncClient() { | ||
BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder() | ||
.region(region) | ||
.credentialsProvider(credentialsProvider) | ||
.build(); | ||
return client; | ||
} | ||
|
||
|
||
|
||
} |
112 changes: 112 additions & 0 deletions
112
.../src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
package dev.langchain4j.model.bedrock.internal; | ||
|
||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.ChatMessageType; | ||
import dev.langchain4j.internal.Json; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.experimental.SuperBuilder; | ||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; | ||
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; | ||
import software.amazon.awssdk.regions.Region; | ||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; | ||
|
||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static java.util.stream.Collectors.joining; | ||
|
||
@Getter | ||
@SuperBuilder | ||
public abstract class AbstractSharedBedrockChatModel { | ||
// Claude requires you to enclose the prompt as follows: | ||
// String enclosedPrompt = "Human: " + prompt + "\n\nAssistant:"; | ||
protected static final String HUMAN_PROMPT = "Human:"; | ||
protected static final String ASSISTANT_PROMPT = "Assistant:"; | ||
protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31"; | ||
|
||
@Builder.Default | ||
protected final String humanPrompt = HUMAN_PROMPT; | ||
@Builder.Default | ||
protected final String assistantPrompt = ASSISTANT_PROMPT; | ||
@Builder.Default | ||
protected final Integer maxRetries = 5; | ||
@Builder.Default | ||
protected final Region region = Region.US_EAST_1; | ||
@Builder.Default | ||
protected final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build(); | ||
@Builder.Default | ||
protected final int maxTokens = 300; | ||
@Builder.Default | ||
protected final double temperature = 1; | ||
@Builder.Default | ||
protected final float topP = 0.999f; | ||
@Builder.Default | ||
protected final String[] stopSequences = new String[]{}; | ||
@Builder.Default | ||
protected final int topK = 250; | ||
@Builder.Default | ||
protected final String anthropicVersion = DEFAULT_ANTHROPIC_VERSION; | ||
|
||
|
||
/** | ||
* Convert chat message to string | ||
* | ||
* @param message chat message | ||
* @return string | ||
*/ | ||
protected String chatMessageToString(ChatMessage message) { | ||
switch (message.type()) { | ||
case SYSTEM: | ||
return message.text(); | ||
case USER: | ||
return humanPrompt + " " + message.text(); | ||
case AI: | ||
return assistantPrompt + " " + message.text(); | ||
case TOOL_EXECUTION_RESULT: | ||
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models"); | ||
} | ||
|
||
throw new IllegalArgumentException("Unknown message type: " + message.type()); | ||
} | ||
|
||
protected String convertMessagesToAwsBody(List<ChatMessage> messages) { | ||
final String context = messages.stream() | ||
.filter(message -> message.type() == ChatMessageType.SYSTEM) | ||
.map(ChatMessage::text) | ||
.collect(joining("\n")); | ||
|
||
final String userMessages = messages.stream() | ||
.filter(message -> message.type() != ChatMessageType.SYSTEM) | ||
.map(this::chatMessageToString) | ||
.collect(joining("\n")); | ||
|
||
final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT); | ||
final Map<String, Object> requestParameters = getRequestParameters(prompt); | ||
final String body = Json.toJson(requestParameters); | ||
return body; | ||
} | ||
|
||
protected Map<String, Object> getRequestParameters(String prompt) { | ||
final Map<String, Object> parameters = new HashMap<>(7); | ||
|
||
parameters.put("prompt", prompt); | ||
parameters.put("max_tokens_to_sample", getMaxTokens()); | ||
parameters.put("temperature", getTemperature()); | ||
parameters.put("top_k", topK); | ||
parameters.put("top_p", getTopP()); | ||
parameters.put("stop_sequences", getStopSequences()); | ||
parameters.put("anthropic_version", anthropicVersion); | ||
|
||
return parameters; | ||
} | ||
|
||
/** | ||
* Get model id | ||
* | ||
* @return model id | ||
*/ | ||
protected abstract String getModelId(); | ||
|
||
} |
39 changes: 39 additions & 0 deletions
39
...in4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package dev.langchain4j.model.bedrock; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.UserMessage; | ||
import dev.langchain4j.model.chat.TestStreamingResponseHandler; | ||
import dev.langchain4j.model.output.Response; | ||
import org.junit.jupiter.api.Disabled; | ||
import org.junit.jupiter.api.Test; | ||
import software.amazon.awssdk.regions.Region; | ||
|
||
import static dev.langchain4j.data.message.UserMessage.userMessage; | ||
import static java.util.Collections.singletonList; | ||
import static org.assertj.core.api.Assertions.assertThat; | ||
|
||
public class BedrockStreamingChatModelIT { | ||
@Test | ||
@Disabled("To run this test, you must have provide your own access key, secret, region") | ||
void testBedrockAnthropicStreamingChatModel() { | ||
//given | ||
BedrockAnthropicStreamingChatModel bedrockChatModel = BedrockAnthropicStreamingChatModel | ||
.builder() | ||
.temperature(0.5) | ||
.maxTokens(300) | ||
.region(Region.US_EAST_1) | ||
.maxRetries(1) | ||
.build(); | ||
UserMessage userMessage = userMessage("What's the capital of Poland?"); | ||
|
||
//when | ||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>(); | ||
bedrockChatModel.generate(singletonList(userMessage), handler); | ||
Response<AiMessage> response = handler.get(); | ||
|
||
//then | ||
assertThat(response.content().text()).contains("Warsaw"); | ||
} | ||
|
||
|
||
} |