Skip to content

Commit

Permalink
New ScoringModel for Google Cloud Vertex AI Ranking API (langchain4j#…
Browse files Browse the repository at this point in the history
…1820)

## Issue
Closes langchain4j#1819

## Change
Add support for the Vertex AI Ranking API, by implement a `ScoringModel`
for it.

## General checklist

- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [ ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [X] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
  • Loading branch information
glaforge authored Sep 24, 2024
1 parent ef2f1ec commit 9e2ee93
Show file tree
Hide file tree
Showing 4 changed files with 373 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
---
sidebar_position: 3
---

# Google Cloud Vertex AI Ranking API

- [Google Cloud Vertex AI Ranking documentation](https://cloud.google.com/generative-ai-app-builder/docs/ranking)
- [Google Cloud Vertex AI Ranking API description](https://cloud.google.com/generative-ai-app-builder/docs/reference/rest/v1/projects.locations.rankingConfigs/rank)


### Introduction

The Google Cloud Vertex AI Ranking API is a powerful tool that enhances search results by refining the relevance of
retrieved documents to a given query. Unlike traditional search methods, it leverages advanced machine learning
algorithms to understand the semantic context of both the query and the documents, delivering more precise and relevant
results. By analyzing the semantic relationship between the query and each document, the API can reorder the candidate
documents based on their calculated relevance scores, ensuring that the most relevant results appear at the top of the
search results page.

### Maven Dependency

```xml
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-vertex-ai</artifactId>
<version>0.34.0</version>
</dependency>
```

### Usage

To configure the model, you'll have to specify:
* the Google Cloud project ID,
* the project number,
* the location (ex. `us-central1`, `europe-west1`),
* and the model you want to use.

> Note: You can find the project number in the Google Cloud console, or by running `gcloud projects describe your-project-id`.
You can score a single string or `TextSegment` against a query
thanks to the `score(text, query)` and `score(segment, query)` methods.

It is also possible to score several strings or `TextSegment`s against the query,
with the `scoreAll(segments, query)` method:

```java
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder()
.projectId(System.getenv("GCP_PROJECT_ID"))
.projectNumber(System.getenv("GCP_PROJECT_NUM"))
.projectLocation(System.getenv("GCP_LOCATION"))
.model("semantic-ranker-512")
.build();

Response<List<Double>> score = scoringModel.scoreAll(Stream.of(
"The sky appears blue due to a phenomenon called Rayleigh scattering. " +
"Sunlight is comprised of all the colors of the rainbow. Blue light has shorter " +
"wavelengths than other colors, and is thus scattered more easily.",

"A canvas stretched across the day,\n" +
"Where sunlight learns to dance and play.\n" +
"Blue, a hue of scattered light,\n" +
"A gentle whisper, soft and bright."
).map(TextSegment::from).collect(Collectors.toList()),
"Why is the sky blue?");

// [0.8199999928474426, 0.4300000071525574]
```

If you pass `TextSegment`s which have a particular `title` key, the Ranker model can take this metadata into account in its calculation.
To specify a custom title key, you can use the `titleMetadataKey()` builder method.`

You can use scoring models with `AiServices` and its `contentAgregator()` method,
which takes a `ContentAggregator` class that can specify a scoring model:

```java
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder()
.projectId(System.getenv("GCP_PROJECT_ID"))
.projectNumber(System.getenv("GCP_PROJECT_NUM"))
.projectLocation(System.getenv("GCP_LOCATION"))
.model("semantic-ranker-512")
.build();

ContentAggregator contentAggregator = ReRankingContentAggregator.builder()
.scoringModel(scoringModel)
...
.build();

RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
...
.contentAggregator(contentAggregator)
.build();

return AiServices.builder(Assistant.class)
.chatLanguageModel(...)
.retrievalAugmentor(retrievalAugmentor)
.build();
```
29 changes: 29 additions & 0 deletions langchain4j-vertex-ai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
<artifactId>langchain4j-core</artifactId>
</dependency>

<!-- Google Vertex AI library -->

<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>google-cloud-aiplatform</artifactId>
Expand All @@ -33,6 +35,21 @@
</exclusions>
</dependency>

<!-- Ranking API -->

<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>google-cloud-discoveryengine</artifactId>
</dependency>

<!-- testing dependencies -->

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
Expand All @@ -59,6 +76,18 @@
</dependency>
</dependencies>

<dependencyManagement>
<dependencies>
<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>google-cloud-discoveryengine-bom</artifactId>
<scope>import</scope>
<type>pom</type>
<version>0.45.0</version>
</dependency>
</dependencies>
</dependencyManagement>

<build>
<plugins>
<plugin>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package dev.langchain4j.model.vertexai;

import com.google.cloud.discoveryengine.v1beta.*;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static java.util.Comparator.comparing;

/**
* Implementation of a <code>ScoringModel</code> for the Google Cloud Vertex AI
* <a href="https://cloud.google.com/generative-ai-app-builder/docs/ranking">Ranking API</a>.
*/
public class VertexAiScoringModel implements ScoringModel {

private final String model;
private final String projectId;
private final String projectNumber;
private final String location;
private final String titleMetadataKey;

/**
* Constructor for the Vertex AI Ranker Scoring Model.
*
* @param projectId The Google Cloud Project ID.
* @param projectNumber The Google Cloud Project Number.
* @param location The Google Cloud Region.
* @param model The model to use
* @param titleMetadataKey The name of the key to use as a title.
*/
public VertexAiScoringModel(String projectId, String projectNumber, String location, String model, String titleMetadataKey) {
this.projectId = ensureNotBlank(projectId, "projectId");
this.projectNumber = ensureNotBlank(projectNumber, "projectNumber");
this.location = ensureNotBlank(location, "location");
this.model = ensureNotBlank(model, "model");
this.titleMetadataKey = titleMetadataKey != null ? titleMetadataKey : "title";
}

/**
* Scores all provided {@link TextSegment}s against a given query.
*
* @param segments The list of {@link TextSegment}s to score.
* @param query The query against which to score the segments.
* @return the list of scores. The order of scores corresponds to the order of {@link TextSegment}s.
*/
@Override
public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) {
AtomicInteger counter = new AtomicInteger();

try (RankServiceClient rankServiceClient = RankServiceClient.create(
RankServiceSettings.newBuilder().build())) {

RankRequest.Builder rankingRequestBuilder = RankRequest.newBuilder();

if (model != null && !model.isEmpty()) {
rankingRequestBuilder.setModel(model);
}

rankingRequestBuilder
.setRankingConfig(RankingConfigName.newBuilder()
.setProject(projectId)
.setLocation(location)
.setRankingConfig(
String.format("projects/%s/locations/%s/rankingConfigs/default_ranking_config.", projectNumber, location))
.build().getRankingConfig())
.setQuery(query)
.setIgnoreRecordDetailsInResponse(true)
.addAllRecords(segments.stream()
.map(segment -> {
RankingRecord.Builder rankingBuilder = RankingRecord.newBuilder()
.setContent(segment.text());
// Ranker API takes into account titles in its score calculations
if (segment.metadata().getString(titleMetadataKey) != null) {
rankingBuilder.setTitle(segment.metadata().getString(titleMetadataKey));
}
// custom ID used to reorder the (sorted) results back into original segment order
rankingBuilder.setId(String.valueOf(counter.getAndIncrement()));
return rankingBuilder.build();
})
.collect(Collectors.toList()));

RankResponse rankResponse = rankServiceClient.rank(rankingRequestBuilder.build());

return Response.from(rankResponse.getRecordsList().stream()
// the API returns results sorted by relevance score, so reorder them back to original order
.sorted(comparing(rr -> Double.valueOf(rr.getId())))
.map(RankingRecord::getScore)
.map(Double::valueOf)
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
}

public static Builder builder() {
return new Builder();
}

public static class Builder {
private String model;
private String projectId;
private String projectNumber;
private String location;
private String titleMetadataKey;

public Builder model(String model) {
this.model = ensureNotBlank(model, "model");
return this;
}

public Builder projectId(String projectId) {
this.projectId = projectId;
return this;
}

public Builder projectNumber(String projectNumber) {
this.projectNumber = projectNumber;
return this;
}

public Builder location(String location) {
this.location = location;
return this;
}

public Builder titleMetadataKey(String titleMetadataKey) {
this.titleMetadataKey = ensureNotBlank(titleMetadataKey, "titleMetadataKey");
return this;
}

public VertexAiScoringModel build() {
return new VertexAiScoringModel(projectId, projectNumber, location, model, titleMetadataKey);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package dev.langchain4j.model.vertexai;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Test;
import dev.langchain4j.data.segment.TextSegment;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.assertj.core.api.Assertions.assertThat;

public class VertexAiScoringModelIT {
@Test
void should_rank_multiple() {
// given
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder()
.projectId(System.getenv("GCP_PROJECT_ID"))
.projectNumber(System.getenv("GCP_PROJECT_NUM"))
.location(System.getenv("GCP_LOCATION"))
.model("semantic-ranker-512")
.build();

// when
Response<List<Double>> score = scoringModel.scoreAll(Stream.of(
"The sky appears blue due to a phenomenon called Rayleigh scattering. " +
"Sunlight is comprised of all the colors of the rainbow. Blue light has shorter " +
"wavelengths than other colors, and is thus scattered more easily.",

"A canvas stretched across the day,\n" +
"Where sunlight learns to dance and play.\n" +
"Blue, a hue of scattered light,\n" +
"A gentle whisper, soft and bright."
).map(TextSegment::from).collect(Collectors.toList()),
"Why is the sky blue?");

// then
assertThat(score.content().get(0)).isGreaterThan(score.content().get(1));
}

@Test
void should_rank_single() {
// given
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder()
.projectId(System.getenv("GCP_PROJECT_ID"))
.projectNumber(System.getenv("GCP_PROJECT_NUM"))
.location(System.getenv("GCP_LOCATION"))
.model("semantic-ranker-512")
.build();

// when
Response<Double> score = scoringModel.score(
"The sky appears blue due to a phenomenon called Rayleigh scattering. " +
"Sunlight is comprised of all the colors of the rainbow. Blue light has shorter " +
"wavelengths than other colors, and is thus scattered more easily.",
"Why is the sky blue?");

// then
assertThat(score.content()).isPositive();
}

@Test
void should_use_text_segment_titles_into_account() {
// given
String customTitleKey = "customTitle";

VertexAiScoringModel scoringModel = VertexAiScoringModel.builder()
.projectId(System.getenv("GCP_PROJECT_ID"))
.projectNumber(System.getenv("GCP_PROJECT_NUM"))
.location(System.getenv("GCP_LOCATION"))
.model("semantic-ranker-512")
.titleMetadataKey(customTitleKey)
.build();

List<TextSegment> segments = Arrays.asList(
new TextSegment(
"Your Cymbal Starlight 2024 is not equipped to tow a trailer.",
new Metadata().put(customTitleKey, "trailer")),
new TextSegment(
"The Cymbal Starlight 2024 has a cargo capacity of 13.5 cubic feet.",
new Metadata().put(customTitleKey, "capacity")),
new TextSegment(
"The cargo area is located in the trunk of the vehicle.",
new Metadata().put(customTitleKey, "trunk")),
new TextSegment(
"To access the cargo area, open the trunk lid using the trunk release lever located in the driver's footwell.",
new Metadata().put(customTitleKey, "lever")),
new TextSegment(
"When loading cargo into the trunk, be sure to distribute the weight evenly.",
new Metadata().put(customTitleKey, "weight")),
new TextSegment(
"Do not overload the trunk, as this could affect the vehicle's handling and stability.",
new Metadata().put(customTitleKey, "overload"))
);

// when
Response<List<Double>> score = scoringModel.scoreAll(segments, "What is the cargo capacity of the car?");

// then
double maxScore = score.content().stream().mapToDouble(Double::doubleValue).max().getAsDouble();

assertThat(score.content().get(1)).isEqualTo(maxScore);
}
}

0 comments on commit 9e2ee93

Please sign in to comment.