From d777eccc0c81c58b322f28e6e3c4a8763f3f84b7 Mon Sep 17 00:00:00 2001 From: David Miguel Lozano Date: Sun, 13 Aug 2023 11:37:35 +0200 Subject: [PATCH] feat(embeddings): Integrate Google Vertex AI PaLM Embeddings (#100) --- README.md | 9 +- packages/langchain_google/README.md | 20 ++- .../lib/langchain_google.dart | 1 + .../lib/src/chat_models/vertex_ai.dart | 10 +- .../lib/src/embeddings/embeddings.dart | 1 + .../lib/src/embeddings/vertex_ai.dart | 142 ++++++++++++++++++ .../lib/src/llms/vertex_ai.dart | 2 +- .../test/chat_models/vertex_ai_test.dart | 16 +- .../test/embeddings/vertex_ai_test.dart | 35 +++++ .../test/llms/vertex_ai_test.dart | 16 +- .../langchain_google/test/utils/auth.dart | 15 ++ packages/langchain_openai/README.md | 23 ++- .../lib/src/chains/qa_with_sources.dart | 23 +++ packages/vertex_ai/README.md | 5 + 14 files changed, 280 insertions(+), 38 deletions(-) create mode 100644 packages/langchain_google/lib/src/embeddings/embeddings.dart create mode 100644 packages/langchain_google/lib/src/embeddings/vertex_ai.dart create mode 100644 packages/langchain_google/test/embeddings/vertex_ai_test.dart create mode 100644 packages/langchain_google/test/utils/auth.dart diff --git a/README.md b/README.md index 8ef27b36..8b6cdec1 100644 --- a/README.md +++ b/README.md @@ -58,10 +58,11 @@ LangChain.dart has a modular design where the core [langchain](https://pub.dev/p package provides the LangChain API and each integration with a model provider, data store, etc. is provided by a separate package. -| Package | Version | Description | Models | Data conn. | Memory | Agents & Tools | -|---------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------|--------------------|--------|------------|--------|----------------| -| [langchain](https://pub.dev/packages/langchain) | [![langchain](https://img.shields.io/pub/v/langchain.svg)](https://pub.dev/packages/langchain) | Core LangChain API | ★ | ★ | ★ | ★ | -| [langchain_openai](https://pub.dev/packages/langchain_openai) | [![langchain_openai](https://img.shields.io/pub/v/langchain_openai.svg)](https://pub.dev/packages/langchain_openai) | OpenAI integration | ✔ | ✔ | | ✔ | +| Package | Version | Description | Models | Data conn. | Chains | Agents & Tools | +|---------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------|-------------------------------------|--------|------------|--------|----------------| +| [langchain](https://pub.dev/packages/langchain) | [![langchain](https://img.shields.io/pub/v/langchain.svg)](https://pub.dev/packages/langchain) | Core LangChain API | ★ | ★ | ★ | ★ | +| [langchain_openai](https://pub.dev/packages/langchain_openai) | [![langchain_openai](https://img.shields.io/pub/v/langchain_openai.svg)](https://pub.dev/packages/langchain_openai) | OpenAI integration | ✔ | ✔ | ✔ | ✔ | +| [langchain_google](https://pub.dev/packages/langchain_google) | [![langchain_google](https://img.shields.io/pub/v/langchain_google.svg)](https://pub.dev/packages/langchain_google) | Google integration (VertexAI, PaLM) | ✔ | ✔ | | | ## Getting started diff --git a/packages/langchain_google/README.md b/packages/langchain_google/README.md index 714096b5..d438060d 100644 --- a/packages/langchain_google/README.md +++ b/packages/langchain_google/README.md @@ -1,7 +1,25 @@ -# 🦜️🔗 LangChain.dart +# 🦜️🔗 LangChain.dart / Google + +[![tests](https://img.shields.io/github/actions/workflow/status/davidmigloz/langchain_dart/test.yaml?logo=github&label=tests)](https://github.com/davidmigloz/langchain_dart/actions/workflows/test.yaml) +[![docs](https://img.shields.io/github/actions/workflow/status/davidmigloz/langchain_dart/pages%2Fpages-build-deployment?logo=github&label=docs)](https://github.com/davidmigloz/langchain_dart/actions/workflows/pages/pages-build-deployment) +[![langchain_google](https://img.shields.io/pub/v/langchain_google.svg)](https://pub.dev/packages/langchain_google) +[![](https://dcbadge.vercel.app/api/server/x4qbhqecVR?style=flat)](https://discord.gg/x4qbhqecVR) +[![MIT](https://img.shields.io/badge/license-MIT-purple.svg)](https://github.com/davidmigloz/langchain_dart/blob/main/LICENSE) Google module for [LangChain.dart](https://github.com/davidmigloz/langchain_dart). +## Features + +- LLMs: + * `VertexAI`: wrapper around GCP Vertex AI text models API (aka PaLM API for + text). +- Chat models: + * `ChatVertexAI`: wrapper around GCP Vertex AI text chat models API (aka PaLM + API for chat). +- Embeddings: + * `VertexAIEmbeddings`: wrapper around GCP Vertex AI text embedding models + API. + ## License LangChain.dart is licensed under the diff --git a/packages/langchain_google/lib/langchain_google.dart b/packages/langchain_google/lib/langchain_google.dart index 8db9cbc0..6b260725 100644 --- a/packages/langchain_google/lib/langchain_google.dart +++ b/packages/langchain_google/lib/langchain_google.dart @@ -3,4 +3,5 @@ library; export 'src/chat_models/chat_models.dart'; export 'src/doc_loaders/doc_loaders.dart'; +export 'src/embeddings/embeddings.dart'; export 'src/llms/llms.dart'; diff --git a/packages/langchain_google/lib/src/chat_models/vertex_ai.dart b/packages/langchain_google/lib/src/chat_models/vertex_ai.dart index 1da56dbb..a5b21833 100644 --- a/packages/langchain_google/lib/src/chat_models/vertex_ai.dart +++ b/packages/langchain_google/lib/src/chat_models/vertex_ai.dart @@ -7,7 +7,7 @@ import 'models/mappers.dart'; import 'models/models.dart'; /// {@template chat_vertex_ai} -/// Wrapper around GCP Vertex AI text chat models API (aka PaLM API). +/// Wrapper around GCP Vertex AI text chat models API (aka PaLM API for chat). /// /// Example: /// ```dart @@ -30,10 +30,10 @@ import 'models/models.dart'; /// /// ### Authentication /// -/// The `VertexAI` wrapper delegates authentication to the +/// The `ChatVertexAI` wrapper delegates authentication to the /// [googleapis_auth](https://pub.dev/packages/googleapis_auth) package. /// -/// To create an instance of `VertexAI` you need to provide an +/// To create an instance of `ChatVertexAI` you need to provide an /// [`AuthClient`](https://pub.dev/documentation/googleapis_auth/latest/googleapis_auth/AuthClient-class.html) /// instance. /// @@ -99,9 +99,9 @@ class ChatVertexAI extends BaseChatModel { /// The text model to use. /// /// To use the latest model version, specify the model name without a version - /// number (e.g. `text-bison`). + /// number (e.g. `chat-bison`). /// To use a stable model version, specify the model version number - /// (e.g. `text-bison@001`). + /// (e.g. `chat-bison@001`). /// /// You can find a list of available models here: /// https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models diff --git a/packages/langchain_google/lib/src/embeddings/embeddings.dart b/packages/langchain_google/lib/src/embeddings/embeddings.dart new file mode 100644 index 00000000..792ff901 --- /dev/null +++ b/packages/langchain_google/lib/src/embeddings/embeddings.dart @@ -0,0 +1 @@ +export 'vertex_ai.dart'; diff --git a/packages/langchain_google/lib/src/embeddings/vertex_ai.dart b/packages/langchain_google/lib/src/embeddings/vertex_ai.dart new file mode 100644 index 00000000..6c384ba1 --- /dev/null +++ b/packages/langchain_google/lib/src/embeddings/vertex_ai.dart @@ -0,0 +1,142 @@ +import 'package:googleapis_auth/googleapis_auth.dart'; +import 'package:langchain/langchain.dart'; +import 'package:vertex_ai/vertex_ai.dart'; + +/// {@template vertex_ai_embeddings} +/// Wrapper around GCP Vertex AI text embedding models API +/// +/// Example: +/// ```dart +/// final embeddings = VertexAIEmbeddings( +/// authHttpClient: authClient, +/// project: 'your-project-id', +/// ); +/// final result = await embeddings.embedQuery('Hello world'); +/// ``` +/// +/// Vertex AI documentation: +/// https://cloud.google.com/vertex-ai/docs/generative-ai/language-model-overview +/// +/// ### Set up your Google Cloud Platform project +/// +/// 1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager). +/// 2. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project). +/// 3. [Enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com). +/// 4. [Configure the Vertex AI location](https://cloud.google.com/vertex-ai/docs/general/locations). +/// +/// ### Authentication +/// +/// The `VertexAIEmbeddings` wrapper delegates authentication to the +/// [googleapis_auth](https://pub.dev/packages/googleapis_auth) package. +/// +/// To create an instance of `VertexAIEmbeddings` you need to provide an +/// [`AuthClient`](https://pub.dev/documentation/googleapis_auth/latest/googleapis_auth/AuthClient-class.html) +/// instance. +/// +/// There are several ways to obtain an `AuthClient` depending on your use case. +/// Check out the [googleapis_auth](https://pub.dev/packages/googleapis_auth) +/// package documentation for more details. +/// +/// Example using a service account JSON: +/// +/// ```dart +/// final serviceAccountCredentials = ServiceAccountCredentials.fromJson( +/// json.decode(serviceAccountJson), +/// ); +/// final authClient = await clientViaServiceAccount( +/// serviceAccountCredentials, +/// [VertexAIEmbeddings.cloudPlatformScope], +/// ); +/// final vertexAi = VertexAIEmbeddings( +/// authHttpClient: authClient, +/// project: 'your-project-id', +/// ); +/// ``` +/// +/// The service account should have the following +/// [permission](https://cloud.google.com/vertex-ai/docs/general/iam-permissions): +/// - `aiplatform.endpoints.predict` +/// +/// The required[OAuth2 scope](https://developers.google.com/identity/protocols/oauth2/scopes) +/// is: +/// - `https://www.googleapis.com/auth/cloud-platform` (you can use the +/// constant `VertexAIEmbeddings.cloudPlatformScope`) +/// +/// See: https://cloud.google.com/vertex-ai/docs/generative-ai/access-control +/// {@endtemplate} +class VertexAIEmbeddings implements Embeddings { + /// {@macro vertex_ai_embeddings} + VertexAIEmbeddings({ + required final AuthClient authHttpClient, + required final String project, + final String location = 'us-central1', + final String rootUrl = 'https://us-central1-aiplatform.googleapis.com/', + this.publisher = 'google', + this.model = 'textembedding-gecko', + this.batchSize = 5, + }) : client = VertexAIGenAIClient( + authHttpClient: authHttpClient, + project: project, + location: location, + rootUrl: rootUrl, + ); + + /// A client for interacting with Vertex AI API. + final VertexAIGenAIClient client; + + /// The publisher of the model. + /// + /// Use `google` for first-party models. + final String publisher; + + /// The text model to use. + /// + /// To use the latest model version, specify the model name without a version + /// number (e.g. `textembedding-gecko`). + /// To use a stable model version, specify the model version number + /// (e.g. `textembedding-gecko@001`). + /// + /// You can find a list of available models here: + /// https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models + final String model; + + /// The maximum number of documents to embed in a single request. + /// + /// `textembedding-gecko` has a limit of up to 5 input texts per request. + final int batchSize; + + /// Scope required for Vertex AI API calls. + static const cloudPlatformScope = VertexAIGenAIClient.cloudPlatformScope; + + @override + Future>> embedDocuments( + final List documents, + ) async { + final subDocs = chunkArray(documents, chunkSize: batchSize); + + final embeddings = await Future.wait( + subDocs.map((final docsBatch) async { + final data = await client.textEmbeddings.predict( + content: docsBatch, + publisher: publisher, + model: model, + ); + return data.predictions + .map((final p) => p.values) + .toList(growable: false); + }), + ); + + return embeddings.expand((final e) => e).toList(growable: false); + } + + @override + Future> embedQuery(final String query) async { + final data = await client.textEmbeddings.predict( + content: [query], + publisher: publisher, + model: model, + ); + return data.predictions.first.values; + } +} diff --git a/packages/langchain_google/lib/src/llms/vertex_ai.dart b/packages/langchain_google/lib/src/llms/vertex_ai.dart index 2167ac16..8b1bca27 100644 --- a/packages/langchain_google/lib/src/llms/vertex_ai.dart +++ b/packages/langchain_google/lib/src/llms/vertex_ai.dart @@ -7,7 +7,7 @@ import 'models/mappers.dart'; import 'models/models.dart'; /// {@template vertex_ai} -/// Wrapper around GCP Vertex AI text models API (aka PaLM API). +/// Wrapper around GCP Vertex AI text models API (aka PaLM API for text). /// /// Example: /// ```dart diff --git a/packages/langchain_google/test/chat_models/vertex_ai_test.dart b/packages/langchain_google/test/chat_models/vertex_ai_test.dart index 9dd06c86..d7a9f228 100644 --- a/packages/langchain_google/test/chat_models/vertex_ai_test.dart +++ b/packages/langchain_google/test/chat_models/vertex_ai_test.dart @@ -2,16 +2,16 @@ @TestOn('vm') library; // Uses dart:io -import 'dart:convert'; import 'dart:io'; -import 'package:googleapis_auth/auth_io.dart'; import 'package:langchain/langchain.dart'; import 'package:langchain_google/langchain_google.dart'; import 'package:test/test.dart'; +import '../utils/auth.dart'; + void main() async { - final authHttpClient = await _getAuthHttpClient(); + final authHttpClient = await getAuthHttpClient(); group('ChatVertexAI tests', () { test('Test ChatVertexAI parameters', () async { final llm = ChatVertexAI( @@ -140,13 +140,3 @@ void main() async { }); }); } - -Future _getAuthHttpClient() async { - final serviceAccountCredentials = ServiceAccountCredentials.fromJson( - json.decode(Platform.environment['VERTEX_AI_SERVICE_ACCOUNT']!), - ); - return clientViaServiceAccount( - serviceAccountCredentials, - [ChatVertexAI.cloudPlatformScope], - ); -} diff --git a/packages/langchain_google/test/embeddings/vertex_ai_test.dart b/packages/langchain_google/test/embeddings/vertex_ai_test.dart new file mode 100644 index 00000000..c9bee55b --- /dev/null +++ b/packages/langchain_google/test/embeddings/vertex_ai_test.dart @@ -0,0 +1,35 @@ +@TestOn('vm') +library; // Uses dart:io + +import 'dart:io'; + +import 'package:langchain_google/langchain_google.dart'; +import 'package:test/test.dart'; + +import '../utils/auth.dart'; + +void main() async { + final authHttpClient = await getAuthHttpClient(); + group('VertexAIEmbeddings tests', () { + test('Test VertexAIEmbeddings.embedQuery', () async { + final embeddings = VertexAIEmbeddings( + authHttpClient: authHttpClient, + project: Platform.environment['VERTEX_AI_PROJECT_ID']!, + ); + final res = await embeddings.embedQuery('Hello world'); + expect(res.length, 768); + }); + + test('Test VertexAIEmbeddings.embedDocuments', () async { + final embeddings = VertexAIEmbeddings( + authHttpClient: authHttpClient, + project: Platform.environment['VERTEX_AI_PROJECT_ID']!, + batchSize: 1, + ); + final res = await embeddings.embedDocuments(['Hello world', 'Bye bye']); + expect(res.length, 2); + expect(res[0].length, 768); + expect(res[1].length, 768); + }); + }); +} diff --git a/packages/langchain_google/test/llms/vertex_ai_test.dart b/packages/langchain_google/test/llms/vertex_ai_test.dart index 2f595e9b..1991ca5f 100644 --- a/packages/langchain_google/test/llms/vertex_ai_test.dart +++ b/packages/langchain_google/test/llms/vertex_ai_test.dart @@ -2,16 +2,16 @@ @TestOn('vm') library; // Uses dart:io -import 'dart:convert'; import 'dart:io'; -import 'package:googleapis_auth/auth_io.dart'; import 'package:langchain/langchain.dart'; import 'package:langchain_google/langchain_google.dart'; import 'package:test/test.dart'; +import '../utils/auth.dart'; + Future main() async { - final authHttpClient = await _getAuthHttpClient(); + final authHttpClient = await getAuthHttpClient(); group('VertexAI tests', () { test('Test VertexAI parameters', () async { final llm = VertexAI( @@ -93,13 +93,3 @@ Future main() async { }); }); } - -Future _getAuthHttpClient() async { - final serviceAccountCredentials = ServiceAccountCredentials.fromJson( - json.decode(Platform.environment['VERTEX_AI_SERVICE_ACCOUNT']!), - ); - return clientViaServiceAccount( - serviceAccountCredentials, - [VertexAI.cloudPlatformScope], - ); -} diff --git a/packages/langchain_google/test/utils/auth.dart b/packages/langchain_google/test/utils/auth.dart new file mode 100644 index 00000000..45349e1a --- /dev/null +++ b/packages/langchain_google/test/utils/auth.dart @@ -0,0 +1,15 @@ +import 'dart:convert'; +import 'dart:io'; + +import 'package:googleapis_auth/auth_io.dart'; +import 'package:langchain_google/langchain_google.dart'; + +Future getAuthHttpClient() async { + final serviceAccountCredentials = ServiceAccountCredentials.fromJson( + json.decode(Platform.environment['VERTEX_AI_SERVICE_ACCOUNT']!), + ); + return clientViaServiceAccount( + serviceAccountCredentials, + [VertexAI.cloudPlatformScope], + ); +} diff --git a/packages/langchain_openai/README.md b/packages/langchain_openai/README.md index da9cba0d..5dea5a5b 100644 --- a/packages/langchain_openai/README.md +++ b/packages/langchain_openai/README.md @@ -1,7 +1,28 @@ -# 🦜️🔗 LangChain.dart +# 🦜️🔗 LangChain.dart / OpenAI + +[![tests](https://img.shields.io/github/actions/workflow/status/davidmigloz/langchain_dart/test.yaml?logo=github&label=tests)](https://github.com/davidmigloz/langchain_dart/actions/workflows/test.yaml) +[![docs](https://img.shields.io/github/actions/workflow/status/davidmigloz/langchain_dart/pages%2Fpages-build-deployment?logo=github&label=docs)](https://github.com/davidmigloz/langchain_dart/actions/workflows/pages/pages-build-deployment) +[![langchain_openai](https://img.shields.io/pub/v/langchain_openai.svg)](https://pub.dev/packages/langchain_openai) +[![](https://dcbadge.vercel.app/api/server/x4qbhqecVR?style=flat)](https://discord.gg/x4qbhqecVR) +[![MIT](https://img.shields.io/badge/license-MIT-purple.svg)](https://github.com/davidmigloz/langchain_dart/blob/main/LICENSE) OpenAI module for [LangChain.dart](https://github.com/davidmigloz/langchain_dart). +## Features + +- LLMs: + * `OpenAI`: wrapper around OpenAI Completions API. +- Chat models: + * `ChatOpenAI`: wrapper around OpenAI Chat API. +- Embeddings: + * `OpenAIEmbeddings`: wrapper around OpenAI Embeddings API. +- Chains: + * `OpenAIQAWithStructureChain` a chain that answer questions in the specified + structure. + * `OpenAIQAWithSourcesChain`: a chain that answer questions providing sources. +- Agents: + * `OpenAIFunctionsAgent`: an agent driven by OpenAIs Functions powered API. + ## License LangChain.dart is licensed under the diff --git a/packages/langchain_openai/lib/src/chains/qa_with_sources.dart b/packages/langchain_openai/lib/src/chains/qa_with_sources.dart index 9a63fde1..ce814fce 100644 --- a/packages/langchain_openai/lib/src/chains/qa_with_sources.dart +++ b/packages/langchain_openai/lib/src/chains/qa_with_sources.dart @@ -5,6 +5,29 @@ import 'qa_with_structure.dart'; /// {@template openai_qa_with_sources_chain} /// A chain that answers questions returning a [QAWithSources] object /// containing the answers with the sources used to answer the question. +/// +/// Example: +/// ```dart +/// final llm = ChatOpenAI( +/// apiKey: openaiApiKey, +/// model: 'gpt-3.5-turbo-0613', +/// temperature: 0, +/// ); +/// final qaChain = OpenAIQAWithSourcesChain(llm: llm); +/// final docPrompt = PromptTemplate.fromTemplate( +/// 'Content: {page_content}\nSource: {source}', +/// ); +/// final finalQAChain = StuffDocumentsChain( +/// llmChain: qaChain, +/// documentPrompt: docPrompt, +/// ); +/// final retrievalQA = RetrievalQAChain( +/// retriever: vectorStore.asRetriever(), +/// combineDocumentsChain: finalQAChain, +/// ); +/// const query = 'What did President Biden say about Russia?'; +/// final res = await retrievalQA(query); +/// ``` /// {@endtemplate} class OpenAIQAWithSourcesChain extends OpenAIQAWithStructureChain { OpenAIQAWithSourcesChain({ diff --git a/packages/vertex_ai/README.md b/packages/vertex_ai/README.md index 773890f9..118a0a04 100644 --- a/packages/vertex_ai/README.md +++ b/packages/vertex_ai/README.md @@ -1,5 +1,10 @@ # Vertex AI API Client +[![tests](https://img.shields.io/github/actions/workflow/status/davidmigloz/langchain_dart/test.yaml?logo=github&label=tests)](https://github.com/davidmigloz/langchain_dart/actions/workflows/test.yaml) +[![vertex_ai](https://img.shields.io/pub/v/vertex_ai.svg)](https://pub.dev/packages/vertex_ai) +[![](https://dcbadge.vercel.app/api/server/x4qbhqecVR?style=flat)](https://discord.gg/x4qbhqecVR) +[![MIT](https://img.shields.io/badge/license-MIT-purple.svg)](https://github.com/davidmigloz/langchain_dart/blob/main/LICENSE) + Dart client for the [Vertex AI](https://cloud.google.com/vertex-ai) API. ## Features