Skip to content

Commit

Permalink
Jina AI Embedding model integration (langchain4j#997)
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j committed May 22, 2024
1 parent 0ad92d5 commit 050e93b
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 69 deletions.
8 changes: 7 additions & 1 deletion langchain4j-bom/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-jina</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-local-ai</artifactId>
Expand Down Expand Up @@ -375,7 +381,7 @@
<artifactId>langchain4j-web-search-engine-google-custom</artifactId>
<version>${project.version}</version>
</dependency>

<!-- experimental -->
<dependency>
<groupId>dev.langchain4j</groupId>
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.jina.internal.api.EmbeddingRequest;
import dev.langchain4j.model.jina.internal.api.EmbeddingResponse;
import dev.langchain4j.model.jina.internal.api.JinaEmbedding;
import dev.langchain4j.model.jina.internal.client.JinaClient;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;
Expand All @@ -16,13 +20,11 @@
import static java.util.stream.Collectors.toList;

/**
* An integration with Nomic Atlas's Text Embeddings API.
* See more details <a href="https://api.jina.ai/redoc#tag/embeddings">Jina API reference</a>
* An integration with Jina Embeddings API.
* See more details <a href="https://api.jina.ai/redoc#tag/embeddings">here</a>.
*/

public class JinaEmbeddingModel implements EmbeddingModel {


private static final String DEFAULT_BASE_URL = "https://api.jina.ai/";

private final JinaClient client;
Expand All @@ -36,7 +38,7 @@ public JinaEmbeddingModel(String baseUrl,
Duration timeout,
Integer maxRetries) {
this.client = JinaClient.builder()
.baseUrl(getOrDefault(baseUrl,DEFAULT_BASE_URL))
.baseUrl(getOrDefault(baseUrl, DEFAULT_BASE_URL))
.apiKey(apiKey)
.timeout(getOrDefault(timeout, ofSeconds(60)))
.build();
Expand All @@ -48,21 +50,21 @@ public static JinaEmbeddingModel withApiKey(String apiKey) {
return JinaEmbeddingModel.builder().apiKey(apiKey).build();
}


@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {

EmbeddingRequest request = EmbeddingRequest.builder()
.model(modelName)
.input(textSegments.stream().map(TextSegment::text).collect(toList()))
.build();

EmbeddingResponse response = withRetry(() -> client.embed(request), maxRetries);

List<Embedding> embeddings = response.getData().stream()
.map(JinaEmbedding::toEmbedding).collect(toList());
List<Embedding> embeddings = response.data.stream()
.map(JinaEmbedding::toEmbedding)
.collect(toList());

TokenUsage tokenUsage = new TokenUsage(response.getUsage().getPromptTokens(),0 );
return Response.from(embeddings,tokenUsage);
TokenUsage tokenUsage = new TokenUsage(response.usage.promptTokens, 0);
return Response.from(embeddings, tokenUsage);
}

}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package dev.langchain4j.model.jina.internal.api;

import lombok.Builder;

import java.util.List;

@Builder
public class EmbeddingRequest {

public String model;
public List<String> input;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package dev.langchain4j.model.jina.internal.api;

import java.util.List;

public class EmbeddingResponse {

public List<JinaEmbedding> data;
public Usage usage;
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package dev.langchain4j.model.jina;
package dev.langchain4j.model.jina.internal.api;

import retrofit2.Call;
import retrofit2.http.Body;
Expand All @@ -7,8 +7,8 @@
import retrofit2.http.POST;

public interface JinaApi {

@POST("v1/embeddings")
@Headers({"Content-Type: application/json"})
Call<EmbeddingResponse> embed(@Body EmbeddingRequest request, @Header("Authorization") String authorizationHeader);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package dev.langchain4j.model.jina.internal.api;

import dev.langchain4j.data.embedding.Embedding;

public class JinaEmbedding {

public long index;
public float[] embedding;
public String object;

public Embedding toEmbedding() {
return Embedding.from(embedding);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package dev.langchain4j.model.jina.internal.api;

import com.google.gson.annotations.SerializedName;

public class Usage {

@SerializedName("total_tokens")
public Integer totalTokens;

@SerializedName("prompt_tokens")
public Integer promptTokens;
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package dev.langchain4j.model.jina;
package dev.langchain4j.model.jina.internal.client;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import dev.langchain4j.model.jina.internal.api.EmbeddingRequest;
import dev.langchain4j.model.jina.internal.api.EmbeddingResponse;
import dev.langchain4j.model.jina.internal.api.JinaApi;
import lombok.Builder;
import okhttp3.OkHttpClient;
import retrofit2.Retrofit;
Expand All @@ -14,6 +17,7 @@
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;

public class JinaClient {

private static final Gson GSON = new GsonBuilder()
.setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES)
.setPrettyPrinting()
Expand All @@ -23,7 +27,7 @@ public class JinaClient {
private final String authorizationHeader;

@Builder
JinaClient(String baseUrl, String apiKey, Duration timeout){
JinaClient(String baseUrl, String apiKey, Duration timeout) {
OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder()
.callTimeout(timeout)
.connectTimeout(timeout)
Expand All @@ -35,7 +39,6 @@ public class JinaClient {
.addConverterFactory(GsonConverterFactory.create(GSON))
.build();


this.jinaApi = retrofit.create(JinaApi.class);
this.authorizationHeader = "Bearer " + ensureNotBlank(apiKey, "apiKey");
}
Expand All @@ -44,7 +47,6 @@ public EmbeddingResponse embed(EmbeddingRequest request) {
try {
retrofit2.Response<EmbeddingResponse> retrofitResponse
= jinaApi.embed(request, authorizationHeader).execute();

if (retrofitResponse.isSuccessful()) {
return retrofitResponse.body();
} else {
Expand All @@ -55,13 +57,10 @@ public EmbeddingResponse embed(EmbeddingRequest request) {
}
}



private static RuntimeException toException(retrofit2.Response<?> response) throws IOException {
int code = response.code();
String body = response.errorBody().string();
String errorMessage = String.format("status code: %s; body: %s", code, body);
return new RuntimeException(errorMessage);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;



public class JinaEmbeddingModelIT {

@Test
public void should_embed_single_text() {

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
<module>langchain4j-cohere</module>
<module>langchain4j-dashscope</module>
<module>langchain4j-hugging-face</module>
<module>langchain4j-jina</module>
<module>langchain4j-local-ai</module>
<module>langchain4j-mistral-ai</module>
<module>langchain4j-nomic</module>
Expand All @@ -37,7 +38,6 @@
<module>langchain4j-vertex-ai</module>
<module>langchain4j-vertex-ai-gemini</module>
<module>langchain4j-zhipu-ai</module>
<module>langchain4j-jina</module>

<!-- embedding stores -->
<module>langchain4j-azure-ai-search</module>
Expand Down

0 comments on commit 050e93b

Please sign in to comment.