forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Google AI Gemini: replace OkHttp and Retrofit with Java 11 HttpClient (…
…langchain4j#1950) ## Issue Based on langchain4j#1903 ## Change Replaced OkHttp and Retrofit inside the GeminiService with an implementation using the HttpClient (Java 11). ## General checklist - [X] There are no breaking changes - [X] I have added unit and integration tests for my change (already existing) - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [X] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
- Loading branch information
1 parent
17d8384
commit 2ae3983
Showing
5 changed files
with
183 additions
and
302 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
151 changes: 81 additions & 70 deletions
151
langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiService.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,85 +1,96 @@ | ||
package dev.langchain4j.model.googleai; | ||
|
||
//import io.reactivex.rxjava3.core.Observable; | ||
import okhttp3.OkHttpClient; | ||
import okhttp3.logging.HttpLoggingInterceptor; | ||
import com.google.gson.Gson; | ||
import com.google.gson.GsonBuilder; | ||
import org.slf4j.Logger; | ||
import retrofit2.Call; | ||
import retrofit2.Retrofit; | ||
import retrofit2.converter.gson.GsonConverterFactory; | ||
import retrofit2.http.Body; | ||
import retrofit2.http.POST; | ||
import retrofit2.http.Path; | ||
import retrofit2.http.Header; | ||
import retrofit2.http.Headers; | ||
|
||
import java.io.IOException; | ||
import java.net.URI; | ||
import java.net.http.HttpClient; | ||
import java.net.http.HttpRequest; | ||
import java.net.http.HttpResponse; | ||
import java.time.Duration; | ||
//import retrofit2.http.Streaming; | ||
|
||
interface GeminiService { | ||
String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/"; | ||
String API_KEY_HEADER_NAME = "x-goog-api-key"; | ||
String USER_AGENT = "User-Agent: LangChain4j"; | ||
class GeminiService { | ||
private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta"; | ||
private static final String API_KEY_HEADER_NAME = "x-goog-api-key"; | ||
|
||
static GeminiService getGeminiService(Logger logger, Duration timeout) { | ||
Retrofit.Builder retrofitBuilder = new Retrofit.Builder() | ||
.baseUrl(GEMINI_AI_ENDPOINT) | ||
.addConverterFactory(GsonConverterFactory.create()); | ||
private final HttpClient httpClient; | ||
private final Gson gson; | ||
private final Logger logger; | ||
|
||
OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder() | ||
.callTimeout(timeout); | ||
GeminiService(Logger logger, Duration timeout) { | ||
this.logger = logger; | ||
this.gson = new GsonBuilder().setPrettyPrinting().create(); | ||
|
||
if (logger != null) { | ||
HttpLoggingInterceptor logging = new HttpLoggingInterceptor(logger::debug); | ||
logging.redactHeader(API_KEY_HEADER_NAME); | ||
logging.setLevel(HttpLoggingInterceptor.Level.BODY); | ||
this.httpClient = HttpClient.newBuilder() | ||
.connectTimeout(timeout) | ||
.build(); | ||
} | ||
|
||
GeminiGenerateContentResponse generateContent(String modelName, String apiKey, GeminiGenerateContentRequest request) { | ||
String url = String.format("%s/models/%s:generateContent", GEMINI_AI_ENDPOINT, modelName); | ||
return sendRequest(url, apiKey, request, GeminiGenerateContentResponse.class); | ||
} | ||
|
||
GeminiCountTokensResponse countTokens(String modelName, String apiKey, GeminiCountTokensRequest request) { | ||
String url = String.format("%s/models/%s:countTokens", GEMINI_AI_ENDPOINT, modelName); | ||
return sendRequest(url, apiKey, request, GeminiCountTokensResponse.class); | ||
} | ||
|
||
GoogleAiEmbeddingResponse embed(String modelName, String apiKey, GoogleAiEmbeddingRequest request) { | ||
String url = String.format("%s/models/%s:embedContent", GEMINI_AI_ENDPOINT, modelName); | ||
return sendRequest(url, apiKey, request, GoogleAiEmbeddingResponse.class); | ||
} | ||
|
||
GoogleAiBatchEmbeddingResponse batchEmbed(String modelName, String apiKey, GoogleAiBatchEmbeddingRequest request) { | ||
String url = String.format("%s/models/%s:batchEmbedContents", GEMINI_AI_ENDPOINT, modelName); | ||
return sendRequest(url, apiKey, request, GoogleAiBatchEmbeddingResponse.class); | ||
} | ||
|
||
private <T> T sendRequest(String url, String apiKey, Object requestBody, Class<T> responseType) { | ||
String jsonBody = gson.toJson(requestBody); | ||
HttpRequest request = buildHttpRequest(url, apiKey, jsonBody); | ||
|
||
clientBuilder.addInterceptor(logging); | ||
logRequest(jsonBody); | ||
|
||
try { | ||
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); | ||
|
||
if (response.statusCode() >= 300) { | ||
throw new RuntimeException(String.format("HTTP error (%d): %s", response.statusCode(), response.body())); | ||
} | ||
|
||
logResponse(response.body()); | ||
|
||
return gson.fromJson(response.body(), responseType); | ||
} catch (IOException e) { | ||
throw new RuntimeException("An error occurred while sending the request", e); | ||
} catch (InterruptedException e) { | ||
Thread.currentThread().interrupt(); | ||
throw new RuntimeException("Sending the request was interrupted", e); | ||
} | ||
} | ||
|
||
retrofitBuilder.client(clientBuilder.build()); | ||
Retrofit retrofit = retrofitBuilder.build(); | ||
private HttpRequest buildHttpRequest(String url, String apiKey, String jsonBody) { | ||
return HttpRequest.newBuilder() | ||
.uri(URI.create(url)) | ||
.header("Content-Type", "application/json") | ||
.header("User-Agent", "LangChain4j") | ||
.header(API_KEY_HEADER_NAME, apiKey) | ||
.POST(HttpRequest.BodyPublishers.ofString(jsonBody)) | ||
.build(); | ||
} | ||
|
||
return retrofit.create(GeminiService.class); | ||
private void logRequest(String jsonBody) { | ||
if (logger != null) { | ||
logger.debug("Sending request to Gemini:\n{}", jsonBody); | ||
} | ||
} | ||
|
||
@POST("models/{model}:generateContent") | ||
@Headers(USER_AGENT) | ||
Call<GeminiGenerateContentResponse> generateContent( | ||
@Path("model") String modelName, | ||
@Header(API_KEY_HEADER_NAME) String apiKey, | ||
@Body GeminiGenerateContentRequest request); | ||
|
||
@POST("models/{model}:countTokens") | ||
@Headers(USER_AGENT) | ||
Call<GeminiCountTokensResponse> countTokens( | ||
@Path("model") String modelName, | ||
@Header(API_KEY_HEADER_NAME) String apiKey, | ||
@Body GeminiCountTokensRequest countTokensRequest); | ||
|
||
@POST("models/{model}:embedContent") | ||
@Headers(USER_AGENT) | ||
Call<GoogleAiEmbeddingResponse> embed( | ||
@Path("model") String modelName, | ||
@Header(API_KEY_HEADER_NAME) String apiKey, | ||
@Body GoogleAiEmbeddingRequest embeddingRequest); | ||
|
||
@POST("models/{model}:batchEmbedContents") | ||
@Headers(USER_AGENT) | ||
Call<GoogleAiBatchEmbeddingResponse> batchEmbed( | ||
@Path("model") String modelName, | ||
@Header(API_KEY_HEADER_NAME) String apiKey, | ||
@Body GoogleAiBatchEmbeddingRequest batchEmbeddingRequest); | ||
|
||
/* | ||
@Streaming | ||
@POST("models/{model}:streamGenerateContent") | ||
@Headers("User-Agent: LangChain4j") | ||
Observable<GeminiGenerateContentResponse> streamGenerateContent( | ||
@Path("model") String modelName, | ||
@Header(API_KEY_HEADER_NAME) String apiKey, | ||
@Body GeminiGenerateContentRequest request); | ||
*/ | ||
|
||
|
||
} | ||
private void logResponse(String responseBody) { | ||
if (logger != null) { | ||
logger.debug("Response from Gemini:\n{}", responseBody); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.