Skip to content

Commit

Permalink
Migrate Ollama from gson to jackson (langchain4j#1697)
Browse files Browse the repository at this point in the history
## Issue
Closes langchain4j#1688 

## Change
Migrate Ollama from gson to jackson

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [ ] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
  • Loading branch information
Martin7-1 authored Sep 10, 2024
1 parent 0eb6de2 commit ff243b0
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package dev.langchain4j.model.ollama;

import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.StreamingResponseHandler;
Expand All @@ -27,15 +26,13 @@
import java.util.Map;
import java.util.Optional;

import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT;
import static dev.langchain4j.model.ollama.OllamaJsonUtils.getObjectMapper;
import static dev.langchain4j.model.ollama.OllamaJsonUtils.toObject;
import static java.lang.Boolean.TRUE;

@Slf4j
class OllamaClient {

private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper()
.enable(INDENT_OUTPUT);

private final OllamaApi ollamaApi;
private final boolean logStreamingResponses;

Expand Down Expand Up @@ -66,7 +63,7 @@ public OllamaClient(String baseUrl,
Retrofit retrofit = new Retrofit.Builder()
.baseUrl(Utils.ensureTrailingForwardSlash(baseUrl))
.client(okHttpClient)
.addConverterFactory(JacksonConverterFactory.create(OBJECT_MAPPER))
.addConverterFactory(JacksonConverterFactory.create(getObjectMapper()))
.build();

ollamaApi = retrofit.create(OllamaApi.class);
Expand Down Expand Up @@ -118,7 +115,7 @@ public void onResponse(Call<ResponseBody> call, retrofit2.Response<ResponseBody>
log.debug("Streaming partial response: {}", partialResponse);
}

CompletionResponse completionResponse = OBJECT_MAPPER.readValue(partialResponse, CompletionResponse.class);
CompletionResponse completionResponse = toObject(partialResponse, CompletionResponse.class);
contentBuilder.append(completionResponse.getResponse());
handler.onNext(completionResponse.getResponse());

Expand Down Expand Up @@ -161,7 +158,7 @@ public void onResponse(Call<ResponseBody> call, retrofit2.Response<ResponseBody>
log.debug("Streaming partial response: {}", partialResponse);
}

ChatResponse chatResponse = OBJECT_MAPPER.readValue(partialResponse, ChatResponse.class);
ChatResponse chatResponse = toObject(partialResponse, ChatResponse.class);
String content = chatResponse.getMessage().getContent();
contentBuilder.append(content);
handler.onNext(content);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package dev.langchain4j.model.ollama;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;

import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT;

class OllamaJsonUtils {

private OllamaJsonUtils() throws InstantiationException {
throw new InstantiationException("Can't instantiate this utility class.");
}

private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper()
.enable(INDENT_OUTPUT);

static String toJson(Object object) {
try {
return OBJECT_MAPPER.writeValueAsString(object);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}

static <T> T toObject(String jsonStr, Class<T> clazz) {
try {
return OBJECT_MAPPER.readValue(jsonStr, clazz);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}

static <T> T toObject(String jsonStr, TypeReference<T> typeReference) {
try {
return OBJECT_MAPPER.readValue(jsonStr, typeReference);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}

static ObjectMapper getObjectMapper() {
return OBJECT_MAPPER;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
import dev.langchain4j.internal.Json;

import java.util.List;
import java.util.Map;
Expand All @@ -12,6 +11,7 @@

import static dev.langchain4j.data.message.ContentType.IMAGE;
import static dev.langchain4j.data.message.ContentType.TEXT;
import static dev.langchain4j.model.ollama.OllamaJsonUtils.toJson;

class OllamaMessagesUtils {

Expand Down Expand Up @@ -48,7 +48,7 @@ static List<ToolExecutionRequest> toToolExecutionRequest(List<ToolCall> toolCall
return toolCalls.stream().map(toolCall ->
ToolExecutionRequest.builder()
.name(toolCall.getFunction().getName())
.arguments(Json.toJson(toolCall.getFunction().getArguments()))
.arguments(toJson(toolCall.getFunction().getArguments()))
.build())
.collect(Collectors.toList());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
package dev.langchain4j.model.ollama;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import okhttp3.mockwebserver.Dispatcher;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.time.Duration;

import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES;
import static org.assertj.core.api.Assertions.assertThat;

public class OllamaApiIT {

private static MockWebServer mockWebServer;

private static final Gson GSON = new GsonBuilder()
.setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES)
.create();
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();

@BeforeAll
public static void init() throws IOException {
Expand All @@ -40,9 +36,16 @@ public MockResponse dispatch(RecordedRequest recordedRequest) {
.message(message)
.build();

String jsonBody;
try {
jsonBody = OBJECT_MAPPER.writeValueAsString(chatResponse);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}

return new MockResponse()
.setResponseCode(200)
.setBody(GSON.toJson(chatResponse));
.setBody(jsonBody);
}
};
mockWebServer.setDispatcher(dispatcher);
Expand Down

0 comments on commit ff243b0

Please sign in to comment.