Skip to content

Commit

Permalink
feat(embeddings): Integrate Google Vertex AI PaLM Embeddings (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Aug 13, 2023
1 parent 3897595 commit d777ecc
Show file tree
Hide file tree
Showing 14 changed files with 280 additions and 38 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ LangChain.dart has a modular design where the core [langchain](https://pub.dev/p
package provides the LangChain API and each integration with a model provider, data store, etc. is
provided by a separate package.

| Package | Version | Description | Models | Data conn. | Memory | Agents & Tools |
|---------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------|--------------------|--------|------------|--------|----------------|
| [langchain](https://pub.dev/packages/langchain) | [![langchain](https://img.shields.io/pub/v/langchain.svg)](https://pub.dev/packages/langchain) | Core LangChain API |||||
| [langchain_openai](https://pub.dev/packages/langchain_openai) | [![langchain_openai](https://img.shields.io/pub/v/langchain_openai.svg)](https://pub.dev/packages/langchain_openai) | OpenAI integration ||| ||
| Package | Version | Description | Models | Data conn. | Chains | Agents & Tools |
|---------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------|-------------------------------------|--------|------------|--------|----------------|
| [langchain](https://pub.dev/packages/langchain) | [![langchain](https://img.shields.io/pub/v/langchain.svg)](https://pub.dev/packages/langchain) | Core LangChain API |||||
| [langchain_openai](https://pub.dev/packages/langchain_openai) | [![langchain_openai](https://img.shields.io/pub/v/langchain_openai.svg)](https://pub.dev/packages/langchain_openai) | OpenAI integration |||||
| [langchain_google](https://pub.dev/packages/langchain_google) | [![langchain_google](https://img.shields.io/pub/v/langchain_google.svg)](https://pub.dev/packages/langchain_google) | Google integration (VertexAI, PaLM) ||| | |

## Getting started

Expand Down
20 changes: 19 additions & 1 deletion packages/langchain_google/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
# 🦜️🔗 LangChain.dart
# 🦜️🔗 LangChain.dart / Google

[![tests](https://img.shields.io/github/actions/workflow/status/davidmigloz/langchain_dart/test.yaml?logo=github&label=tests)](https://github.com/davidmigloz/langchain_dart/actions/workflows/test.yaml)
[![docs](https://img.shields.io/github/actions/workflow/status/davidmigloz/langchain_dart/pages%2Fpages-build-deployment?logo=github&label=docs)](https://github.com/davidmigloz/langchain_dart/actions/workflows/pages/pages-build-deployment)
[![langchain_google](https://img.shields.io/pub/v/langchain_google.svg)](https://pub.dev/packages/langchain_google)
[![](https://dcbadge.vercel.app/api/server/x4qbhqecVR?style=flat)](https://discord.gg/x4qbhqecVR)
[![MIT](https://img.shields.io/badge/license-MIT-purple.svg)](https://github.com/davidmigloz/langchain_dart/blob/main/LICENSE)

Google module for [LangChain.dart](https://github.com/davidmigloz/langchain_dart).

## Features

- LLMs:
* `VertexAI`: wrapper around GCP Vertex AI text models API (aka PaLM API for
text).
- Chat models:
* `ChatVertexAI`: wrapper around GCP Vertex AI text chat models API (aka PaLM
API for chat).
- Embeddings:
* `VertexAIEmbeddings`: wrapper around GCP Vertex AI text embedding models
API.

## License

LangChain.dart is licensed under the
Expand Down
1 change: 1 addition & 0 deletions packages/langchain_google/lib/langchain_google.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ library;

export 'src/chat_models/chat_models.dart';
export 'src/doc_loaders/doc_loaders.dart';
export 'src/embeddings/embeddings.dart';
export 'src/llms/llms.dart';
10 changes: 5 additions & 5 deletions packages/langchain_google/lib/src/chat_models/vertex_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import 'models/mappers.dart';
import 'models/models.dart';

/// {@template chat_vertex_ai}
/// Wrapper around GCP Vertex AI text chat models API (aka PaLM API).
/// Wrapper around GCP Vertex AI text chat models API (aka PaLM API for chat).
///
/// Example:
/// ```dart
Expand All @@ -30,10 +30,10 @@ import 'models/models.dart';
///
/// ### Authentication
///
/// The `VertexAI` wrapper delegates authentication to the
/// The `ChatVertexAI` wrapper delegates authentication to the
/// [googleapis_auth](https://pub.dev/packages/googleapis_auth) package.
///
/// To create an instance of `VertexAI` you need to provide an
/// To create an instance of `ChatVertexAI` you need to provide an
/// [`AuthClient`](https://pub.dev/documentation/googleapis_auth/latest/googleapis_auth/AuthClient-class.html)
/// instance.
///
Expand Down Expand Up @@ -99,9 +99,9 @@ class ChatVertexAI extends BaseChatModel<ChatVertexAIOptions> {
/// The text model to use.
///
/// To use the latest model version, specify the model name without a version
/// number (e.g. `text-bison`).
/// number (e.g. `chat-bison`).
/// To use a stable model version, specify the model version number
/// (e.g. `text-bison@001`).
/// (e.g. `chat-bison@001`).
///
/// You can find a list of available models here:
/// https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export 'vertex_ai.dart';
142 changes: 142 additions & 0 deletions packages/langchain_google/lib/src/embeddings/vertex_ai.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import 'package:googleapis_auth/googleapis_auth.dart';
import 'package:langchain/langchain.dart';
import 'package:vertex_ai/vertex_ai.dart';

/// {@template vertex_ai_embeddings}
/// Wrapper around GCP Vertex AI text embedding models API
///
/// Example:
/// ```dart
/// final embeddings = VertexAIEmbeddings(
/// authHttpClient: authClient,
/// project: 'your-project-id',
/// );
/// final result = await embeddings.embedQuery('Hello world');
/// ```
///
/// Vertex AI documentation:
/// https://cloud.google.com/vertex-ai/docs/generative-ai/language-model-overview
///
/// ### Set up your Google Cloud Platform project
///
/// 1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager).
/// 2. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).
/// 3. [Enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).
/// 4. [Configure the Vertex AI location](https://cloud.google.com/vertex-ai/docs/general/locations).
///
/// ### Authentication
///
/// The `VertexAIEmbeddings` wrapper delegates authentication to the
/// [googleapis_auth](https://pub.dev/packages/googleapis_auth) package.
///
/// To create an instance of `VertexAIEmbeddings` you need to provide an
/// [`AuthClient`](https://pub.dev/documentation/googleapis_auth/latest/googleapis_auth/AuthClient-class.html)
/// instance.
///
/// There are several ways to obtain an `AuthClient` depending on your use case.
/// Check out the [googleapis_auth](https://pub.dev/packages/googleapis_auth)
/// package documentation for more details.
///
/// Example using a service account JSON:
///
/// ```dart
/// final serviceAccountCredentials = ServiceAccountCredentials.fromJson(
/// json.decode(serviceAccountJson),
/// );
/// final authClient = await clientViaServiceAccount(
/// serviceAccountCredentials,
/// [VertexAIEmbeddings.cloudPlatformScope],
/// );
/// final vertexAi = VertexAIEmbeddings(
/// authHttpClient: authClient,
/// project: 'your-project-id',
/// );
/// ```
///
/// The service account should have the following
/// [permission](https://cloud.google.com/vertex-ai/docs/general/iam-permissions):
/// - `aiplatform.endpoints.predict`
///
/// The required[OAuth2 scope](https://developers.google.com/identity/protocols/oauth2/scopes)
/// is:
/// - `https://www.googleapis.com/auth/cloud-platform` (you can use the
/// constant `VertexAIEmbeddings.cloudPlatformScope`)
///
/// See: https://cloud.google.com/vertex-ai/docs/generative-ai/access-control
/// {@endtemplate}
class VertexAIEmbeddings implements Embeddings {
/// {@macro vertex_ai_embeddings}
VertexAIEmbeddings({
required final AuthClient authHttpClient,
required final String project,
final String location = 'us-central1',
final String rootUrl = 'https://us-central1-aiplatform.googleapis.com/',
this.publisher = 'google',
this.model = 'textembedding-gecko',
this.batchSize = 5,
}) : client = VertexAIGenAIClient(
authHttpClient: authHttpClient,
project: project,
location: location,
rootUrl: rootUrl,
);

/// A client for interacting with Vertex AI API.
final VertexAIGenAIClient client;

/// The publisher of the model.
///
/// Use `google` for first-party models.
final String publisher;

/// The text model to use.
///
/// To use the latest model version, specify the model name without a version
/// number (e.g. `textembedding-gecko`).
/// To use a stable model version, specify the model version number
/// (e.g. `textembedding-gecko@001`).
///
/// You can find a list of available models here:
/// https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
final String model;

/// The maximum number of documents to embed in a single request.
///
/// `textembedding-gecko` has a limit of up to 5 input texts per request.
final int batchSize;

/// Scope required for Vertex AI API calls.
static const cloudPlatformScope = VertexAIGenAIClient.cloudPlatformScope;

@override
Future<List<List<double>>> embedDocuments(
final List<String> documents,
) async {
final subDocs = chunkArray(documents, chunkSize: batchSize);

final embeddings = await Future.wait(
subDocs.map((final docsBatch) async {
final data = await client.textEmbeddings.predict(
content: docsBatch,
publisher: publisher,
model: model,
);
return data.predictions
.map((final p) => p.values)
.toList(growable: false);
}),
);

return embeddings.expand((final e) => e).toList(growable: false);
}

@override
Future<List<double>> embedQuery(final String query) async {
final data = await client.textEmbeddings.predict(
content: [query],
publisher: publisher,
model: model,
);
return data.predictions.first.values;
}
}
2 changes: 1 addition & 1 deletion packages/langchain_google/lib/src/llms/vertex_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import 'models/mappers.dart';
import 'models/models.dart';

/// {@template vertex_ai}
/// Wrapper around GCP Vertex AI text models API (aka PaLM API).
/// Wrapper around GCP Vertex AI text models API (aka PaLM API for text).
///
/// Example:
/// ```dart
Expand Down
16 changes: 3 additions & 13 deletions packages/langchain_google/test/chat_models/vertex_ai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
@TestOn('vm')
library; // Uses dart:io

import 'dart:convert';
import 'dart:io';

import 'package:googleapis_auth/auth_io.dart';
import 'package:langchain/langchain.dart';
import 'package:langchain_google/langchain_google.dart';
import 'package:test/test.dart';

import '../utils/auth.dart';

void main() async {
final authHttpClient = await _getAuthHttpClient();
final authHttpClient = await getAuthHttpClient();
group('ChatVertexAI tests', () {
test('Test ChatVertexAI parameters', () async {
final llm = ChatVertexAI(
Expand Down Expand Up @@ -140,13 +140,3 @@ void main() async {
});
});
}

Future<AuthClient> _getAuthHttpClient() async {
final serviceAccountCredentials = ServiceAccountCredentials.fromJson(
json.decode(Platform.environment['VERTEX_AI_SERVICE_ACCOUNT']!),
);
return clientViaServiceAccount(
serviceAccountCredentials,
[ChatVertexAI.cloudPlatformScope],
);
}
35 changes: 35 additions & 0 deletions packages/langchain_google/test/embeddings/vertex_ai_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
@TestOn('vm')
library; // Uses dart:io

import 'dart:io';

import 'package:langchain_google/langchain_google.dart';
import 'package:test/test.dart';

import '../utils/auth.dart';

void main() async {
final authHttpClient = await getAuthHttpClient();
group('VertexAIEmbeddings tests', () {
test('Test VertexAIEmbeddings.embedQuery', () async {
final embeddings = VertexAIEmbeddings(
authHttpClient: authHttpClient,
project: Platform.environment['VERTEX_AI_PROJECT_ID']!,
);
final res = await embeddings.embedQuery('Hello world');
expect(res.length, 768);
});

test('Test VertexAIEmbeddings.embedDocuments', () async {
final embeddings = VertexAIEmbeddings(
authHttpClient: authHttpClient,
project: Platform.environment['VERTEX_AI_PROJECT_ID']!,
batchSize: 1,
);
final res = await embeddings.embedDocuments(['Hello world', 'Bye bye']);
expect(res.length, 2);
expect(res[0].length, 768);
expect(res[1].length, 768);
});
});
}
16 changes: 3 additions & 13 deletions packages/langchain_google/test/llms/vertex_ai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
@TestOn('vm')
library; // Uses dart:io

import 'dart:convert';
import 'dart:io';

import 'package:googleapis_auth/auth_io.dart';
import 'package:langchain/langchain.dart';
import 'package:langchain_google/langchain_google.dart';
import 'package:test/test.dart';

import '../utils/auth.dart';

Future<void> main() async {
final authHttpClient = await _getAuthHttpClient();
final authHttpClient = await getAuthHttpClient();
group('VertexAI tests', () {
test('Test VertexAI parameters', () async {
final llm = VertexAI(
Expand Down Expand Up @@ -93,13 +93,3 @@ Future<void> main() async {
});
});
}

Future<AuthClient> _getAuthHttpClient() async {
final serviceAccountCredentials = ServiceAccountCredentials.fromJson(
json.decode(Platform.environment['VERTEX_AI_SERVICE_ACCOUNT']!),
);
return clientViaServiceAccount(
serviceAccountCredentials,
[VertexAI.cloudPlatformScope],
);
}
15 changes: 15 additions & 0 deletions packages/langchain_google/test/utils/auth.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import 'dart:convert';
import 'dart:io';

import 'package:googleapis_auth/auth_io.dart';
import 'package:langchain_google/langchain_google.dart';

Future<AuthClient> getAuthHttpClient() async {
final serviceAccountCredentials = ServiceAccountCredentials.fromJson(
json.decode(Platform.environment['VERTEX_AI_SERVICE_ACCOUNT']!),
);
return clientViaServiceAccount(
serviceAccountCredentials,
[VertexAI.cloudPlatformScope],
);
}
23 changes: 22 additions & 1 deletion packages/langchain_openai/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
# 🦜️🔗 LangChain.dart
# 🦜️🔗 LangChain.dart / OpenAI

[![tests](https://img.shields.io/github/actions/workflow/status/davidmigloz/langchain_dart/test.yaml?logo=github&label=tests)](https://github.com/davidmigloz/langchain_dart/actions/workflows/test.yaml)
[![docs](https://img.shields.io/github/actions/workflow/status/davidmigloz/langchain_dart/pages%2Fpages-build-deployment?logo=github&label=docs)](https://github.com/davidmigloz/langchain_dart/actions/workflows/pages/pages-build-deployment)
[![langchain_openai](https://img.shields.io/pub/v/langchain_openai.svg)](https://pub.dev/packages/langchain_openai)
[![](https://dcbadge.vercel.app/api/server/x4qbhqecVR?style=flat)](https://discord.gg/x4qbhqecVR)
[![MIT](https://img.shields.io/badge/license-MIT-purple.svg)](https://github.com/davidmigloz/langchain_dart/blob/main/LICENSE)

OpenAI module for [LangChain.dart](https://github.com/davidmigloz/langchain_dart).

## Features

- LLMs:
* `OpenAI`: wrapper around OpenAI Completions API.
- Chat models:
* `ChatOpenAI`: wrapper around OpenAI Chat API.
- Embeddings:
* `OpenAIEmbeddings`: wrapper around OpenAI Embeddings API.
- Chains:
* `OpenAIQAWithStructureChain` a chain that answer questions in the specified
structure.
* `OpenAIQAWithSourcesChain`: a chain that answer questions providing sources.
- Agents:
* `OpenAIFunctionsAgent`: an agent driven by OpenAIs Functions powered API.

## License

LangChain.dart is licensed under the
Expand Down
Loading

0 comments on commit d777ecc

Please sign in to comment.