Skip to content

Commit

Permalink
OpenAI: support shortened embeddings (langchain4j#555)
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j authored Jan 26, 2024
1 parent 2a59976 commit 9ab9b1e
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class OpenAiEmbeddingModel implements EmbeddingModel, TokenCountEstimator

private final OpenAiClient client;
private final String modelName;
private final Integer dimensions;
private final String user;
private final Integer maxRetries;
private final Tokenizer tokenizer;
Expand All @@ -40,6 +41,7 @@ public OpenAiEmbeddingModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Integer dimensions,
String user,
Duration timeout,
Integer maxRetries,
Expand Down Expand Up @@ -68,6 +70,7 @@ public OpenAiEmbeddingModel(String baseUrl,
.logResponses(logResponses)
.build();
this.modelName = getOrDefault(modelName, TEXT_EMBEDDING_ADA_002);
this.dimensions = dimensions;
this.user = user;
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = getOrDefault(tokenizer, () -> new OpenAiTokenizer(this.modelName));
Expand All @@ -88,6 +91,7 @@ private Response<List<Embedding>> embedTexts(List<String> texts) {
EmbeddingRequest request = EmbeddingRequest.builder()
.input(texts)
.model(modelName)
.dimensions(dimensions)
.user(user)
.build();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,37 @@
package dev.langchain4j.model.openai;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;

import java.util.List;

import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;

class OpenAiEmbeddingModelIT {

EmbeddingModel model = OpenAiEmbeddingModel.withApiKey(System.getenv("OPENAI_API_KEY"));
EmbeddingModel model = OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.logRequests(true)
.logResponses(true)
.build();

@Test
void should_embed_and_return_token_usage() {
void should_embed_single_text() {

// given
String text = "hello world";

Response<Embedding> response = model.embed("hello world");
// when
Response<Embedding> response = model.embed(text);
System.out.println(response);

// then
assertThat(response.content().vector()).hasSize(1536);

TokenUsage tokenUsage = response.tokenUsage();
Expand All @@ -27,4 +41,62 @@ void should_embed_and_return_token_usage() {

assertThat(response.finishReason()).isNull();
}

@Test
void should_embed_multiple_segments() {

// given
List<TextSegment> segments = asList(
TextSegment.from("hello"),
TextSegment.from("world")
);

// when
Response<List<Embedding>> response = model.embedAll(segments);
System.out.println(response);

// then
assertThat(response.content()).hasSize(2);
assertThat(response.content().get(0).dimension()).isEqualTo(1536);
assertThat(response.content().get(1).dimension()).isEqualTo(1536);

TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(2);
assertThat(tokenUsage.outputTokenCount()).isNull();
assertThat(tokenUsage.totalTokenCount()).isEqualTo(2);

assertThat(response.finishReason()).isNull();
}

@Test
void should_embed_text_with_embedding_shortening() {

// given
int dimensions = 42;

EmbeddingModel model = OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName("text-embedding-3-small")
.dimensions(dimensions)
.logRequests(true)
.logResponses(true)
.build();

String text = "hello world";

// when
Response<Embedding> response = model.embed(text);
System.out.println(response);

// then
assertThat(response.content().dimension()).isEqualTo(dimensions);

TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(2);
assertThat(tokenUsage.outputTokenCount()).isNull();
assertThat(tokenUsage.totalTokenCount()).isEqualTo(2);

assertThat(response.finishReason()).isNull();
}
}
2 changes: 1 addition & 1 deletion langchain4j-parent/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
<maven.compiler.target>1.8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.build.outputTimestamp>1705675871</project.build.outputTimestamp>
<openai4j.version>0.12.4</openai4j.version>
<openai4j.version>0.13.0</openai4j.version>
<azure-ai-openai.version>1.0.0-beta.6</azure-ai-openai.version>
<azure.storage-blob.version>12.25.1</azure.storage-blob.version>
<azure.storage-common.version>12.24.1</azure.storage-common.version>
Expand Down

0 comments on commit 9ab9b1e

Please sign in to comment.