forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New ScoringModel for Google Cloud Vertex AI Ranking API (langchain4j#…
…1820) ## Issue Closes langchain4j#1819 ## Change Add support for the Vertex AI Ranking API, by implement a `ScoringModel` for it. ## General checklist - [X] There are no breaking changes - [X] I have added unit and integration tests for my change - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [ ] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green <!-- Before adding documentation and example(s) (below), please wait until the PR is reviewed and approved. --> - [X] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
- Loading branch information
Showing
4 changed files
with
373 additions
and
0 deletions.
There are no files selected for viewing
97 changes: 97 additions & 0 deletions
97
docs/docs/integrations/scoring-reranking-models/3-vertex-ai-ranking.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
--- | ||
sidebar_position: 3 | ||
--- | ||
|
||
# Google Cloud Vertex AI Ranking API | ||
|
||
- [Google Cloud Vertex AI Ranking documentation](https://cloud.google.com/generative-ai-app-builder/docs/ranking) | ||
- [Google Cloud Vertex AI Ranking API description](https://cloud.google.com/generative-ai-app-builder/docs/reference/rest/v1/projects.locations.rankingConfigs/rank) | ||
|
||
|
||
### Introduction | ||
|
||
The Google Cloud Vertex AI Ranking API is a powerful tool that enhances search results by refining the relevance of | ||
retrieved documents to a given query. Unlike traditional search methods, it leverages advanced machine learning | ||
algorithms to understand the semantic context of both the query and the documents, delivering more precise and relevant | ||
results. By analyzing the semantic relationship between the query and each document, the API can reorder the candidate | ||
documents based on their calculated relevance scores, ensuring that the most relevant results appear at the top of the | ||
search results page. | ||
|
||
### Maven Dependency | ||
|
||
```xml | ||
<dependency> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-vertex-ai</artifactId> | ||
<version>0.34.0</version> | ||
</dependency> | ||
``` | ||
|
||
### Usage | ||
|
||
To configure the model, you'll have to specify: | ||
* the Google Cloud project ID, | ||
* the project number, | ||
* the location (ex. `us-central1`, `europe-west1`), | ||
* and the model you want to use. | ||
|
||
> Note: You can find the project number in the Google Cloud console, or by running `gcloud projects describe your-project-id`. | ||
You can score a single string or `TextSegment` against a query | ||
thanks to the `score(text, query)` and `score(segment, query)` methods. | ||
|
||
It is also possible to score several strings or `TextSegment`s against the query, | ||
with the `scoreAll(segments, query)` method: | ||
|
||
```java | ||
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder() | ||
.projectId(System.getenv("GCP_PROJECT_ID")) | ||
.projectNumber(System.getenv("GCP_PROJECT_NUM")) | ||
.projectLocation(System.getenv("GCP_LOCATION")) | ||
.model("semantic-ranker-512") | ||
.build(); | ||
|
||
Response<List<Double>> score = scoringModel.scoreAll(Stream.of( | ||
"The sky appears blue due to a phenomenon called Rayleigh scattering. " + | ||
"Sunlight is comprised of all the colors of the rainbow. Blue light has shorter " + | ||
"wavelengths than other colors, and is thus scattered more easily.", | ||
|
||
"A canvas stretched across the day,\n" + | ||
"Where sunlight learns to dance and play.\n" + | ||
"Blue, a hue of scattered light,\n" + | ||
"A gentle whisper, soft and bright." | ||
).map(TextSegment::from).collect(Collectors.toList()), | ||
"Why is the sky blue?"); | ||
|
||
// [0.8199999928474426, 0.4300000071525574] | ||
``` | ||
|
||
If you pass `TextSegment`s which have a particular `title` key, the Ranker model can take this metadata into account in its calculation. | ||
To specify a custom title key, you can use the `titleMetadataKey()` builder method.` | ||
|
||
You can use scoring models with `AiServices` and its `contentAgregator()` method, | ||
which takes a `ContentAggregator` class that can specify a scoring model: | ||
|
||
```java | ||
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder() | ||
.projectId(System.getenv("GCP_PROJECT_ID")) | ||
.projectNumber(System.getenv("GCP_PROJECT_NUM")) | ||
.projectLocation(System.getenv("GCP_LOCATION")) | ||
.model("semantic-ranker-512") | ||
.build(); | ||
|
||
ContentAggregator contentAggregator = ReRankingContentAggregator.builder() | ||
.scoringModel(scoringModel) | ||
... | ||
.build(); | ||
|
||
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder() | ||
... | ||
.contentAggregator(contentAggregator) | ||
.build(); | ||
|
||
return AiServices.builder(Assistant.class) | ||
.chatLanguageModel(...) | ||
.retrievalAugmentor(retrievalAugmentor) | ||
.build(); | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
141 changes: 141 additions & 0 deletions
141
langchain4j-vertex-ai/src/main/java/dev/langchain4j/model/vertexai/VertexAiScoringModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
package dev.langchain4j.model.vertexai; | ||
|
||
import com.google.cloud.discoveryengine.v1beta.*; | ||
import dev.langchain4j.data.segment.TextSegment; | ||
import dev.langchain4j.model.output.Response; | ||
import dev.langchain4j.model.scoring.ScoringModel; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.concurrent.atomic.AtomicInteger; | ||
import java.util.stream.Collectors; | ||
|
||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; | ||
import static java.util.Comparator.comparing; | ||
|
||
/** | ||
* Implementation of a <code>ScoringModel</code> for the Google Cloud Vertex AI | ||
* <a href="https://cloud.google.com/generative-ai-app-builder/docs/ranking">Ranking API</a>. | ||
*/ | ||
public class VertexAiScoringModel implements ScoringModel { | ||
|
||
private final String model; | ||
private final String projectId; | ||
private final String projectNumber; | ||
private final String location; | ||
private final String titleMetadataKey; | ||
|
||
/** | ||
* Constructor for the Vertex AI Ranker Scoring Model. | ||
* | ||
* @param projectId The Google Cloud Project ID. | ||
* @param projectNumber The Google Cloud Project Number. | ||
* @param location The Google Cloud Region. | ||
* @param model The model to use | ||
* @param titleMetadataKey The name of the key to use as a title. | ||
*/ | ||
public VertexAiScoringModel(String projectId, String projectNumber, String location, String model, String titleMetadataKey) { | ||
this.projectId = ensureNotBlank(projectId, "projectId"); | ||
this.projectNumber = ensureNotBlank(projectNumber, "projectNumber"); | ||
this.location = ensureNotBlank(location, "location"); | ||
this.model = ensureNotBlank(model, "model"); | ||
this.titleMetadataKey = titleMetadataKey != null ? titleMetadataKey : "title"; | ||
} | ||
|
||
/** | ||
* Scores all provided {@link TextSegment}s against a given query. | ||
* | ||
* @param segments The list of {@link TextSegment}s to score. | ||
* @param query The query against which to score the segments. | ||
* @return the list of scores. The order of scores corresponds to the order of {@link TextSegment}s. | ||
*/ | ||
@Override | ||
public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) { | ||
AtomicInteger counter = new AtomicInteger(); | ||
|
||
try (RankServiceClient rankServiceClient = RankServiceClient.create( | ||
RankServiceSettings.newBuilder().build())) { | ||
|
||
RankRequest.Builder rankingRequestBuilder = RankRequest.newBuilder(); | ||
|
||
if (model != null && !model.isEmpty()) { | ||
rankingRequestBuilder.setModel(model); | ||
} | ||
|
||
rankingRequestBuilder | ||
.setRankingConfig(RankingConfigName.newBuilder() | ||
.setProject(projectId) | ||
.setLocation(location) | ||
.setRankingConfig( | ||
String.format("projects/%s/locations/%s/rankingConfigs/default_ranking_config.", projectNumber, location)) | ||
.build().getRankingConfig()) | ||
.setQuery(query) | ||
.setIgnoreRecordDetailsInResponse(true) | ||
.addAllRecords(segments.stream() | ||
.map(segment -> { | ||
RankingRecord.Builder rankingBuilder = RankingRecord.newBuilder() | ||
.setContent(segment.text()); | ||
// Ranker API takes into account titles in its score calculations | ||
if (segment.metadata().getString(titleMetadataKey) != null) { | ||
rankingBuilder.setTitle(segment.metadata().getString(titleMetadataKey)); | ||
} | ||
// custom ID used to reorder the (sorted) results back into original segment order | ||
rankingBuilder.setId(String.valueOf(counter.getAndIncrement())); | ||
return rankingBuilder.build(); | ||
}) | ||
.collect(Collectors.toList())); | ||
|
||
RankResponse rankResponse = rankServiceClient.rank(rankingRequestBuilder.build()); | ||
|
||
return Response.from(rankResponse.getRecordsList().stream() | ||
// the API returns results sorted by relevance score, so reorder them back to original order | ||
.sorted(comparing(rr -> Double.valueOf(rr.getId()))) | ||
.map(RankingRecord::getScore) | ||
.map(Double::valueOf) | ||
.collect(Collectors.toList())); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
public static class Builder { | ||
private String model; | ||
private String projectId; | ||
private String projectNumber; | ||
private String location; | ||
private String titleMetadataKey; | ||
|
||
public Builder model(String model) { | ||
this.model = ensureNotBlank(model, "model"); | ||
return this; | ||
} | ||
|
||
public Builder projectId(String projectId) { | ||
this.projectId = projectId; | ||
return this; | ||
} | ||
|
||
public Builder projectNumber(String projectNumber) { | ||
this.projectNumber = projectNumber; | ||
return this; | ||
} | ||
|
||
public Builder location(String location) { | ||
this.location = location; | ||
return this; | ||
} | ||
|
||
public Builder titleMetadataKey(String titleMetadataKey) { | ||
this.titleMetadataKey = ensureNotBlank(titleMetadataKey, "titleMetadataKey"); | ||
return this; | ||
} | ||
|
||
public VertexAiScoringModel build() { | ||
return new VertexAiScoringModel(projectId, projectNumber, location, model, titleMetadataKey); | ||
} | ||
} | ||
} |
106 changes: 106 additions & 0 deletions
106
...hain4j-vertex-ai/src/test/java/dev/langchain4j/model/vertexai/VertexAiScoringModelIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
package dev.langchain4j.model.vertexai; | ||
|
||
import dev.langchain4j.data.document.Metadata; | ||
import dev.langchain4j.model.output.Response; | ||
import org.junit.jupiter.api.Test; | ||
import dev.langchain4j.data.segment.TextSegment; | ||
|
||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.stream.Collectors; | ||
import java.util.stream.Stream; | ||
|
||
import static org.assertj.core.api.Assertions.assertThat; | ||
|
||
public class VertexAiScoringModelIT { | ||
@Test | ||
void should_rank_multiple() { | ||
// given | ||
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder() | ||
.projectId(System.getenv("GCP_PROJECT_ID")) | ||
.projectNumber(System.getenv("GCP_PROJECT_NUM")) | ||
.location(System.getenv("GCP_LOCATION")) | ||
.model("semantic-ranker-512") | ||
.build(); | ||
|
||
// when | ||
Response<List<Double>> score = scoringModel.scoreAll(Stream.of( | ||
"The sky appears blue due to a phenomenon called Rayleigh scattering. " + | ||
"Sunlight is comprised of all the colors of the rainbow. Blue light has shorter " + | ||
"wavelengths than other colors, and is thus scattered more easily.", | ||
|
||
"A canvas stretched across the day,\n" + | ||
"Where sunlight learns to dance and play.\n" + | ||
"Blue, a hue of scattered light,\n" + | ||
"A gentle whisper, soft and bright." | ||
).map(TextSegment::from).collect(Collectors.toList()), | ||
"Why is the sky blue?"); | ||
|
||
// then | ||
assertThat(score.content().get(0)).isGreaterThan(score.content().get(1)); | ||
} | ||
|
||
@Test | ||
void should_rank_single() { | ||
// given | ||
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder() | ||
.projectId(System.getenv("GCP_PROJECT_ID")) | ||
.projectNumber(System.getenv("GCP_PROJECT_NUM")) | ||
.location(System.getenv("GCP_LOCATION")) | ||
.model("semantic-ranker-512") | ||
.build(); | ||
|
||
// when | ||
Response<Double> score = scoringModel.score( | ||
"The sky appears blue due to a phenomenon called Rayleigh scattering. " + | ||
"Sunlight is comprised of all the colors of the rainbow. Blue light has shorter " + | ||
"wavelengths than other colors, and is thus scattered more easily.", | ||
"Why is the sky blue?"); | ||
|
||
// then | ||
assertThat(score.content()).isPositive(); | ||
} | ||
|
||
@Test | ||
void should_use_text_segment_titles_into_account() { | ||
// given | ||
String customTitleKey = "customTitle"; | ||
|
||
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder() | ||
.projectId(System.getenv("GCP_PROJECT_ID")) | ||
.projectNumber(System.getenv("GCP_PROJECT_NUM")) | ||
.location(System.getenv("GCP_LOCATION")) | ||
.model("semantic-ranker-512") | ||
.titleMetadataKey(customTitleKey) | ||
.build(); | ||
|
||
List<TextSegment> segments = Arrays.asList( | ||
new TextSegment( | ||
"Your Cymbal Starlight 2024 is not equipped to tow a trailer.", | ||
new Metadata().put(customTitleKey, "trailer")), | ||
new TextSegment( | ||
"The Cymbal Starlight 2024 has a cargo capacity of 13.5 cubic feet.", | ||
new Metadata().put(customTitleKey, "capacity")), | ||
new TextSegment( | ||
"The cargo area is located in the trunk of the vehicle.", | ||
new Metadata().put(customTitleKey, "trunk")), | ||
new TextSegment( | ||
"To access the cargo area, open the trunk lid using the trunk release lever located in the driver's footwell.", | ||
new Metadata().put(customTitleKey, "lever")), | ||
new TextSegment( | ||
"When loading cargo into the trunk, be sure to distribute the weight evenly.", | ||
new Metadata().put(customTitleKey, "weight")), | ||
new TextSegment( | ||
"Do not overload the trunk, as this could affect the vehicle's handling and stability.", | ||
new Metadata().put(customTitleKey, "overload")) | ||
); | ||
|
||
// when | ||
Response<List<Double>> score = scoringModel.scoreAll(segments, "What is the cargo capacity of the car?"); | ||
|
||
// then | ||
double maxScore = score.content().stream().mapToDouble(Double::doubleValue).max().getAsDouble(); | ||
|
||
assertThat(score.content().get(1)).isEqualTo(maxScore); | ||
} | ||
} |