Skip to content

Commit

Permalink
Google AI Gemini: replace OkHttp and Retrofit with Java 11 HttpClient (
Browse files Browse the repository at this point in the history
…langchain4j#1950)

## Issue
Based on langchain4j#1903

## Change
Replaced OkHttp and Retrofit inside the GeminiService with an
implementation using the HttpClient (Java 11).


## General checklist
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change (already
existing)
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [X] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
  • Loading branch information
Bjarne-Kinkel authored Oct 23, 2024
1 parent 17d8384 commit 2ae3983
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 302 deletions.
33 changes: 0 additions & 33 deletions langchain4j-google-ai-gemini/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,6 @@
<artifactId>langchain4j-core</artifactId>
</dependency>

<!-- Retrofit REST client -->
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>retrofit</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>converter-gson</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>logging-interceptor</artifactId>
<version>4.12.0</version>
<scope>compile</scope>
</dependency>

<!--
<dependency>
<groupId>io.reactivex.rxjava3</groupId>
<artifactId>rxjava</artifactId>
<version>3.1.9</version>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>adapter-rxjava3</artifactId>
<version>2.11.0</version>
</dependency>
-->

<!-- Lombok for @Data and @Builder -->
<dependency>
<groupId>org.projectlombok</groupId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,85 +1,96 @@
package dev.langchain4j.model.googleai;

//import io.reactivex.rxjava3.core.Observable;
import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import org.slf4j.Logger;
import retrofit2.Call;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;
import retrofit2.http.Body;
import retrofit2.http.POST;
import retrofit2.http.Path;
import retrofit2.http.Header;
import retrofit2.http.Headers;

import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
//import retrofit2.http.Streaming;

interface GeminiService {
String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/";
String API_KEY_HEADER_NAME = "x-goog-api-key";
String USER_AGENT = "User-Agent: LangChain4j";
class GeminiService {
private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta";
private static final String API_KEY_HEADER_NAME = "x-goog-api-key";

static GeminiService getGeminiService(Logger logger, Duration timeout) {
Retrofit.Builder retrofitBuilder = new Retrofit.Builder()
.baseUrl(GEMINI_AI_ENDPOINT)
.addConverterFactory(GsonConverterFactory.create());
private final HttpClient httpClient;
private final Gson gson;
private final Logger logger;

OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder()
.callTimeout(timeout);
GeminiService(Logger logger, Duration timeout) {
this.logger = logger;
this.gson = new GsonBuilder().setPrettyPrinting().create();

if (logger != null) {
HttpLoggingInterceptor logging = new HttpLoggingInterceptor(logger::debug);
logging.redactHeader(API_KEY_HEADER_NAME);
logging.setLevel(HttpLoggingInterceptor.Level.BODY);
this.httpClient = HttpClient.newBuilder()
.connectTimeout(timeout)
.build();
}

GeminiGenerateContentResponse generateContent(String modelName, String apiKey, GeminiGenerateContentRequest request) {
String url = String.format("%s/models/%s:generateContent", GEMINI_AI_ENDPOINT, modelName);
return sendRequest(url, apiKey, request, GeminiGenerateContentResponse.class);
}

GeminiCountTokensResponse countTokens(String modelName, String apiKey, GeminiCountTokensRequest request) {
String url = String.format("%s/models/%s:countTokens", GEMINI_AI_ENDPOINT, modelName);
return sendRequest(url, apiKey, request, GeminiCountTokensResponse.class);
}

GoogleAiEmbeddingResponse embed(String modelName, String apiKey, GoogleAiEmbeddingRequest request) {
String url = String.format("%s/models/%s:embedContent", GEMINI_AI_ENDPOINT, modelName);
return sendRequest(url, apiKey, request, GoogleAiEmbeddingResponse.class);
}

GoogleAiBatchEmbeddingResponse batchEmbed(String modelName, String apiKey, GoogleAiBatchEmbeddingRequest request) {
String url = String.format("%s/models/%s:batchEmbedContents", GEMINI_AI_ENDPOINT, modelName);
return sendRequest(url, apiKey, request, GoogleAiBatchEmbeddingResponse.class);
}

private <T> T sendRequest(String url, String apiKey, Object requestBody, Class<T> responseType) {
String jsonBody = gson.toJson(requestBody);
HttpRequest request = buildHttpRequest(url, apiKey, jsonBody);

clientBuilder.addInterceptor(logging);
logRequest(jsonBody);

try {
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());

if (response.statusCode() >= 300) {
throw new RuntimeException(String.format("HTTP error (%d): %s", response.statusCode(), response.body()));
}

logResponse(response.body());

return gson.fromJson(response.body(), responseType);
} catch (IOException e) {
throw new RuntimeException("An error occurred while sending the request", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Sending the request was interrupted", e);
}
}

retrofitBuilder.client(clientBuilder.build());
Retrofit retrofit = retrofitBuilder.build();
private HttpRequest buildHttpRequest(String url, String apiKey, String jsonBody) {
return HttpRequest.newBuilder()
.uri(URI.create(url))
.header("Content-Type", "application/json")
.header("User-Agent", "LangChain4j")
.header(API_KEY_HEADER_NAME, apiKey)
.POST(HttpRequest.BodyPublishers.ofString(jsonBody))
.build();
}

return retrofit.create(GeminiService.class);
private void logRequest(String jsonBody) {
if (logger != null) {
logger.debug("Sending request to Gemini:\n{}", jsonBody);
}
}

@POST("models/{model}:generateContent")
@Headers(USER_AGENT)
Call<GeminiGenerateContentResponse> generateContent(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GeminiGenerateContentRequest request);

@POST("models/{model}:countTokens")
@Headers(USER_AGENT)
Call<GeminiCountTokensResponse> countTokens(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GeminiCountTokensRequest countTokensRequest);

@POST("models/{model}:embedContent")
@Headers(USER_AGENT)
Call<GoogleAiEmbeddingResponse> embed(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GoogleAiEmbeddingRequest embeddingRequest);

@POST("models/{model}:batchEmbedContents")
@Headers(USER_AGENT)
Call<GoogleAiBatchEmbeddingResponse> batchEmbed(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GoogleAiBatchEmbeddingRequest batchEmbeddingRequest);

/*
@Streaming
@POST("models/{model}:streamGenerateContent")
@Headers("User-Agent: LangChain4j")
Observable<GeminiGenerateContentResponse> streamGenerateContent(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GeminiGenerateContentRequest request);
*/


}
private void logResponse(String responseBody) {
if (logger != null) {
logger.debug("Response from Gemini:\n{}", responseBody);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import okhttp3.ResponseBody;
import retrofit2.Call;

import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
Expand All @@ -29,8 +26,6 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {

private final GeminiService geminiService;

private final Gson GSON = new Gson();

private final String modelName;
private final String apiKey;
private final Integer maxRetries;
Expand All @@ -40,14 +35,14 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {

@Builder
public GoogleAiEmbeddingModel(
String modelName,
String apiKey,
Integer maxRetries,
TaskType taskType,
String titleMetadataKey,
Integer outputDimensionality,
Duration timeout,
Boolean logRequestsAndResponses
String modelName,
String apiKey,
Integer maxRetries,
TaskType taskType,
String titleMetadataKey,
Integer outputDimensionality,
Duration timeout,
Boolean logRequestsAndResponses
) {

this.modelName = ensureNotBlank(modelName, "modelName");
Expand All @@ -64,33 +59,14 @@ public GoogleAiEmbeddingModel(

boolean logRequestsAndResponses1 = logRequestsAndResponses != null && logRequestsAndResponses;

this.geminiService = GeminiService.getGeminiService(logRequestsAndResponses1 ? log : null, timeout1);
this.geminiService = new GeminiService(logRequestsAndResponses1 ? log : null, timeout1);
}

@Override
public Response<Embedding> embed(TextSegment textSegment) {
GoogleAiEmbeddingRequest embeddingRequest = getGoogleAiEmbeddingRequest(textSegment);

Call<GoogleAiEmbeddingResponse> geminiEmbeddingResponseCall =
withRetry(() -> this.geminiService.embed(this.modelName, this.apiKey, embeddingRequest), this.maxRetries);

GoogleAiEmbeddingResponse geminiResponse;
try {
retrofit2.Response<GoogleAiEmbeddingResponse> executed = geminiEmbeddingResponseCall.execute();
geminiResponse = executed.body();

if (executed.code() >= 300) {
try (ResponseBody errorBody = executed.errorBody()) {
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();

throw new RuntimeException(
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
}
}
} catch (IOException e) {

throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embed).", e);
}
GoogleAiEmbeddingResponse geminiResponse = withRetry(() -> this.geminiService.embed(this.modelName, this.apiKey, embeddingRequest), this.maxRetries);

if (geminiResponse != null) {
return Response.from(Embedding.from(geminiResponse.getEmbedding().getValues()));
Expand All @@ -107,8 +83,8 @@ public Response<Embedding> embed(String text) {
@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
List<GoogleAiEmbeddingRequest> embeddingRequests = textSegments.stream()
.map(this::getGoogleAiEmbeddingRequest)
.collect(Collectors.toList());
.map(this::getGoogleAiEmbeddingRequest)
.collect(Collectors.toList());

List<Embedding> allEmbeddings = new ArrayList<>();
int numberOfEmbeddings = embeddingRequests.size();
Expand All @@ -123,30 +99,12 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
GoogleAiBatchEmbeddingRequest batchEmbeddingRequest = new GoogleAiBatchEmbeddingRequest();
batchEmbeddingRequest.setRequests(embeddingRequests.subList(startIndex, lastIndex));

Call<GoogleAiBatchEmbeddingResponse> geminiBatchEmbeddingResponseCall =
withRetry(() -> this.geminiService.batchEmbed(this.modelName, this.apiKey, batchEmbeddingRequest));

GoogleAiBatchEmbeddingResponse geminiResponse;
try {
retrofit2.Response<GoogleAiBatchEmbeddingResponse> executed = geminiBatchEmbeddingResponseCall.execute();
geminiResponse = executed.body();

if (executed.code() >= 300) {
try (ResponseBody errorBody = executed.errorBody()) {
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();

throw new RuntimeException(
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
}
}
} catch (IOException e) {
throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embedAll).", e);
}
GoogleAiBatchEmbeddingResponse geminiResponse = withRetry(() -> this.geminiService.batchEmbed(this.modelName, this.apiKey, batchEmbeddingRequest));

if (geminiResponse != null) {
allEmbeddings.addAll(geminiResponse.getEmbeddings().stream()
.map(values -> Embedding.from(values.getValues()))
.collect(Collectors.toList()));
.map(values -> Embedding.from(values.getValues()))
.collect(Collectors.toList()));
} else {
throw new RuntimeException("Gemini embedding response was null (embedAll)");
}
Expand All @@ -157,8 +115,8 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {

private GoogleAiEmbeddingRequest getGoogleAiEmbeddingRequest(TextSegment textSegment) {
GeminiPart geminiPart = GeminiPart.builder()
.text(textSegment.text())
.build();
.text(textSegment.text())
.build();

GeminiContent content = new GeminiContent(Collections.singletonList(geminiPart), null);

Expand All @@ -170,11 +128,11 @@ private GoogleAiEmbeddingRequest getGoogleAiEmbeddingRequest(TextSegment textSeg
}

return new GoogleAiEmbeddingRequest(
"models/" + this.modelName,
content,
this.taskType,
title,
this.outputDimensionality
"models/" + this.modelName,
content,
this.taskType,
title,
this.outputDimensionality
);
}

Expand Down
Loading

0 comments on commit 2ae3983

Please sign in to comment.