Skip to content

Commit

Permalink
Ollama: add OllamaStreamingChatModel, "format" (json) and other param…
Browse files Browse the repository at this point in the history
…eters (langchain4j#373)

- added `OllamaStreamingChatModel`
- added `format` parameter to all models, now can get valid JSON with
`format="json"`
- added `top_k`, `top_p`, `repeat_penalty`, `seed`, `num_predict`,
`stop` paramerters to all models
  • Loading branch information
dliubarskyi authored Dec 21, 2023
1 parent 798a474 commit 7181633
Show file tree
Hide file tree
Showing 20 changed files with 903 additions and 144 deletions.
12 changes: 12 additions & 0 deletions langchain4j-ollama/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@
<artifactId>testcontainers</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
@Builder
class ChatRequest {

/**
* model name
*/
private String model;
private List<Message> messages;
private Options options;
private String format;
private Boolean stream;
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
@AllArgsConstructor
@Builder
class ChatResponse {

private String model;
private String createdAt;
private Message message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,10 @@
@Builder
class CompletionRequest {

/**
* model name
*/
private String model;
/**
* the prompt to generate a response for
*/
private String system;
private String prompt;
private Options options;
private String system;
private String format;
private Boolean stream;
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
@AllArgsConstructor
@Builder
class Message {

private Role role;
private String content;
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
import retrofit2.http.POST;
import retrofit2.http.Streaming;

/**
* Ollama API definitions using retrofit2
*/
public interface OllamaApi {
interface OllamaApi {

@POST("/api/generate")
@Headers({"Content-Type: application/json"})
Expand All @@ -28,4 +25,9 @@ public interface OllamaApi {
@POST("/api/chat")
@Headers({"Content-Type: application/json"})
Call<ChatResponse> chat(@Body ChatRequest chatRequest);

@POST("/api/chat")
@Headers({"Content-Type: application/json"})
@Streaming
Call<ResponseBody> streamingChat(@Body ChatRequest chatRequest);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,105 @@

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static java.time.Duration.ofSeconds;
import static java.util.stream.Collectors.toList;

/**
* Ollama chat model implementation.
* <a href="https://github.com/jmorganca/ollama/blob/main/docs/api.md">Ollama API reference</a>
* <br>
* <a href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">Ollama API parameters</a>.
*/
public class OllamaChatModel implements ChatLanguageModel {

private final OllamaClient client;
private final Double temperature;
private final String modelName;
private final Options options;
private final String format;
private final Integer maxRetries;

@Builder
public OllamaChatModel(String baseUrl, Duration timeout,
String modelName, Double temperature, Integer maxRetries) {
this.client = OllamaClient.builder().baseUrl(baseUrl).timeout(timeout).build();
public OllamaChatModel(String baseUrl,
String modelName,
Double temperature,
Integer topK,
Double topP,
Double repeatPenalty,
Integer seed,
Integer numPredict,
List<String> stop,
String format,
Duration timeout,
Integer maxRetries) {
this.client = OllamaClient.builder()
.baseUrl(baseUrl)
.timeout(getOrDefault(timeout, ofSeconds(60)))
.build();
this.modelName = ensureNotBlank(modelName, "modelName");
this.temperature = getOrDefault(temperature, 0.7);
this.options = Options.builder()
.temperature(temperature)
.topK(topK)
.topP(topP)
.repeatPenalty(repeatPenalty)
.seed(seed)
.numPredict(numPredict)
.stop(stop)
.build();
this.format = format;
this.maxRetries = getOrDefault(maxRetries, 3);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
if (messages == null || messages.isEmpty()) {
throw new IllegalArgumentException("messages must not be null or empty");
}

ArrayList<Message> messageList = new ArrayList<>();

messages.forEach(message -> {
Role role = Role.fromChatMessageType(message.type());
messageList.add(Message.builder()
.role(role)
.content(message.text())
.build());
});
ensureNotEmpty(messages, "messages");

ChatRequest request = ChatRequest.builder()
.model(modelName)
.messages(messageList)
.options(Options.builder()
.temperature(temperature)
.build())
.messages(toOllamaMessages(messages))
.options(options)
.format(format)
.stream(false)
.build();

ChatResponse response = withRetry(() -> client.chat(request), maxRetries);

return Response.from(AiMessage.from(response.getMessage().getContent()));
return Response.from(
AiMessage.from(response.getMessage().getContent()),
new TokenUsage(response.getPromptEvalCount(), response.getEvalCount())
);
}

static List<Message> toOllamaMessages(List<ChatMessage> messages) {
return messages.stream()
.map(message -> Message.builder()
.role(toOllamaRole(message.type()))
.content(message.text())
.build())
.collect(toList());
}

private static Role toOllamaRole(ChatMessageType chatMessageType) {
switch (chatMessageType) {
case SYSTEM:
return Role.SYSTEM;
case USER:
return Role.USER;
case AI:
return Role.ASSISTANT;
default:
throw new IllegalArgumentException("Unknown ChatMessageType: " + chatMessageType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,35 @@

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;
import okhttp3.OkHttpClient;
import okhttp3.ResponseBody;
import retrofit2.Call;
import retrofit2.Callback;
import retrofit2.Response;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;

import java.io.IOException;
import java.io.InputStream;
import java.time.Duration;
import java.util.Optional;

import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static java.time.Duration.ofSeconds;
import static java.lang.Boolean.TRUE;

class OllamaClient {

private final OllamaApi ollamaApi;
private static final Gson GSON = new GsonBuilder().setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES)
private static final Gson GSON = new GsonBuilder()
.setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES)
.create();

private final OllamaApi ollamaApi;

@Builder
public OllamaClient(String baseUrl, Duration timeout) {
timeout = getOrDefault(timeout, ofSeconds(60));

OkHttpClient okHttpClient = new OkHttpClient.Builder()
.callTimeout(timeout)
Expand All @@ -50,7 +50,7 @@ public OllamaClient(String baseUrl, Duration timeout) {

public CompletionResponse completion(CompletionRequest request) {
try {
Response<CompletionResponse> retrofitResponse
retrofit2.Response<CompletionResponse> retrofitResponse
= ollamaApi.completion(request).execute();

if (retrofitResponse.isSuccessful()) {
Expand All @@ -65,7 +65,7 @@ public CompletionResponse completion(CompletionRequest request) {

public ChatResponse chat(ChatRequest request) {
try {
Response<ChatResponse> retrofitResponse
retrofit2.Response<ChatResponse> retrofitResponse
= ollamaApi.chat(request).execute();

if (retrofitResponse.isSuccessful()) {
Expand All @@ -80,32 +80,72 @@ public ChatResponse chat(ChatRequest request) {

public void streamingCompletion(CompletionRequest request, StreamingResponseHandler<String> handler) {
ollamaApi.streamingCompletion(request).enqueue(new Callback<ResponseBody>() {

@Override
public void onResponse(Call<ResponseBody> call, Response<ResponseBody> response) {
try (InputStream inputStream = response.body().byteStream()) {
StringBuilder content = new StringBuilder();
int inputTokenCount = 0;
int outputTokenCount = 0;
public void onResponse(Call<ResponseBody> call, retrofit2.Response<ResponseBody> retrofitResponse) {
try (InputStream inputStream = retrofitResponse.body().byteStream()) {
StringBuilder contentBuilder = new StringBuilder();
while (true) {
byte[] bytes = new byte[1024];
int len = inputStream.read(bytes);
String partialResponse = new String(bytes, 0, len);
CompletionResponse completionResponse = GSON.fromJson(partialResponse, CompletionResponse.class);

// finish streaming response
if (Boolean.TRUE.equals(completionResponse.getDone())) {
handler.onComplete(dev.langchain4j.model.output.Response.from(
content.toString(),
new TokenUsage(inputTokenCount, outputTokenCount)
));
break;
contentBuilder.append(completionResponse.getResponse());
handler.onNext(completionResponse.getResponse());

if (TRUE.equals(completionResponse.getDone())) {
Response<String> response = Response.from(
contentBuilder.toString(),
new TokenUsage(
completionResponse.getPromptEvalCount(),
completionResponse.getEvalCount()
)
);
handler.onComplete(response);
return;
}
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}

// handle cur token and tokenUsage
content.append(completionResponse.getResponse());
inputTokenCount += Optional.ofNullable(completionResponse.getPromptEvalCount()).orElse(0);
outputTokenCount += Optional.ofNullable(completionResponse.getEvalCount()).orElse(0);
handler.onNext(completionResponse.getResponse());
@Override
public void onFailure(Call<ResponseBody> call, Throwable throwable) {
handler.onError(throwable);
}
});
}

public void streamingChat(ChatRequest request, StreamingResponseHandler<AiMessage> handler) {
ollamaApi.streamingChat(request).enqueue(new Callback<ResponseBody>() {

@Override
public void onResponse(Call<ResponseBody> call, retrofit2.Response<ResponseBody> retrofitResponse) {
try (InputStream inputStream = retrofitResponse.body().byteStream()) {
StringBuilder contentBuilder = new StringBuilder();
while (true) {
byte[] bytes = new byte[1024];
int len = inputStream.read(bytes);
String partialResponse = new String(bytes, 0, len);
ChatResponse chatResponse = GSON.fromJson(partialResponse, ChatResponse.class);

String content = chatResponse.getMessage().getContent();
contentBuilder.append(content);
handler.onNext(content);

if (TRUE.equals(chatResponse.getDone())) {
Response<AiMessage> response = Response.from(
AiMessage.from(contentBuilder.toString()),
new TokenUsage(
chatResponse.getPromptEvalCount(),
chatResponse.getEvalCount()
)
);
handler.onComplete(response);
return;
}
}
} catch (IOException e) {
throw new RuntimeException(e);
Expand All @@ -132,7 +172,7 @@ public EmbeddingResponse embed(EmbeddingRequest request) {
}
}

private RuntimeException toException(Response<?> response) throws IOException {
private RuntimeException toException(retrofit2.Response<?> response) throws IOException {
int code = response.code();
String body = response.errorBody().string();

Expand Down
Loading

0 comments on commit 7181633

Please sign in to comment.