-
-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(embeddings): Integrate Google Vertex AI PaLM Embeddings (#100)
- Loading branch information
1 parent
3897595
commit d777ecc
Showing
14 changed files
with
280 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
export 'vertex_ai.dart'; |
142 changes: 142 additions & 0 deletions
142
packages/langchain_google/lib/src/embeddings/vertex_ai.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import 'package:googleapis_auth/googleapis_auth.dart'; | ||
import 'package:langchain/langchain.dart'; | ||
import 'package:vertex_ai/vertex_ai.dart'; | ||
|
||
/// {@template vertex_ai_embeddings} | ||
/// Wrapper around GCP Vertex AI text embedding models API | ||
/// | ||
/// Example: | ||
/// ```dart | ||
/// final embeddings = VertexAIEmbeddings( | ||
/// authHttpClient: authClient, | ||
/// project: 'your-project-id', | ||
/// ); | ||
/// final result = await embeddings.embedQuery('Hello world'); | ||
/// ``` | ||
/// | ||
/// Vertex AI documentation: | ||
/// https://cloud.google.com/vertex-ai/docs/generative-ai/language-model-overview | ||
/// | ||
/// ### Set up your Google Cloud Platform project | ||
/// | ||
/// 1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager). | ||
/// 2. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project). | ||
/// 3. [Enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com). | ||
/// 4. [Configure the Vertex AI location](https://cloud.google.com/vertex-ai/docs/general/locations). | ||
/// | ||
/// ### Authentication | ||
/// | ||
/// The `VertexAIEmbeddings` wrapper delegates authentication to the | ||
/// [googleapis_auth](https://pub.dev/packages/googleapis_auth) package. | ||
/// | ||
/// To create an instance of `VertexAIEmbeddings` you need to provide an | ||
/// [`AuthClient`](https://pub.dev/documentation/googleapis_auth/latest/googleapis_auth/AuthClient-class.html) | ||
/// instance. | ||
/// | ||
/// There are several ways to obtain an `AuthClient` depending on your use case. | ||
/// Check out the [googleapis_auth](https://pub.dev/packages/googleapis_auth) | ||
/// package documentation for more details. | ||
/// | ||
/// Example using a service account JSON: | ||
/// | ||
/// ```dart | ||
/// final serviceAccountCredentials = ServiceAccountCredentials.fromJson( | ||
/// json.decode(serviceAccountJson), | ||
/// ); | ||
/// final authClient = await clientViaServiceAccount( | ||
/// serviceAccountCredentials, | ||
/// [VertexAIEmbeddings.cloudPlatformScope], | ||
/// ); | ||
/// final vertexAi = VertexAIEmbeddings( | ||
/// authHttpClient: authClient, | ||
/// project: 'your-project-id', | ||
/// ); | ||
/// ``` | ||
/// | ||
/// The service account should have the following | ||
/// [permission](https://cloud.google.com/vertex-ai/docs/general/iam-permissions): | ||
/// - `aiplatform.endpoints.predict` | ||
/// | ||
/// The required[OAuth2 scope](https://developers.google.com/identity/protocols/oauth2/scopes) | ||
/// is: | ||
/// - `https://www.googleapis.com/auth/cloud-platform` (you can use the | ||
/// constant `VertexAIEmbeddings.cloudPlatformScope`) | ||
/// | ||
/// See: https://cloud.google.com/vertex-ai/docs/generative-ai/access-control | ||
/// {@endtemplate} | ||
class VertexAIEmbeddings implements Embeddings { | ||
/// {@macro vertex_ai_embeddings} | ||
VertexAIEmbeddings({ | ||
required final AuthClient authHttpClient, | ||
required final String project, | ||
final String location = 'us-central1', | ||
final String rootUrl = 'https://us-central1-aiplatform.googleapis.com/', | ||
this.publisher = 'google', | ||
this.model = 'textembedding-gecko', | ||
this.batchSize = 5, | ||
}) : client = VertexAIGenAIClient( | ||
authHttpClient: authHttpClient, | ||
project: project, | ||
location: location, | ||
rootUrl: rootUrl, | ||
); | ||
|
||
/// A client for interacting with Vertex AI API. | ||
final VertexAIGenAIClient client; | ||
|
||
/// The publisher of the model. | ||
/// | ||
/// Use `google` for first-party models. | ||
final String publisher; | ||
|
||
/// The text model to use. | ||
/// | ||
/// To use the latest model version, specify the model name without a version | ||
/// number (e.g. `textembedding-gecko`). | ||
/// To use a stable model version, specify the model version number | ||
/// (e.g. `textembedding-gecko@001`). | ||
/// | ||
/// You can find a list of available models here: | ||
/// https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models | ||
final String model; | ||
|
||
/// The maximum number of documents to embed in a single request. | ||
/// | ||
/// `textembedding-gecko` has a limit of up to 5 input texts per request. | ||
final int batchSize; | ||
|
||
/// Scope required for Vertex AI API calls. | ||
static const cloudPlatformScope = VertexAIGenAIClient.cloudPlatformScope; | ||
|
||
@override | ||
Future<List<List<double>>> embedDocuments( | ||
final List<String> documents, | ||
) async { | ||
final subDocs = chunkArray(documents, chunkSize: batchSize); | ||
|
||
final embeddings = await Future.wait( | ||
subDocs.map((final docsBatch) async { | ||
final data = await client.textEmbeddings.predict( | ||
content: docsBatch, | ||
publisher: publisher, | ||
model: model, | ||
); | ||
return data.predictions | ||
.map((final p) => p.values) | ||
.toList(growable: false); | ||
}), | ||
); | ||
|
||
return embeddings.expand((final e) => e).toList(growable: false); | ||
} | ||
|
||
@override | ||
Future<List<double>> embedQuery(final String query) async { | ||
final data = await client.textEmbeddings.predict( | ||
content: [query], | ||
publisher: publisher, | ||
model: model, | ||
); | ||
return data.predictions.first.values; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 35 additions & 0 deletions
35
packages/langchain_google/test/embeddings/vertex_ai_test.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
@TestOn('vm') | ||
library; // Uses dart:io | ||
|
||
import 'dart:io'; | ||
|
||
import 'package:langchain_google/langchain_google.dart'; | ||
import 'package:test/test.dart'; | ||
|
||
import '../utils/auth.dart'; | ||
|
||
void main() async { | ||
final authHttpClient = await getAuthHttpClient(); | ||
group('VertexAIEmbeddings tests', () { | ||
test('Test VertexAIEmbeddings.embedQuery', () async { | ||
final embeddings = VertexAIEmbeddings( | ||
authHttpClient: authHttpClient, | ||
project: Platform.environment['VERTEX_AI_PROJECT_ID']!, | ||
); | ||
final res = await embeddings.embedQuery('Hello world'); | ||
expect(res.length, 768); | ||
}); | ||
|
||
test('Test VertexAIEmbeddings.embedDocuments', () async { | ||
final embeddings = VertexAIEmbeddings( | ||
authHttpClient: authHttpClient, | ||
project: Platform.environment['VERTEX_AI_PROJECT_ID']!, | ||
batchSize: 1, | ||
); | ||
final res = await embeddings.embedDocuments(['Hello world', 'Bye bye']); | ||
expect(res.length, 2); | ||
expect(res[0].length, 768); | ||
expect(res[1].length, 768); | ||
}); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import 'dart:convert'; | ||
import 'dart:io'; | ||
|
||
import 'package:googleapis_auth/auth_io.dart'; | ||
import 'package:langchain_google/langchain_google.dart'; | ||
|
||
Future<AuthClient> getAuthHttpClient() async { | ||
final serviceAccountCredentials = ServiceAccountCredentials.fromJson( | ||
json.decode(Platform.environment['VERTEX_AI_SERVICE_ACCOUNT']!), | ||
); | ||
return clientViaServiceAccount( | ||
serviceAccountCredentials, | ||
[VertexAI.cloudPlatformScope], | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.