diff --git a/langchain4j-ollama/pom.xml b/langchain4j-ollama/pom.xml index 82375ae9942..9b4081262b7 100644 --- a/langchain4j-ollama/pom.xml +++ b/langchain4j-ollama/pom.xml @@ -65,6 +65,18 @@ testcontainers test + + + org.tinylog + tinylog-impl + test + + + org.tinylog + slf4j-tinylog + test + + \ No newline at end of file diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatRequest.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatRequest.java index 1e853e22531..a502c50762b 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatRequest.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatRequest.java @@ -13,11 +13,9 @@ @Builder class ChatRequest { - /** - * model name - */ private String model; private List messages; private Options options; + private String format; private Boolean stream; } diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatResponse.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatResponse.java index df1f5e067ae..4aaa94b56fa 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatResponse.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatResponse.java @@ -10,6 +10,7 @@ @AllArgsConstructor @Builder class ChatResponse { + private String model; private String createdAt; private Message message; diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/CompletionRequest.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/CompletionRequest.java index 44443fd1c27..583868b5ac5 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/CompletionRequest.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/CompletionRequest.java @@ -11,15 +11,10 @@ @Builder class CompletionRequest { - /** - * model name - */ private String model; - /** - * the prompt to generate a response for - */ + private String system; private String prompt; private Options options; - private String system; + private String format; private Boolean stream; } diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Message.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Message.java index e0e4079b087..64feb163f98 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Message.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Message.java @@ -10,6 +10,7 @@ @AllArgsConstructor @Builder class Message { + private Role role; private String content; } diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaApi.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaApi.java index f7c0b7b3599..56d4968e4e2 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaApi.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaApi.java @@ -7,10 +7,7 @@ import retrofit2.http.POST; import retrofit2.http.Streaming; -/** - * Ollama API definitions using retrofit2 - */ -public interface OllamaApi { +interface OllamaApi { @POST("/api/generate") @Headers({"Content-Type: application/json"}) @@ -28,4 +25,9 @@ public interface OllamaApi { @POST("/api/chat") @Headers({"Content-Type: application/json"}) Call chat(@Body ChatRequest chatRequest); + + @POST("/api/chat") + @Headers({"Content-Type: application/json"}) + @Streaming + Call streamingChat(@Body ChatRequest chatRequest); } diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java index f24666d2869..0e43ba5c5a9 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java @@ -2,64 +2,105 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; import lombok.Builder; import java.time.Duration; -import java.util.ArrayList; import java.util.List; import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static java.time.Duration.ofSeconds; +import static java.util.stream.Collectors.toList; /** - * Ollama chat model implementation. + * Ollama API reference + *
+ * Ollama API parameters. */ public class OllamaChatModel implements ChatLanguageModel { private final OllamaClient client; - private final Double temperature; private final String modelName; + private final Options options; + private final String format; private final Integer maxRetries; @Builder - public OllamaChatModel(String baseUrl, Duration timeout, - String modelName, Double temperature, Integer maxRetries) { - this.client = OllamaClient.builder().baseUrl(baseUrl).timeout(timeout).build(); + public OllamaChatModel(String baseUrl, + String modelName, + Double temperature, + Integer topK, + Double topP, + Double repeatPenalty, + Integer seed, + Integer numPredict, + List stop, + String format, + Duration timeout, + Integer maxRetries) { + this.client = OllamaClient.builder() + .baseUrl(baseUrl) + .timeout(getOrDefault(timeout, ofSeconds(60))) + .build(); this.modelName = ensureNotBlank(modelName, "modelName"); - this.temperature = getOrDefault(temperature, 0.7); + this.options = Options.builder() + .temperature(temperature) + .topK(topK) + .topP(topP) + .repeatPenalty(repeatPenalty) + .seed(seed) + .numPredict(numPredict) + .stop(stop) + .build(); + this.format = format; this.maxRetries = getOrDefault(maxRetries, 3); } @Override public Response generate(List messages) { - if (messages == null || messages.isEmpty()) { - throw new IllegalArgumentException("messages must not be null or empty"); - } - - ArrayList messageList = new ArrayList<>(); - - messages.forEach(message -> { - Role role = Role.fromChatMessageType(message.type()); - messageList.add(Message.builder() - .role(role) - .content(message.text()) - .build()); - }); + ensureNotEmpty(messages, "messages"); ChatRequest request = ChatRequest.builder() .model(modelName) - .messages(messageList) - .options(Options.builder() - .temperature(temperature) - .build()) + .messages(toOllamaMessages(messages)) + .options(options) + .format(format) .stream(false) .build(); ChatResponse response = withRetry(() -> client.chat(request), maxRetries); - return Response.from(AiMessage.from(response.getMessage().getContent())); + return Response.from( + AiMessage.from(response.getMessage().getContent()), + new TokenUsage(response.getPromptEvalCount(), response.getEvalCount()) + ); + } + + static List toOllamaMessages(List messages) { + return messages.stream() + .map(message -> Message.builder() + .role(toOllamaRole(message.type())) + .content(message.text()) + .build()) + .collect(toList()); + } + + private static Role toOllamaRole(ChatMessageType chatMessageType) { + switch (chatMessageType) { + case SYSTEM: + return Role.SYSTEM; + case USER: + return Role.USER; + case AI: + return Role.ASSISTANT; + default: + throw new IllegalArgumentException("Unknown ChatMessageType: " + chatMessageType); + } } } diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java index fa10e91cddf..fbe90aebaf4 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java @@ -2,35 +2,35 @@ import com.google.gson.Gson; import com.google.gson.GsonBuilder; +import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import lombok.Builder; import okhttp3.OkHttpClient; import okhttp3.ResponseBody; import retrofit2.Call; import retrofit2.Callback; -import retrofit2.Response; import retrofit2.Retrofit; import retrofit2.converter.gson.GsonConverterFactory; import java.io.IOException; import java.io.InputStream; import java.time.Duration; -import java.util.Optional; import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES; -import static dev.langchain4j.internal.Utils.getOrDefault; -import static java.time.Duration.ofSeconds; +import static java.lang.Boolean.TRUE; class OllamaClient { - private final OllamaApi ollamaApi; - private static final Gson GSON = new GsonBuilder().setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES) + private static final Gson GSON = new GsonBuilder() + .setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES) .create(); + private final OllamaApi ollamaApi; + @Builder public OllamaClient(String baseUrl, Duration timeout) { - timeout = getOrDefault(timeout, ofSeconds(60)); OkHttpClient okHttpClient = new OkHttpClient.Builder() .callTimeout(timeout) @@ -50,7 +50,7 @@ public OllamaClient(String baseUrl, Duration timeout) { public CompletionResponse completion(CompletionRequest request) { try { - Response retrofitResponse + retrofit2.Response retrofitResponse = ollamaApi.completion(request).execute(); if (retrofitResponse.isSuccessful()) { @@ -65,7 +65,7 @@ public CompletionResponse completion(CompletionRequest request) { public ChatResponse chat(ChatRequest request) { try { - Response retrofitResponse + retrofit2.Response retrofitResponse = ollamaApi.chat(request).execute(); if (retrofitResponse.isSuccessful()) { @@ -80,32 +80,72 @@ public ChatResponse chat(ChatRequest request) { public void streamingCompletion(CompletionRequest request, StreamingResponseHandler handler) { ollamaApi.streamingCompletion(request).enqueue(new Callback() { + @Override - public void onResponse(Call call, Response response) { - try (InputStream inputStream = response.body().byteStream()) { - StringBuilder content = new StringBuilder(); - int inputTokenCount = 0; - int outputTokenCount = 0; + public void onResponse(Call call, retrofit2.Response retrofitResponse) { + try (InputStream inputStream = retrofitResponse.body().byteStream()) { + StringBuilder contentBuilder = new StringBuilder(); while (true) { byte[] bytes = new byte[1024]; int len = inputStream.read(bytes); String partialResponse = new String(bytes, 0, len); CompletionResponse completionResponse = GSON.fromJson(partialResponse, CompletionResponse.class); - // finish streaming response - if (Boolean.TRUE.equals(completionResponse.getDone())) { - handler.onComplete(dev.langchain4j.model.output.Response.from( - content.toString(), - new TokenUsage(inputTokenCount, outputTokenCount) - )); - break; + contentBuilder.append(completionResponse.getResponse()); + handler.onNext(completionResponse.getResponse()); + + if (TRUE.equals(completionResponse.getDone())) { + Response response = Response.from( + contentBuilder.toString(), + new TokenUsage( + completionResponse.getPromptEvalCount(), + completionResponse.getEvalCount() + ) + ); + handler.onComplete(response); + return; } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } - // handle cur token and tokenUsage - content.append(completionResponse.getResponse()); - inputTokenCount += Optional.ofNullable(completionResponse.getPromptEvalCount()).orElse(0); - outputTokenCount += Optional.ofNullable(completionResponse.getEvalCount()).orElse(0); - handler.onNext(completionResponse.getResponse()); + @Override + public void onFailure(Call call, Throwable throwable) { + handler.onError(throwable); + } + }); + } + + public void streamingChat(ChatRequest request, StreamingResponseHandler handler) { + ollamaApi.streamingChat(request).enqueue(new Callback() { + + @Override + public void onResponse(Call call, retrofit2.Response retrofitResponse) { + try (InputStream inputStream = retrofitResponse.body().byteStream()) { + StringBuilder contentBuilder = new StringBuilder(); + while (true) { + byte[] bytes = new byte[1024]; + int len = inputStream.read(bytes); + String partialResponse = new String(bytes, 0, len); + ChatResponse chatResponse = GSON.fromJson(partialResponse, ChatResponse.class); + + String content = chatResponse.getMessage().getContent(); + contentBuilder.append(content); + handler.onNext(content); + + if (TRUE.equals(chatResponse.getDone())) { + Response response = Response.from( + AiMessage.from(contentBuilder.toString()), + new TokenUsage( + chatResponse.getPromptEvalCount(), + chatResponse.getEvalCount() + ) + ); + handler.onComplete(response); + return; + } } } catch (IOException e) { throw new RuntimeException(e); @@ -132,7 +172,7 @@ public EmbeddingResponse embed(EmbeddingRequest request) { } } - private RuntimeException toException(Response response) throws IOException { + private RuntimeException toException(retrofit2.Response response) throws IOException { int code = response.code(); String body = response.errorBody().string(); diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaEmbeddingModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaEmbeddingModel.java index 61880a4d7eb..914b374b790 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaEmbeddingModel.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaEmbeddingModel.java @@ -13,9 +13,10 @@ import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static java.time.Duration.ofSeconds; /** - * Represents an Ollama embedding model. + * Ollama API reference */ public class OllamaEmbeddingModel implements EmbeddingModel { @@ -24,9 +25,14 @@ public class OllamaEmbeddingModel implements EmbeddingModel { private final Integer maxRetries; @Builder - public OllamaEmbeddingModel(String baseUrl, Duration timeout, - String modelName, Integer maxRetries) { - this.client = OllamaClient.builder().baseUrl(baseUrl).timeout(timeout).build(); + public OllamaEmbeddingModel(String baseUrl, + String modelName, + Duration timeout, + Integer maxRetries) { + this.client = OllamaClient.builder() + .baseUrl(baseUrl) + .timeout(getOrDefault(timeout, ofSeconds(60))) + .build(); this.modelName = ensureNotBlank(modelName, "modelName"); this.maxRetries = getOrDefault(maxRetries, 3); } @@ -34,6 +40,7 @@ public OllamaEmbeddingModel(String baseUrl, Duration timeout, @Override public Response> embedAll(List textSegments) { List embeddings = new ArrayList<>(); + textSegments.forEach(textSegment -> { EmbeddingRequest request = EmbeddingRequest.builder() .model(modelName) @@ -41,7 +48,8 @@ public Response> embedAll(List textSegments) { .build(); EmbeddingResponse response = withRetry(() -> client.embed(request), maxRetries); - embeddings.add(new Embedding(response.getEmbedding())); + + embeddings.add(Embedding.from(response.getEmbedding())); }); return Response.from(embeddings); diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaLanguageModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaLanguageModel.java index 8116b345869..778abcbee6d 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaLanguageModel.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaLanguageModel.java @@ -6,38 +6,65 @@ import lombok.Builder; import java.time.Duration; +import java.util.List; import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static java.time.Duration.ofSeconds; /** - * Represents an Ollama language model with a completion interface + * Ollama API reference + *
+ * Ollama API parameters. */ public class OllamaLanguageModel implements LanguageModel { private final OllamaClient client; private final String modelName; - private final Double temperature; + private final Options options; + private final String format; private final Integer maxRetries; @Builder - public OllamaLanguageModel(String baseUrl, Duration timeout, String modelName, - Double temperature, Integer maxRetries) { - this.client = OllamaClient.builder().baseUrl(baseUrl).timeout(timeout).build(); + public OllamaLanguageModel(String baseUrl, + String modelName, + Double temperature, + Integer topK, + Double topP, + Double repeatPenalty, + Integer seed, + Integer numPredict, + List stop, + String format, + Duration timeout, + Integer maxRetries) { + this.client = OllamaClient.builder() + .baseUrl(baseUrl) + .timeout(getOrDefault(timeout, ofSeconds(60))) + .build(); this.modelName = ensureNotBlank(modelName, "modelName"); - this.temperature = getOrDefault(temperature, 0.7); + this.options = Options.builder() + .temperature(temperature) + .topK(topK) + .topP(topP) + .repeatPenalty(repeatPenalty) + .seed(seed) + .numPredict(numPredict) + .stop(stop) + .build(); + this.format = format; this.maxRetries = getOrDefault(maxRetries, 3); } @Override public Response generate(String prompt) { + CompletionRequest request = CompletionRequest.builder() .model(modelName) .prompt(prompt) - .options(Options.builder() - .temperature(temperature) - .build()) + .options(options) + .format(format) .stream(false) .build(); diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java new file mode 100644 index 00000000000..c543fe4611e --- /dev/null +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java @@ -0,0 +1,73 @@ +package dev.langchain4j.model.ollama; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import lombok.Builder; + +import java.time.Duration; +import java.util.List; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.model.ollama.OllamaChatModel.toOllamaMessages; +import static java.time.Duration.ofSeconds; + +/** + * Ollama API reference + *
+ * Ollama API parameters. + */ +public class OllamaStreamingChatModel implements StreamingChatLanguageModel { + + private final OllamaClient client; + private final String modelName; + private final Options options; + private final String format; + + @Builder + public OllamaStreamingChatModel(String baseUrl, + String modelName, + Double temperature, + Integer topK, + Double topP, + Double repeatPenalty, + Integer seed, + Integer numPredict, + List stop, + String format, + Duration timeout) { + this.client = OllamaClient.builder() + .baseUrl(baseUrl) + .timeout(getOrDefault(timeout, ofSeconds(60))) + .build(); + this.modelName = ensureNotBlank(modelName, "modelName"); + this.options = Options.builder() + .temperature(temperature) + .topK(topK) + .topP(topP) + .repeatPenalty(repeatPenalty) + .seed(seed) + .numPredict(numPredict) + .stop(stop) + .build(); + this.format = format; + } + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + ensureNotEmpty(messages, "messages"); + + ChatRequest request = ChatRequest.builder() + .model(modelName) + .messages(toOllamaMessages(messages)) + .options(options) + .format(format) + .stream(true) + .build(); + + client.streamingChat(request, handler); + } +} diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModel.java index 17304a5a7df..9fe37fe6f08 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModel.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModel.java @@ -5,25 +5,51 @@ import lombok.Builder; import java.time.Duration; +import java.util.List; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static java.time.Duration.ofSeconds; /** - * Represents an Ollama streaming language model with a completion interface + * Ollama API reference + *
+ * Ollama API parameters. */ public class OllamaStreamingLanguageModel implements StreamingLanguageModel { private final OllamaClient client; private final String modelName; - private final Double temperature; + private final Options options; + private final String format; @Builder - public OllamaStreamingLanguageModel(String baseUrl, Duration timeout, - String modelName, Double temperature) { - this.client = OllamaClient.builder().baseUrl(baseUrl).timeout(timeout).build(); + public OllamaStreamingLanguageModel(String baseUrl, + String modelName, + Double temperature, + Integer topK, + Double topP, + Double repeatPenalty, + Integer seed, + Integer numPredict, + List stop, + String format, + Duration timeout) { + this.client = OllamaClient.builder() + .baseUrl(baseUrl) + .timeout(getOrDefault(timeout, ofSeconds(60))) + .build(); this.modelName = ensureNotBlank(modelName, "modelName"); - this.temperature = getOrDefault(temperature, 0.7); + this.options = Options.builder() + .temperature(temperature) + .topK(topK) + .topP(topP) + .repeatPenalty(repeatPenalty) + .seed(seed) + .numPredict(numPredict) + .stop(stop) + .build(); + this.format = format; } @Override @@ -31,9 +57,8 @@ public void generate(String prompt, StreamingResponseHandler handler) { CompletionRequest request = CompletionRequest.builder() .model(modelName) .prompt(prompt) - .options(Options.builder() - .temperature(temperature) - .build()) + .options(options) + .format(format) .stream(true) .build(); diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Options.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Options.java index eff2e4b38e0..65577719b9f 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Options.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Options.java @@ -5,6 +5,8 @@ import lombok.Data; import lombok.NoArgsConstructor; +import java.util.List; + /** * request options in completion/embedding API * @@ -16,8 +18,11 @@ @Builder class Options { - /** - * The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) - */ private Double temperature; + private Integer topK; + private Double topP; + private Double repeatPenalty; + private Integer seed; + private Integer numPredict; + private List stop; } diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Role.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Role.java index 26b4e8d77a2..eca1f02bdf0 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Role.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Role.java @@ -1,22 +1,8 @@ package dev.langchain4j.model.ollama; -import dev.langchain4j.data.message.ChatMessageType; - enum Role { + SYSTEM, USER, ASSISTANT; - - public static Role fromChatMessageType(ChatMessageType chatMessageType) { - switch (chatMessageType) { - case SYSTEM: - return SYSTEM; - case USER: - return USER; - case AI: - return Role.ASSISTANT; - default: - throw new IllegalArgumentException("Unknown ChatMessageType: " + chatMessageType); - } - } } \ No newline at end of file diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaInfrastructure.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaInfrastructure.java index 5bcbb5437c0..3845c83da56 100644 --- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaInfrastructure.java +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaInfrastructure.java @@ -18,9 +18,9 @@ public class AbstractOllamaInfrastructure { private static final String OLLAMA_IMAGE = "ollama/ollama:latest"; - private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-orca-mini", OLLAMA_IMAGE); + static final String MODEL = "phi"; - static final String ORCA_MINI_MODEL = "orca-mini"; + private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OLLAMA_IMAGE, MODEL); static OllamaContainer ollama; @@ -58,18 +58,18 @@ static class OllamaContainer extends GenericContainer { super(image.get()); this.dockerImageName = image.get(); withExposedPorts(11434); - withImagePullPolicy(dockerImageName -> !dockerImageName.getVersionPart().endsWith(ORCA_MINI_MODEL)); + withImagePullPolicy(dockerImageName -> !dockerImageName.getVersionPart().endsWith(MODEL)); } @Override protected void containerIsStarted(InspectContainerResponse containerInfo) { if (!this.dockerImageName.equals(DockerImageName.parse(LOCAL_OLLAMA_IMAGE))) { try { - log.info("Start pulling the 'orca-mini' model (3GB) ... would take several minutes ..."); - execInContainer("ollama", "pull", ORCA_MINI_MODEL); - log.info("orca-mini pulling competed!"); + log.info("Start pulling the '{}' model ... would take several minutes ...", MODEL); + execInContainer("ollama", "pull", MODEL); + log.info("Model pulling competed!"); } catch (IOException | InterruptedException e) { - throw new RuntimeException("Error pulling orca-mini model", e); + throw new RuntimeException("Error pulling model", e); } } } diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelIT.java index ba2be46bfb1..d382aa0fe3c 100644 --- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelIT.java +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelIT.java @@ -2,36 +2,126 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.Test; -import java.util.ArrayList; import java.util.List; -import static dev.langchain4j.data.message.SystemMessage.systemMessage; -import static dev.langchain4j.data.message.UserMessage.userMessage; +import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; class OllamaChatModelIT extends AbstractOllamaInfrastructure { - OllamaChatModel model = OllamaChatModel.builder() + ChatLanguageModel model = OllamaChatModel.builder() .baseUrl(getBaseUrl()) - .modelName(ORCA_MINI_MODEL) + .modelName(MODEL) + .temperature(0.0) .build(); @Test - void should_send_messages_with_roles_and_receive_response() { + void should_generate_answer() { - List chatMessages = new ArrayList<>(); - chatMessages.add(systemMessage("You are a good friend of mine, who likes to answer politely")); - chatMessages.add(userMessage("Hello!, How are you?")); - chatMessages.add(AiMessage.aiMessage("I'm fine, thanks!")); - chatMessages.add(userMessage("Not too bad, just enjoying a cup of coffee. What about you?")); + // given + String userMessage = "What is the capital of Germany?"; - Response response = model.generate(chatMessages); + // when + String answer = model.generate(userMessage); + System.out.println(answer); + + // then + assertThat(answer).contains("Berlin"); + } + + @Test + void should_respect_numPredict() { + + // given + int numPredict = 1; // max output tokens + + OllamaChatModel model = OllamaChatModel.builder() + .baseUrl(getBaseUrl()) + .modelName(MODEL) + .numPredict(numPredict) + .temperature(0.0) + .build(); + + UserMessage userMessage = UserMessage.from("What is the capital of Germany?"); + + // when + Response response = model.generate(userMessage); System.out.println(response); - assertThat(response).isNotNull(); - assertThat(response.content().text()).isNotEmpty(); + // then + assertThat(response.content().text()).doesNotContain("Berlin"); + assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(numPredict + 2); // bug in Ollama + } + + @Test + void should_respect_system_message() { + + // given + SystemMessage systemMessage = SystemMessage.from("Translate messages from user into German"); + UserMessage userMessage = UserMessage.from("I love you"); + + // when + Response response = model.generate(systemMessage, userMessage); + System.out.println(response); + + // then + assertThat(response.content().text()).containsIgnoringCase("liebe"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(18); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isNull(); + } + + @Test + void should_respond_to_few_shot() { + + // given + List messages = asList( + UserMessage.from("1 + 1 ="), + AiMessage.from(">>> 2"), + + UserMessage.from("2 + 2 ="), + AiMessage.from(">>> 4"), + + UserMessage.from("4 + 4 =") + ); + + // when + Response response = model.generate(messages); + System.out.println(response); + + // then + assertThat(response.content().text()).isEqualTo(">>> 8"); + } + + @Test + void should_generate_valid_json() { + + // given + ChatLanguageModel model = OllamaChatModel.builder() + .baseUrl(getBaseUrl()) + .modelName(MODEL) + .format("json") + .temperature(0.0) + .build(); + + String userMessage = "Return JSON with two fields: name and age of John Doe, 42 years old."; + + // when + String json = model.generate(userMessage); + + // then + assertThat(json).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}"); } } \ No newline at end of file diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaEmbeddingModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaEmbeddingModelIT.java index 5f16ec2f3b3..3f53076fa95 100644 --- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaEmbeddingModelIT.java +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaEmbeddingModelIT.java @@ -11,15 +11,23 @@ class OllamaEmbeddingModelIT extends AbstractOllamaInfrastructure { EmbeddingModel model = OllamaEmbeddingModel.builder() .baseUrl(getBaseUrl()) - .modelName(ORCA_MINI_MODEL) + .modelName(MODEL) .build(); @Test void should_embed() { - Response response = model.embed("hello world"); + // given + String text = "hello world"; + + // when + Response response = model.embed(text); System.out.println(response); - assertThat(response.content().vector()).isNotEmpty(); + // then + assertThat(response.content().vector()).isNotEmpty(); + + assertThat(response.tokenUsage()).isNull(); + assertThat(response.finishReason()).isNull(); } } diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaLanguageModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaLanguageModelIT.java index 73df2a3ceb9..040895d0fa1 100644 --- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaLanguageModelIT.java +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaLanguageModelIT.java @@ -2,6 +2,7 @@ import dev.langchain4j.model.language.LanguageModel; import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; @@ -10,18 +11,73 @@ class OllamaLanguageModelIT extends AbstractOllamaInfrastructure { LanguageModel model = OllamaLanguageModel.builder() .baseUrl(getBaseUrl()) - .modelName(ORCA_MINI_MODEL) + .modelName(MODEL) + .temperature(0.0) .build(); @Test void should_generate_answer() { - String prompt = "Hello, how are you?"; + // given + String userMessage = "What is the capital of Germany?"; + // when + Response response = model.generate(userMessage); + System.out.println(response); + + // then + assertThat(response.content()).contains("Berlin"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(43); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isNull(); + } + + @Test + void should_respect_numPredict() { + + // given + int numPredict = 1; // max output tokens + + LanguageModel model = OllamaLanguageModel.builder() + .baseUrl(getBaseUrl()) + .modelName(MODEL) + .numPredict(numPredict) + .temperature(0.0) + .build(); + + String prompt = "What is the capital of Germany?"; + + // when Response response = model.generate(prompt); System.out.println(response); - assertThat(response.content()).isNotBlank(); - assertThat(response.tokenUsage()).isNotNull(); + // then + assertThat(response.content()).doesNotContain("Berlin"); + assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(numPredict + 2); // bug in Ollama + } + + @Test + void should_generate_valid_json() { + + // given + LanguageModel model = OllamaLanguageModel.builder() + .baseUrl(getBaseUrl()) + .modelName(MODEL) + .format("json") + .temperature(0.0) + .build(); + + String userMessage = "Return JSON with two fields: name and age of John Doe, 42 years old."; + + // when + Response response = model.generate(userMessage); + + // then + assertThat(response.content()).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}"); } } diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelIT.java new file mode 100644 index 00000000000..a375d319d08 --- /dev/null +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelIT.java @@ -0,0 +1,276 @@ +package dev.langchain4j.model.ollama; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.concurrent.CompletableFuture; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; + +class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure { + + StreamingChatLanguageModel model = OllamaStreamingChatModel.builder() + .baseUrl(getBaseUrl()) + .modelName(MODEL) + .temperature(0.0) + .build(); + + @Test + void should_stream_answer() throws Exception { + + // given + String userMessage = "What is the capital of Germany?"; + + // when + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + model.generate(userMessage, new StreamingResponseHandler() { + + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String answer = futureAnswer.get(30, SECONDS); + Response response = futureResponse.get(30, SECONDS); + + // then + assertThat(answer).contains("Berlin"); + assertThat(response.content().text()).isEqualTo(answer); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(43); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isNull(); + } + + @Test + void should_respect_numPredict() throws Exception { + + // given + int numPredict = 1; // max output tokens + + StreamingChatLanguageModel model = OllamaStreamingChatModel.builder() + .baseUrl(getBaseUrl()) + .modelName(MODEL) + .numPredict(numPredict) + .temperature(0.0) + .build(); + + UserMessage userMessage = UserMessage.from("What is the capital of Germany?"); + + // when + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + model.generate(singletonList(userMessage), new StreamingResponseHandler() { + + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String answer = futureAnswer.get(30, SECONDS); + Response response = futureResponse.get(30, SECONDS); + + // then + assertThat(answer).doesNotContain("Berlin"); + assertThat(response.content().text()).isEqualTo(answer); + + assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(numPredict + 2); // bug in Ollama + } + + + @Test + void should_respect_system_message() throws Exception { + + // given + SystemMessage systemMessage = SystemMessage.from("Translate messages from user into German"); + UserMessage userMessage = UserMessage.from("I love you"); + + // when + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + model.generate(asList(systemMessage, userMessage), new StreamingResponseHandler() { + + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String answer = futureAnswer.get(30, SECONDS); + Response response = futureResponse.get(30, SECONDS); + + // then + assertThat(answer).containsIgnoringCase("liebe"); + assertThat(response.content().text()).isEqualTo(answer); + + assertThat(response.finishReason()).isNull(); + } + + @Test + void should_respond_to_few_shot() throws Exception { + + // given + List messages = asList( + UserMessage.from("1 + 1 ="), + AiMessage.from(">>> 2"), + + UserMessage.from("2 + 2 ="), + AiMessage.from(">>> 4"), + + UserMessage.from("4 + 4 =") + ); + + // when + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + model.generate(messages, new StreamingResponseHandler() { + + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String answer = futureAnswer.get(30, SECONDS); + Response response = futureResponse.get(30, SECONDS); + + // then + assertThat(answer).isEqualTo(">>> 8"); + assertThat(response.content().text()).isEqualTo(answer); + } + + @Test + void should_generate_valid_json() throws Exception { + + // given + StreamingChatLanguageModel model = OllamaStreamingChatModel.builder() + .baseUrl(getBaseUrl()) + .modelName(MODEL) + .format("json") + .temperature(0.0) + .build(); + + String userMessage = "Return JSON with two fields: name and age of John Doe, 42 years old."; + + // when + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + model.generate(userMessage, new StreamingResponseHandler() { + + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String answer = futureAnswer.get(30, SECONDS); + Response response = futureResponse.get(30, SECONDS); + + // then + assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}"); + assertThat(response.content().text()).isEqualTo(answer); + } +} \ No newline at end of file diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModelIT.java index afef80ba61c..b1cbb997066 100644 --- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModelIT.java +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModelIT.java @@ -3,29 +3,33 @@ import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.language.StreamingLanguageModel; import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.Test; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeoutException; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure { - StreamingLanguageModel model = OllamaStreamingLanguageModel.builder() - .baseUrl(getBaseUrl()) - .modelName(ORCA_MINI_MODEL) - .build(); - @Test - void should_stream_answer() throws ExecutionException, InterruptedException, TimeoutException { + void should_stream_answer() throws Exception { + + // given + String prompt = "What is the capital of Germany?"; + + StreamingLanguageModel model = OllamaStreamingLanguageModel.builder() + .baseUrl(getBaseUrl()) + .modelName(MODEL) + .temperature(0.0) + .build(); + // when CompletableFuture futureAnswer = new CompletableFuture<>(); CompletableFuture> futureResponse = new CompletableFuture<>(); - model.generate("What is the capital of Germany?", new StreamingResponseHandler() { + model.generate(prompt, new StreamingResponseHandler() { private final StringBuilder answerBuilder = new StringBuilder(); @@ -52,7 +56,118 @@ public void onError(Throwable error) { String answer = futureAnswer.get(30, SECONDS); Response response = futureResponse.get(30, SECONDS); + // then assertThat(answer).contains("Berlin"); assertThat(response.content()).isEqualTo(answer); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(43); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isNull(); + } + + @Test + void should_respect_numPredict() throws Exception { + + // given + int numPredict = 1; // max output tokens + + StreamingLanguageModel model = OllamaStreamingLanguageModel.builder() + .baseUrl(getBaseUrl()) + .modelName(MODEL) + .numPredict(numPredict) + .temperature(0.0) + .build(); + + String prompt = "What is the capital of Germany?"; + + // when + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + model.generate(prompt, new StreamingResponseHandler() { + + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String answer = futureAnswer.get(30, SECONDS); + Response response = futureResponse.get(30, SECONDS); + + // then + assertThat(answer).doesNotContain("Berlin"); + assertThat(response.content()).isEqualTo(answer); + + assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(numPredict + 2); // bug in Ollama + } + + @Test + void should_stream_valid_json() throws Exception { + + // given + StreamingLanguageModel model = OllamaStreamingLanguageModel.builder() + .baseUrl(getBaseUrl()) + .modelName(MODEL) + .format("json") + .temperature(0.0) + .build(); + + String prompt = "Return JSON with two fields: name and age of John Doe, 42 years old."; + + // when + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + model.generate(prompt, new StreamingResponseHandler() { + + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String answer = futureAnswer.get(30, SECONDS); + Response response = futureResponse.get(30, SECONDS); + + // then + assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}"); + assertThat(response.content()).isEqualTo(answer); } }