diff --git a/langchain4j-ollama/pom.xml b/langchain4j-ollama/pom.xml
index 82375ae9942..9b4081262b7 100644
--- a/langchain4j-ollama/pom.xml
+++ b/langchain4j-ollama/pom.xml
@@ -65,6 +65,18 @@
testcontainers
test
+
+
+ org.tinylog
+ tinylog-impl
+ test
+
+
+ org.tinylog
+ slf4j-tinylog
+ test
+
+
\ No newline at end of file
diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatRequest.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatRequest.java
index 1e853e22531..a502c50762b 100644
--- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatRequest.java
+++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatRequest.java
@@ -13,11 +13,9 @@
@Builder
class ChatRequest {
- /**
- * model name
- */
private String model;
private List messages;
private Options options;
+ private String format;
private Boolean stream;
}
diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatResponse.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatResponse.java
index df1f5e067ae..4aaa94b56fa 100644
--- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatResponse.java
+++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatResponse.java
@@ -10,6 +10,7 @@
@AllArgsConstructor
@Builder
class ChatResponse {
+
private String model;
private String createdAt;
private Message message;
diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/CompletionRequest.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/CompletionRequest.java
index 44443fd1c27..583868b5ac5 100644
--- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/CompletionRequest.java
+++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/CompletionRequest.java
@@ -11,15 +11,10 @@
@Builder
class CompletionRequest {
- /**
- * model name
- */
private String model;
- /**
- * the prompt to generate a response for
- */
+ private String system;
private String prompt;
private Options options;
- private String system;
+ private String format;
private Boolean stream;
}
diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Message.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Message.java
index e0e4079b087..64feb163f98 100644
--- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Message.java
+++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Message.java
@@ -10,6 +10,7 @@
@AllArgsConstructor
@Builder
class Message {
+
private Role role;
private String content;
}
diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaApi.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaApi.java
index f7c0b7b3599..56d4968e4e2 100644
--- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaApi.java
+++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaApi.java
@@ -7,10 +7,7 @@
import retrofit2.http.POST;
import retrofit2.http.Streaming;
-/**
- * Ollama API definitions using retrofit2
- */
-public interface OllamaApi {
+interface OllamaApi {
@POST("/api/generate")
@Headers({"Content-Type: application/json"})
@@ -28,4 +25,9 @@ public interface OllamaApi {
@POST("/api/chat")
@Headers({"Content-Type: application/json"})
Call chat(@Body ChatRequest chatRequest);
+
+ @POST("/api/chat")
+ @Headers({"Content-Type: application/json"})
+ @Streaming
+ Call streamingChat(@Body ChatRequest chatRequest);
}
diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java
index f24666d2869..0e43ba5c5a9 100644
--- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java
+++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java
@@ -2,64 +2,105 @@
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
+import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;
import java.time.Duration;
-import java.util.ArrayList;
import java.util.List;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
+import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
+import static java.time.Duration.ofSeconds;
+import static java.util.stream.Collectors.toList;
/**
- * Ollama chat model implementation.
+ * Ollama API reference
+ *
+ * Ollama API parameters.
*/
public class OllamaChatModel implements ChatLanguageModel {
private final OllamaClient client;
- private final Double temperature;
private final String modelName;
+ private final Options options;
+ private final String format;
private final Integer maxRetries;
@Builder
- public OllamaChatModel(String baseUrl, Duration timeout,
- String modelName, Double temperature, Integer maxRetries) {
- this.client = OllamaClient.builder().baseUrl(baseUrl).timeout(timeout).build();
+ public OllamaChatModel(String baseUrl,
+ String modelName,
+ Double temperature,
+ Integer topK,
+ Double topP,
+ Double repeatPenalty,
+ Integer seed,
+ Integer numPredict,
+ List stop,
+ String format,
+ Duration timeout,
+ Integer maxRetries) {
+ this.client = OllamaClient.builder()
+ .baseUrl(baseUrl)
+ .timeout(getOrDefault(timeout, ofSeconds(60)))
+ .build();
this.modelName = ensureNotBlank(modelName, "modelName");
- this.temperature = getOrDefault(temperature, 0.7);
+ this.options = Options.builder()
+ .temperature(temperature)
+ .topK(topK)
+ .topP(topP)
+ .repeatPenalty(repeatPenalty)
+ .seed(seed)
+ .numPredict(numPredict)
+ .stop(stop)
+ .build();
+ this.format = format;
this.maxRetries = getOrDefault(maxRetries, 3);
}
@Override
public Response generate(List messages) {
- if (messages == null || messages.isEmpty()) {
- throw new IllegalArgumentException("messages must not be null or empty");
- }
-
- ArrayList messageList = new ArrayList<>();
-
- messages.forEach(message -> {
- Role role = Role.fromChatMessageType(message.type());
- messageList.add(Message.builder()
- .role(role)
- .content(message.text())
- .build());
- });
+ ensureNotEmpty(messages, "messages");
ChatRequest request = ChatRequest.builder()
.model(modelName)
- .messages(messageList)
- .options(Options.builder()
- .temperature(temperature)
- .build())
+ .messages(toOllamaMessages(messages))
+ .options(options)
+ .format(format)
.stream(false)
.build();
ChatResponse response = withRetry(() -> client.chat(request), maxRetries);
- return Response.from(AiMessage.from(response.getMessage().getContent()));
+ return Response.from(
+ AiMessage.from(response.getMessage().getContent()),
+ new TokenUsage(response.getPromptEvalCount(), response.getEvalCount())
+ );
+ }
+
+ static List toOllamaMessages(List messages) {
+ return messages.stream()
+ .map(message -> Message.builder()
+ .role(toOllamaRole(message.type()))
+ .content(message.text())
+ .build())
+ .collect(toList());
+ }
+
+ private static Role toOllamaRole(ChatMessageType chatMessageType) {
+ switch (chatMessageType) {
+ case SYSTEM:
+ return Role.SYSTEM;
+ case USER:
+ return Role.USER;
+ case AI:
+ return Role.ASSISTANT;
+ default:
+ throw new IllegalArgumentException("Unknown ChatMessageType: " + chatMessageType);
+ }
}
}
diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java
index fa10e91cddf..fbe90aebaf4 100644
--- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java
+++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java
@@ -2,35 +2,35 @@
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
+import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
+import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;
import okhttp3.OkHttpClient;
import okhttp3.ResponseBody;
import retrofit2.Call;
import retrofit2.Callback;
-import retrofit2.Response;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;
import java.io.IOException;
import java.io.InputStream;
import java.time.Duration;
-import java.util.Optional;
import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES;
-import static dev.langchain4j.internal.Utils.getOrDefault;
-import static java.time.Duration.ofSeconds;
+import static java.lang.Boolean.TRUE;
class OllamaClient {
- private final OllamaApi ollamaApi;
- private static final Gson GSON = new GsonBuilder().setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES)
+ private static final Gson GSON = new GsonBuilder()
+ .setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES)
.create();
+ private final OllamaApi ollamaApi;
+
@Builder
public OllamaClient(String baseUrl, Duration timeout) {
- timeout = getOrDefault(timeout, ofSeconds(60));
OkHttpClient okHttpClient = new OkHttpClient.Builder()
.callTimeout(timeout)
@@ -50,7 +50,7 @@ public OllamaClient(String baseUrl, Duration timeout) {
public CompletionResponse completion(CompletionRequest request) {
try {
- Response retrofitResponse
+ retrofit2.Response retrofitResponse
= ollamaApi.completion(request).execute();
if (retrofitResponse.isSuccessful()) {
@@ -65,7 +65,7 @@ public CompletionResponse completion(CompletionRequest request) {
public ChatResponse chat(ChatRequest request) {
try {
- Response retrofitResponse
+ retrofit2.Response retrofitResponse
= ollamaApi.chat(request).execute();
if (retrofitResponse.isSuccessful()) {
@@ -80,32 +80,72 @@ public ChatResponse chat(ChatRequest request) {
public void streamingCompletion(CompletionRequest request, StreamingResponseHandler handler) {
ollamaApi.streamingCompletion(request).enqueue(new Callback() {
+
@Override
- public void onResponse(Call call, Response response) {
- try (InputStream inputStream = response.body().byteStream()) {
- StringBuilder content = new StringBuilder();
- int inputTokenCount = 0;
- int outputTokenCount = 0;
+ public void onResponse(Call call, retrofit2.Response retrofitResponse) {
+ try (InputStream inputStream = retrofitResponse.body().byteStream()) {
+ StringBuilder contentBuilder = new StringBuilder();
while (true) {
byte[] bytes = new byte[1024];
int len = inputStream.read(bytes);
String partialResponse = new String(bytes, 0, len);
CompletionResponse completionResponse = GSON.fromJson(partialResponse, CompletionResponse.class);
- // finish streaming response
- if (Boolean.TRUE.equals(completionResponse.getDone())) {
- handler.onComplete(dev.langchain4j.model.output.Response.from(
- content.toString(),
- new TokenUsage(inputTokenCount, outputTokenCount)
- ));
- break;
+ contentBuilder.append(completionResponse.getResponse());
+ handler.onNext(completionResponse.getResponse());
+
+ if (TRUE.equals(completionResponse.getDone())) {
+ Response response = Response.from(
+ contentBuilder.toString(),
+ new TokenUsage(
+ completionResponse.getPromptEvalCount(),
+ completionResponse.getEvalCount()
+ )
+ );
+ handler.onComplete(response);
+ return;
}
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
- // handle cur token and tokenUsage
- content.append(completionResponse.getResponse());
- inputTokenCount += Optional.ofNullable(completionResponse.getPromptEvalCount()).orElse(0);
- outputTokenCount += Optional.ofNullable(completionResponse.getEvalCount()).orElse(0);
- handler.onNext(completionResponse.getResponse());
+ @Override
+ public void onFailure(Call call, Throwable throwable) {
+ handler.onError(throwable);
+ }
+ });
+ }
+
+ public void streamingChat(ChatRequest request, StreamingResponseHandler handler) {
+ ollamaApi.streamingChat(request).enqueue(new Callback() {
+
+ @Override
+ public void onResponse(Call call, retrofit2.Response retrofitResponse) {
+ try (InputStream inputStream = retrofitResponse.body().byteStream()) {
+ StringBuilder contentBuilder = new StringBuilder();
+ while (true) {
+ byte[] bytes = new byte[1024];
+ int len = inputStream.read(bytes);
+ String partialResponse = new String(bytes, 0, len);
+ ChatResponse chatResponse = GSON.fromJson(partialResponse, ChatResponse.class);
+
+ String content = chatResponse.getMessage().getContent();
+ contentBuilder.append(content);
+ handler.onNext(content);
+
+ if (TRUE.equals(chatResponse.getDone())) {
+ Response response = Response.from(
+ AiMessage.from(contentBuilder.toString()),
+ new TokenUsage(
+ chatResponse.getPromptEvalCount(),
+ chatResponse.getEvalCount()
+ )
+ );
+ handler.onComplete(response);
+ return;
+ }
}
} catch (IOException e) {
throw new RuntimeException(e);
@@ -132,7 +172,7 @@ public EmbeddingResponse embed(EmbeddingRequest request) {
}
}
- private RuntimeException toException(Response> response) throws IOException {
+ private RuntimeException toException(retrofit2.Response> response) throws IOException {
int code = response.code();
String body = response.errorBody().string();
diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaEmbeddingModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaEmbeddingModel.java
index 61880a4d7eb..914b374b790 100644
--- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaEmbeddingModel.java
+++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaEmbeddingModel.java
@@ -13,9 +13,10 @@
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
+import static java.time.Duration.ofSeconds;
/**
- * Represents an Ollama embedding model.
+ * Ollama API reference
*/
public class OllamaEmbeddingModel implements EmbeddingModel {
@@ -24,9 +25,14 @@ public class OllamaEmbeddingModel implements EmbeddingModel {
private final Integer maxRetries;
@Builder
- public OllamaEmbeddingModel(String baseUrl, Duration timeout,
- String modelName, Integer maxRetries) {
- this.client = OllamaClient.builder().baseUrl(baseUrl).timeout(timeout).build();
+ public OllamaEmbeddingModel(String baseUrl,
+ String modelName,
+ Duration timeout,
+ Integer maxRetries) {
+ this.client = OllamaClient.builder()
+ .baseUrl(baseUrl)
+ .timeout(getOrDefault(timeout, ofSeconds(60)))
+ .build();
this.modelName = ensureNotBlank(modelName, "modelName");
this.maxRetries = getOrDefault(maxRetries, 3);
}
@@ -34,6 +40,7 @@ public OllamaEmbeddingModel(String baseUrl, Duration timeout,
@Override
public Response> embedAll(List textSegments) {
List embeddings = new ArrayList<>();
+
textSegments.forEach(textSegment -> {
EmbeddingRequest request = EmbeddingRequest.builder()
.model(modelName)
@@ -41,7 +48,8 @@ public Response> embedAll(List textSegments) {
.build();
EmbeddingResponse response = withRetry(() -> client.embed(request), maxRetries);
- embeddings.add(new Embedding(response.getEmbedding()));
+
+ embeddings.add(Embedding.from(response.getEmbedding()));
});
return Response.from(embeddings);
diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaLanguageModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaLanguageModel.java
index 8116b345869..778abcbee6d 100644
--- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaLanguageModel.java
+++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaLanguageModel.java
@@ -6,38 +6,65 @@
import lombok.Builder;
import java.time.Duration;
+import java.util.List;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
+import static java.time.Duration.ofSeconds;
/**
- * Represents an Ollama language model with a completion interface
+ * Ollama API reference
+ *
+ * Ollama API parameters.
*/
public class OllamaLanguageModel implements LanguageModel {
private final OllamaClient client;
private final String modelName;
- private final Double temperature;
+ private final Options options;
+ private final String format;
private final Integer maxRetries;
@Builder
- public OllamaLanguageModel(String baseUrl, Duration timeout, String modelName,
- Double temperature, Integer maxRetries) {
- this.client = OllamaClient.builder().baseUrl(baseUrl).timeout(timeout).build();
+ public OllamaLanguageModel(String baseUrl,
+ String modelName,
+ Double temperature,
+ Integer topK,
+ Double topP,
+ Double repeatPenalty,
+ Integer seed,
+ Integer numPredict,
+ List stop,
+ String format,
+ Duration timeout,
+ Integer maxRetries) {
+ this.client = OllamaClient.builder()
+ .baseUrl(baseUrl)
+ .timeout(getOrDefault(timeout, ofSeconds(60)))
+ .build();
this.modelName = ensureNotBlank(modelName, "modelName");
- this.temperature = getOrDefault(temperature, 0.7);
+ this.options = Options.builder()
+ .temperature(temperature)
+ .topK(topK)
+ .topP(topP)
+ .repeatPenalty(repeatPenalty)
+ .seed(seed)
+ .numPredict(numPredict)
+ .stop(stop)
+ .build();
+ this.format = format;
this.maxRetries = getOrDefault(maxRetries, 3);
}
@Override
public Response generate(String prompt) {
+
CompletionRequest request = CompletionRequest.builder()
.model(modelName)
.prompt(prompt)
- .options(Options.builder()
- .temperature(temperature)
- .build())
+ .options(options)
+ .format(format)
.stream(false)
.build();
diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java
new file mode 100644
index 00000000000..c543fe4611e
--- /dev/null
+++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java
@@ -0,0 +1,73 @@
+package dev.langchain4j.model.ollama;
+
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.model.StreamingResponseHandler;
+import dev.langchain4j.model.chat.StreamingChatLanguageModel;
+import lombok.Builder;
+
+import java.time.Duration;
+import java.util.List;
+
+import static dev.langchain4j.internal.Utils.getOrDefault;
+import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
+import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
+import static dev.langchain4j.model.ollama.OllamaChatModel.toOllamaMessages;
+import static java.time.Duration.ofSeconds;
+
+/**
+ * Ollama API reference
+ *
+ * Ollama API parameters.
+ */
+public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
+
+ private final OllamaClient client;
+ private final String modelName;
+ private final Options options;
+ private final String format;
+
+ @Builder
+ public OllamaStreamingChatModel(String baseUrl,
+ String modelName,
+ Double temperature,
+ Integer topK,
+ Double topP,
+ Double repeatPenalty,
+ Integer seed,
+ Integer numPredict,
+ List stop,
+ String format,
+ Duration timeout) {
+ this.client = OllamaClient.builder()
+ .baseUrl(baseUrl)
+ .timeout(getOrDefault(timeout, ofSeconds(60)))
+ .build();
+ this.modelName = ensureNotBlank(modelName, "modelName");
+ this.options = Options.builder()
+ .temperature(temperature)
+ .topK(topK)
+ .topP(topP)
+ .repeatPenalty(repeatPenalty)
+ .seed(seed)
+ .numPredict(numPredict)
+ .stop(stop)
+ .build();
+ this.format = format;
+ }
+
+ @Override
+ public void generate(List messages, StreamingResponseHandler handler) {
+ ensureNotEmpty(messages, "messages");
+
+ ChatRequest request = ChatRequest.builder()
+ .model(modelName)
+ .messages(toOllamaMessages(messages))
+ .options(options)
+ .format(format)
+ .stream(true)
+ .build();
+
+ client.streamingChat(request, handler);
+ }
+}
diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModel.java
index 17304a5a7df..9fe37fe6f08 100644
--- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModel.java
+++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModel.java
@@ -5,25 +5,51 @@
import lombok.Builder;
import java.time.Duration;
+import java.util.List;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
+import static java.time.Duration.ofSeconds;
/**
- * Represents an Ollama streaming language model with a completion interface
+ * Ollama API reference
+ *
+ * Ollama API parameters.
*/
public class OllamaStreamingLanguageModel implements StreamingLanguageModel {
private final OllamaClient client;
private final String modelName;
- private final Double temperature;
+ private final Options options;
+ private final String format;
@Builder
- public OllamaStreamingLanguageModel(String baseUrl, Duration timeout,
- String modelName, Double temperature) {
- this.client = OllamaClient.builder().baseUrl(baseUrl).timeout(timeout).build();
+ public OllamaStreamingLanguageModel(String baseUrl,
+ String modelName,
+ Double temperature,
+ Integer topK,
+ Double topP,
+ Double repeatPenalty,
+ Integer seed,
+ Integer numPredict,
+ List stop,
+ String format,
+ Duration timeout) {
+ this.client = OllamaClient.builder()
+ .baseUrl(baseUrl)
+ .timeout(getOrDefault(timeout, ofSeconds(60)))
+ .build();
this.modelName = ensureNotBlank(modelName, "modelName");
- this.temperature = getOrDefault(temperature, 0.7);
+ this.options = Options.builder()
+ .temperature(temperature)
+ .topK(topK)
+ .topP(topP)
+ .repeatPenalty(repeatPenalty)
+ .seed(seed)
+ .numPredict(numPredict)
+ .stop(stop)
+ .build();
+ this.format = format;
}
@Override
@@ -31,9 +57,8 @@ public void generate(String prompt, StreamingResponseHandler handler) {
CompletionRequest request = CompletionRequest.builder()
.model(modelName)
.prompt(prompt)
- .options(Options.builder()
- .temperature(temperature)
- .build())
+ .options(options)
+ .format(format)
.stream(true)
.build();
diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Options.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Options.java
index eff2e4b38e0..65577719b9f 100644
--- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Options.java
+++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Options.java
@@ -5,6 +5,8 @@
import lombok.Data;
import lombok.NoArgsConstructor;
+import java.util.List;
+
/**
* request options in completion/embedding API
*
@@ -16,8 +18,11 @@
@Builder
class Options {
- /**
- * The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)
- */
private Double temperature;
+ private Integer topK;
+ private Double topP;
+ private Double repeatPenalty;
+ private Integer seed;
+ private Integer numPredict;
+ private List stop;
}
diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Role.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Role.java
index 26b4e8d77a2..eca1f02bdf0 100644
--- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Role.java
+++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Role.java
@@ -1,22 +1,8 @@
package dev.langchain4j.model.ollama;
-import dev.langchain4j.data.message.ChatMessageType;
-
enum Role {
+
SYSTEM,
USER,
ASSISTANT;
-
- public static Role fromChatMessageType(ChatMessageType chatMessageType) {
- switch (chatMessageType) {
- case SYSTEM:
- return SYSTEM;
- case USER:
- return USER;
- case AI:
- return Role.ASSISTANT;
- default:
- throw new IllegalArgumentException("Unknown ChatMessageType: " + chatMessageType);
- }
- }
}
\ No newline at end of file
diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaInfrastructure.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaInfrastructure.java
index 5bcbb5437c0..3845c83da56 100644
--- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaInfrastructure.java
+++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaInfrastructure.java
@@ -18,9 +18,9 @@ public class AbstractOllamaInfrastructure {
private static final String OLLAMA_IMAGE = "ollama/ollama:latest";
- private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-orca-mini", OLLAMA_IMAGE);
+ static final String MODEL = "phi";
- static final String ORCA_MINI_MODEL = "orca-mini";
+ private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OLLAMA_IMAGE, MODEL);
static OllamaContainer ollama;
@@ -58,18 +58,18 @@ static class OllamaContainer extends GenericContainer {
super(image.get());
this.dockerImageName = image.get();
withExposedPorts(11434);
- withImagePullPolicy(dockerImageName -> !dockerImageName.getVersionPart().endsWith(ORCA_MINI_MODEL));
+ withImagePullPolicy(dockerImageName -> !dockerImageName.getVersionPart().endsWith(MODEL));
}
@Override
protected void containerIsStarted(InspectContainerResponse containerInfo) {
if (!this.dockerImageName.equals(DockerImageName.parse(LOCAL_OLLAMA_IMAGE))) {
try {
- log.info("Start pulling the 'orca-mini' model (3GB) ... would take several minutes ...");
- execInContainer("ollama", "pull", ORCA_MINI_MODEL);
- log.info("orca-mini pulling competed!");
+ log.info("Start pulling the '{}' model ... would take several minutes ...", MODEL);
+ execInContainer("ollama", "pull", MODEL);
+ log.info("Model pulling competed!");
} catch (IOException | InterruptedException e) {
- throw new RuntimeException("Error pulling orca-mini model", e);
+ throw new RuntimeException("Error pulling model", e);
}
}
}
diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelIT.java
index ba2be46bfb1..d382aa0fe3c 100644
--- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelIT.java
+++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelIT.java
@@ -2,36 +2,126 @@
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.SystemMessage;
+import dev.langchain4j.data.message.UserMessage;
+import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
+import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
-import java.util.ArrayList;
import java.util.List;
-import static dev.langchain4j.data.message.SystemMessage.systemMessage;
-import static dev.langchain4j.data.message.UserMessage.userMessage;
+import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
class OllamaChatModelIT extends AbstractOllamaInfrastructure {
- OllamaChatModel model = OllamaChatModel.builder()
+ ChatLanguageModel model = OllamaChatModel.builder()
.baseUrl(getBaseUrl())
- .modelName(ORCA_MINI_MODEL)
+ .modelName(MODEL)
+ .temperature(0.0)
.build();
@Test
- void should_send_messages_with_roles_and_receive_response() {
+ void should_generate_answer() {
- List chatMessages = new ArrayList<>();
- chatMessages.add(systemMessage("You are a good friend of mine, who likes to answer politely"));
- chatMessages.add(userMessage("Hello!, How are you?"));
- chatMessages.add(AiMessage.aiMessage("I'm fine, thanks!"));
- chatMessages.add(userMessage("Not too bad, just enjoying a cup of coffee. What about you?"));
+ // given
+ String userMessage = "What is the capital of Germany?";
- Response response = model.generate(chatMessages);
+ // when
+ String answer = model.generate(userMessage);
+ System.out.println(answer);
+
+ // then
+ assertThat(answer).contains("Berlin");
+ }
+
+ @Test
+ void should_respect_numPredict() {
+
+ // given
+ int numPredict = 1; // max output tokens
+
+ OllamaChatModel model = OllamaChatModel.builder()
+ .baseUrl(getBaseUrl())
+ .modelName(MODEL)
+ .numPredict(numPredict)
+ .temperature(0.0)
+ .build();
+
+ UserMessage userMessage = UserMessage.from("What is the capital of Germany?");
+
+ // when
+ Response response = model.generate(userMessage);
System.out.println(response);
- assertThat(response).isNotNull();
- assertThat(response.content().text()).isNotEmpty();
+ // then
+ assertThat(response.content().text()).doesNotContain("Berlin");
+ assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(numPredict + 2); // bug in Ollama
+ }
+
+ @Test
+ void should_respect_system_message() {
+
+ // given
+ SystemMessage systemMessage = SystemMessage.from("Translate messages from user into German");
+ UserMessage userMessage = UserMessage.from("I love you");
+
+ // when
+ Response response = model.generate(systemMessage, userMessage);
+ System.out.println(response);
+
+ // then
+ assertThat(response.content().text()).containsIgnoringCase("liebe");
+
+ TokenUsage tokenUsage = response.tokenUsage();
+ assertThat(tokenUsage.inputTokenCount()).isEqualTo(18);
+ assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
+ assertThat(tokenUsage.totalTokenCount())
+ .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
+
+ assertThat(response.finishReason()).isNull();
+ }
+
+ @Test
+ void should_respond_to_few_shot() {
+
+ // given
+ List messages = asList(
+ UserMessage.from("1 + 1 ="),
+ AiMessage.from(">>> 2"),
+
+ UserMessage.from("2 + 2 ="),
+ AiMessage.from(">>> 4"),
+
+ UserMessage.from("4 + 4 =")
+ );
+
+ // when
+ Response response = model.generate(messages);
+ System.out.println(response);
+
+ // then
+ assertThat(response.content().text()).isEqualTo(">>> 8");
+ }
+
+ @Test
+ void should_generate_valid_json() {
+
+ // given
+ ChatLanguageModel model = OllamaChatModel.builder()
+ .baseUrl(getBaseUrl())
+ .modelName(MODEL)
+ .format("json")
+ .temperature(0.0)
+ .build();
+
+ String userMessage = "Return JSON with two fields: name and age of John Doe, 42 years old.";
+
+ // when
+ String json = model.generate(userMessage);
+
+ // then
+ assertThat(json).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}");
}
}
\ No newline at end of file
diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaEmbeddingModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaEmbeddingModelIT.java
index 5f16ec2f3b3..3f53076fa95 100644
--- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaEmbeddingModelIT.java
+++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaEmbeddingModelIT.java
@@ -11,15 +11,23 @@ class OllamaEmbeddingModelIT extends AbstractOllamaInfrastructure {
EmbeddingModel model = OllamaEmbeddingModel.builder()
.baseUrl(getBaseUrl())
- .modelName(ORCA_MINI_MODEL)
+ .modelName(MODEL)
.build();
@Test
void should_embed() {
- Response response = model.embed("hello world");
+ // given
+ String text = "hello world";
+
+ // when
+ Response response = model.embed(text);
System.out.println(response);
- assertThat(response.content().vector()).isNotEmpty();
+ // then
+ assertThat(response.content().vector()).isNotEmpty();
+
+ assertThat(response.tokenUsage()).isNull();
+ assertThat(response.finishReason()).isNull();
}
}
diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaLanguageModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaLanguageModelIT.java
index 73df2a3ceb9..040895d0fa1 100644
--- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaLanguageModelIT.java
+++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaLanguageModelIT.java
@@ -2,6 +2,7 @@
import dev.langchain4j.model.language.LanguageModel;
import dev.langchain4j.model.output.Response;
+import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
@@ -10,18 +11,73 @@ class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
LanguageModel model = OllamaLanguageModel.builder()
.baseUrl(getBaseUrl())
- .modelName(ORCA_MINI_MODEL)
+ .modelName(MODEL)
+ .temperature(0.0)
.build();
@Test
void should_generate_answer() {
- String prompt = "Hello, how are you?";
+ // given
+ String userMessage = "What is the capital of Germany?";
+ // when
+ Response response = model.generate(userMessage);
+ System.out.println(response);
+
+ // then
+ assertThat(response.content()).contains("Berlin");
+
+ TokenUsage tokenUsage = response.tokenUsage();
+ assertThat(tokenUsage.inputTokenCount()).isEqualTo(43);
+ assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
+ assertThat(tokenUsage.totalTokenCount())
+ .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
+
+ assertThat(response.finishReason()).isNull();
+ }
+
+ @Test
+ void should_respect_numPredict() {
+
+ // given
+ int numPredict = 1; // max output tokens
+
+ LanguageModel model = OllamaLanguageModel.builder()
+ .baseUrl(getBaseUrl())
+ .modelName(MODEL)
+ .numPredict(numPredict)
+ .temperature(0.0)
+ .build();
+
+ String prompt = "What is the capital of Germany?";
+
+ // when
Response response = model.generate(prompt);
System.out.println(response);
- assertThat(response.content()).isNotBlank();
- assertThat(response.tokenUsage()).isNotNull();
+ // then
+ assertThat(response.content()).doesNotContain("Berlin");
+ assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(numPredict + 2); // bug in Ollama
+ }
+
+ @Test
+ void should_generate_valid_json() {
+
+ // given
+ LanguageModel model = OllamaLanguageModel.builder()
+ .baseUrl(getBaseUrl())
+ .modelName(MODEL)
+ .format("json")
+ .temperature(0.0)
+ .build();
+
+ String userMessage = "Return JSON with two fields: name and age of John Doe, 42 years old.";
+
+ // when
+ Response response = model.generate(userMessage);
+
+ // then
+ assertThat(response.content()).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}");
}
}
diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelIT.java
new file mode 100644
index 00000000000..a375d319d08
--- /dev/null
+++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelIT.java
@@ -0,0 +1,276 @@
+package dev.langchain4j.model.ollama;
+
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.SystemMessage;
+import dev.langchain4j.data.message.UserMessage;
+import dev.langchain4j.model.StreamingResponseHandler;
+import dev.langchain4j.model.chat.StreamingChatLanguageModel;
+import dev.langchain4j.model.output.Response;
+import dev.langchain4j.model.output.TokenUsage;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+
+import static java.util.Arrays.asList;
+import static java.util.Collections.singletonList;
+import static java.util.concurrent.TimeUnit.SECONDS;
+import static org.assertj.core.api.Assertions.assertThat;
+
+class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
+
+ StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
+ .baseUrl(getBaseUrl())
+ .modelName(MODEL)
+ .temperature(0.0)
+ .build();
+
+ @Test
+ void should_stream_answer() throws Exception {
+
+ // given
+ String userMessage = "What is the capital of Germany?";
+
+ // when
+ CompletableFuture futureAnswer = new CompletableFuture<>();
+ CompletableFuture> futureResponse = new CompletableFuture<>();
+
+ model.generate(userMessage, new StreamingResponseHandler() {
+
+ private final StringBuilder answerBuilder = new StringBuilder();
+
+ @Override
+ public void onNext(String token) {
+ System.out.println("onNext: '" + token + "'");
+ answerBuilder.append(token);
+ }
+
+ @Override
+ public void onComplete(Response response) {
+ System.out.println("onComplete: '" + response + "'");
+ futureAnswer.complete(answerBuilder.toString());
+ futureResponse.complete(response);
+ }
+
+ @Override
+ public void onError(Throwable error) {
+ futureAnswer.completeExceptionally(error);
+ futureResponse.completeExceptionally(error);
+ }
+ });
+
+ String answer = futureAnswer.get(30, SECONDS);
+ Response response = futureResponse.get(30, SECONDS);
+
+ // then
+ assertThat(answer).contains("Berlin");
+ assertThat(response.content().text()).isEqualTo(answer);
+
+ TokenUsage tokenUsage = response.tokenUsage();
+ assertThat(tokenUsage.inputTokenCount()).isEqualTo(43);
+ assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
+ assertThat(tokenUsage.totalTokenCount())
+ .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
+
+ assertThat(response.finishReason()).isNull();
+ }
+
+ @Test
+ void should_respect_numPredict() throws Exception {
+
+ // given
+ int numPredict = 1; // max output tokens
+
+ StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
+ .baseUrl(getBaseUrl())
+ .modelName(MODEL)
+ .numPredict(numPredict)
+ .temperature(0.0)
+ .build();
+
+ UserMessage userMessage = UserMessage.from("What is the capital of Germany?");
+
+ // when
+ CompletableFuture futureAnswer = new CompletableFuture<>();
+ CompletableFuture> futureResponse = new CompletableFuture<>();
+
+ model.generate(singletonList(userMessage), new StreamingResponseHandler() {
+
+ private final StringBuilder answerBuilder = new StringBuilder();
+
+ @Override
+ public void onNext(String token) {
+ System.out.println("onNext: '" + token + "'");
+ answerBuilder.append(token);
+ }
+
+ @Override
+ public void onComplete(Response response) {
+ System.out.println("onComplete: '" + response + "'");
+ futureAnswer.complete(answerBuilder.toString());
+ futureResponse.complete(response);
+ }
+
+ @Override
+ public void onError(Throwable error) {
+ futureAnswer.completeExceptionally(error);
+ futureResponse.completeExceptionally(error);
+ }
+ });
+
+ String answer = futureAnswer.get(30, SECONDS);
+ Response response = futureResponse.get(30, SECONDS);
+
+ // then
+ assertThat(answer).doesNotContain("Berlin");
+ assertThat(response.content().text()).isEqualTo(answer);
+
+ assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(numPredict + 2); // bug in Ollama
+ }
+
+
+ @Test
+ void should_respect_system_message() throws Exception {
+
+ // given
+ SystemMessage systemMessage = SystemMessage.from("Translate messages from user into German");
+ UserMessage userMessage = UserMessage.from("I love you");
+
+ // when
+ CompletableFuture futureAnswer = new CompletableFuture<>();
+ CompletableFuture> futureResponse = new CompletableFuture<>();
+
+ model.generate(asList(systemMessage, userMessage), new StreamingResponseHandler() {
+
+ private final StringBuilder answerBuilder = new StringBuilder();
+
+ @Override
+ public void onNext(String token) {
+ System.out.println("onNext: '" + token + "'");
+ answerBuilder.append(token);
+ }
+
+ @Override
+ public void onComplete(Response response) {
+ System.out.println("onComplete: '" + response + "'");
+ futureAnswer.complete(answerBuilder.toString());
+ futureResponse.complete(response);
+ }
+
+ @Override
+ public void onError(Throwable error) {
+ futureAnswer.completeExceptionally(error);
+ futureResponse.completeExceptionally(error);
+ }
+ });
+
+ String answer = futureAnswer.get(30, SECONDS);
+ Response response = futureResponse.get(30, SECONDS);
+
+ // then
+ assertThat(answer).containsIgnoringCase("liebe");
+ assertThat(response.content().text()).isEqualTo(answer);
+
+ assertThat(response.finishReason()).isNull();
+ }
+
+ @Test
+ void should_respond_to_few_shot() throws Exception {
+
+ // given
+ List messages = asList(
+ UserMessage.from("1 + 1 ="),
+ AiMessage.from(">>> 2"),
+
+ UserMessage.from("2 + 2 ="),
+ AiMessage.from(">>> 4"),
+
+ UserMessage.from("4 + 4 =")
+ );
+
+ // when
+ CompletableFuture futureAnswer = new CompletableFuture<>();
+ CompletableFuture> futureResponse = new CompletableFuture<>();
+
+ model.generate(messages, new StreamingResponseHandler() {
+
+ private final StringBuilder answerBuilder = new StringBuilder();
+
+ @Override
+ public void onNext(String token) {
+ System.out.println("onNext: '" + token + "'");
+ answerBuilder.append(token);
+ }
+
+ @Override
+ public void onComplete(Response response) {
+ System.out.println("onComplete: '" + response + "'");
+ futureAnswer.complete(answerBuilder.toString());
+ futureResponse.complete(response);
+ }
+
+ @Override
+ public void onError(Throwable error) {
+ futureAnswer.completeExceptionally(error);
+ futureResponse.completeExceptionally(error);
+ }
+ });
+
+ String answer = futureAnswer.get(30, SECONDS);
+ Response response = futureResponse.get(30, SECONDS);
+
+ // then
+ assertThat(answer).isEqualTo(">>> 8");
+ assertThat(response.content().text()).isEqualTo(answer);
+ }
+
+ @Test
+ void should_generate_valid_json() throws Exception {
+
+ // given
+ StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
+ .baseUrl(getBaseUrl())
+ .modelName(MODEL)
+ .format("json")
+ .temperature(0.0)
+ .build();
+
+ String userMessage = "Return JSON with two fields: name and age of John Doe, 42 years old.";
+
+ // when
+ CompletableFuture futureAnswer = new CompletableFuture<>();
+ CompletableFuture> futureResponse = new CompletableFuture<>();
+
+ model.generate(userMessage, new StreamingResponseHandler() {
+
+ private final StringBuilder answerBuilder = new StringBuilder();
+
+ @Override
+ public void onNext(String token) {
+ System.out.println("onNext: '" + token + "'");
+ answerBuilder.append(token);
+ }
+
+ @Override
+ public void onComplete(Response response) {
+ System.out.println("onComplete: '" + response + "'");
+ futureAnswer.complete(answerBuilder.toString());
+ futureResponse.complete(response);
+ }
+
+ @Override
+ public void onError(Throwable error) {
+ futureAnswer.completeExceptionally(error);
+ futureResponse.completeExceptionally(error);
+ }
+ });
+
+ String answer = futureAnswer.get(30, SECONDS);
+ Response response = futureResponse.get(30, SECONDS);
+
+ // then
+ assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}");
+ assertThat(response.content().text()).isEqualTo(answer);
+ }
+}
\ No newline at end of file
diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModelIT.java
index afef80ba61c..b1cbb997066 100644
--- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModelIT.java
+++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModelIT.java
@@ -3,29 +3,33 @@
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.language.StreamingLanguageModel;
import dev.langchain4j.model.output.Response;
+import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.TimeoutException;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
- StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
- .baseUrl(getBaseUrl())
- .modelName(ORCA_MINI_MODEL)
- .build();
-
@Test
- void should_stream_answer() throws ExecutionException, InterruptedException, TimeoutException {
+ void should_stream_answer() throws Exception {
+
+ // given
+ String prompt = "What is the capital of Germany?";
+
+ StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
+ .baseUrl(getBaseUrl())
+ .modelName(MODEL)
+ .temperature(0.0)
+ .build();
+ // when
CompletableFuture futureAnswer = new CompletableFuture<>();
CompletableFuture> futureResponse = new CompletableFuture<>();
- model.generate("What is the capital of Germany?", new StreamingResponseHandler() {
+ model.generate(prompt, new StreamingResponseHandler() {
private final StringBuilder answerBuilder = new StringBuilder();
@@ -52,7 +56,118 @@ public void onError(Throwable error) {
String answer = futureAnswer.get(30, SECONDS);
Response response = futureResponse.get(30, SECONDS);
+ // then
assertThat(answer).contains("Berlin");
assertThat(response.content()).isEqualTo(answer);
+
+ TokenUsage tokenUsage = response.tokenUsage();
+ assertThat(tokenUsage.inputTokenCount()).isEqualTo(43);
+ assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
+ assertThat(tokenUsage.totalTokenCount())
+ .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
+
+ assertThat(response.finishReason()).isNull();
+ }
+
+ @Test
+ void should_respect_numPredict() throws Exception {
+
+ // given
+ int numPredict = 1; // max output tokens
+
+ StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
+ .baseUrl(getBaseUrl())
+ .modelName(MODEL)
+ .numPredict(numPredict)
+ .temperature(0.0)
+ .build();
+
+ String prompt = "What is the capital of Germany?";
+
+ // when
+ CompletableFuture futureAnswer = new CompletableFuture<>();
+ CompletableFuture> futureResponse = new CompletableFuture<>();
+
+ model.generate(prompt, new StreamingResponseHandler() {
+
+ private final StringBuilder answerBuilder = new StringBuilder();
+
+ @Override
+ public void onNext(String token) {
+ System.out.println("onNext: '" + token + "'");
+ answerBuilder.append(token);
+ }
+
+ @Override
+ public void onComplete(Response response) {
+ System.out.println("onComplete: '" + response + "'");
+ futureAnswer.complete(answerBuilder.toString());
+ futureResponse.complete(response);
+ }
+
+ @Override
+ public void onError(Throwable error) {
+ futureAnswer.completeExceptionally(error);
+ futureResponse.completeExceptionally(error);
+ }
+ });
+
+ String answer = futureAnswer.get(30, SECONDS);
+ Response response = futureResponse.get(30, SECONDS);
+
+ // then
+ assertThat(answer).doesNotContain("Berlin");
+ assertThat(response.content()).isEqualTo(answer);
+
+ assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(numPredict + 2); // bug in Ollama
+ }
+
+ @Test
+ void should_stream_valid_json() throws Exception {
+
+ // given
+ StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
+ .baseUrl(getBaseUrl())
+ .modelName(MODEL)
+ .format("json")
+ .temperature(0.0)
+ .build();
+
+ String prompt = "Return JSON with two fields: name and age of John Doe, 42 years old.";
+
+ // when
+ CompletableFuture futureAnswer = new CompletableFuture<>();
+ CompletableFuture> futureResponse = new CompletableFuture<>();
+
+ model.generate(prompt, new StreamingResponseHandler() {
+
+ private final StringBuilder answerBuilder = new StringBuilder();
+
+ @Override
+ public void onNext(String token) {
+ System.out.println("onNext: '" + token + "'");
+ answerBuilder.append(token);
+ }
+
+ @Override
+ public void onComplete(Response response) {
+ System.out.println("onComplete: '" + response + "'");
+ futureAnswer.complete(answerBuilder.toString());
+ futureResponse.complete(response);
+ }
+
+ @Override
+ public void onError(Throwable error) {
+ futureAnswer.completeExceptionally(error);
+ futureResponse.completeExceptionally(error);
+ }
+ });
+
+ String answer = futureAnswer.get(30, SECONDS);
+ Response response = futureResponse.get(30, SECONDS);
+
+ // then
+ assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}");
+ assertThat(response.content()).isEqualTo(answer);
}
}