From 4a8fe37eadadd5456436d0be0049b0560a4adece Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 12 Feb 2024 22:18:21 -0800 Subject: [PATCH] cohere[minor]: Add cohere rerank (#4380) * cohere[minor]: Add cohere rerank * cr * docs * chore: lint files * cr * Add sidebar * add rerank method * cr * chore: lint files --- .../document_compressors/cohere_rerank.mdx | 35 +++++ docs/core_docs/sidebars.js | 15 +++ .../src/document_compressors/cohere_rerank.ts | 46 +++++++ .../cohere_rerank_compressor.ts | 52 +++++++ libs/langchain-cohere/src/index.ts | 1 + libs/langchain-cohere/src/rerank.ts | 127 ++++++++++++++++++ .../src/tests/rerank.int.test.ts | 44 ++++++ 7 files changed, 320 insertions(+) create mode 100644 docs/core_docs/docs/integrations/document_compressors/cohere_rerank.mdx create mode 100644 examples/src/document_compressors/cohere_rerank.ts create mode 100644 examples/src/document_compressors/cohere_rerank_compressor.ts create mode 100644 libs/langchain-cohere/src/rerank.ts create mode 100644 libs/langchain-cohere/src/tests/rerank.int.test.ts diff --git a/docs/core_docs/docs/integrations/document_compressors/cohere_rerank.mdx b/docs/core_docs/docs/integrations/document_compressors/cohere_rerank.mdx new file mode 100644 index 000000000000..15bf2ea743dc --- /dev/null +++ b/docs/core_docs/docs/integrations/document_compressors/cohere_rerank.mdx @@ -0,0 +1,35 @@ +# Cohere Rerank + +Reranking documents can greatly improve any RAG application and document retrieval system. + +At a high level, a rerank API is a language model which analyzes documents and reorders them based on their relevance to a given query. + +Cohere offers an API for reranking documents. In this example we'll show you how to use it. + +## Setup + +import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx"; + + + +```bash npm2yarn +npm install @langchain/cohere +``` + +import CodeBlock from "@theme/CodeBlock"; + +import Example from "@examples/document_compressors/cohere_rerank.ts"; + +{Example} + +Here, we can see the `.rerank()` method returns just the index of the documents (matching the indexes of the input documents) and their relevancy scores. + +If we'd like to have the documents returned from the method itself, we can use the `.compressDocuments()` method. + +import ExampleCompressor from "@examples/document_compressors/cohere_rerank_compressor.ts"; + +{ExampleCompressor} + +From the results, we can see it returned the top 3 documents, and assigned a `relevanceScore` to each. + +As expected, the document with the highest `relevanceScore` is the one that references Washington, D.C., with a score of `98.7%`! diff --git a/docs/core_docs/sidebars.js b/docs/core_docs/sidebars.js index e62aa12d5dfd..01c76b7ebde3 100644 --- a/docs/core_docs/sidebars.js +++ b/docs/core_docs/sidebars.js @@ -254,6 +254,21 @@ module.exports = { slug: "integrations/document_transformers", }, }, + { + type: "category", + label: "Document compressors", + collapsed: true, + items: [ + { + type: "autogenerated", + dirName: "integrations/document_compressors", + }, + ], + link: { + type: "generated-index", + slug: "integrations/document_compressors", + }, + }, { type: "category", label: "Text embedding models", diff --git a/examples/src/document_compressors/cohere_rerank.ts b/examples/src/document_compressors/cohere_rerank.ts new file mode 100644 index 000000000000..45443b728b57 --- /dev/null +++ b/examples/src/document_compressors/cohere_rerank.ts @@ -0,0 +1,46 @@ +import { CohereRerank } from "@langchain/cohere"; +import { Document } from "@langchain/core/documents"; + +const query = "What is the capital of the United States?"; +const docs = [ + new Document({ + pageContent: + "Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274.", + }), + new Document({ + pageContent: + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan.", + }), + new Document({ + pageContent: + "Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas.", + }), + new Document({ + pageContent: + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.", + }), + new Document({ + pageContent: + "Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.", + }), +]; + +const cohereRerank = new CohereRerank({ + apiKey: process.env.COHERE_API_KEY, // Default + model: "rerank-english-v2.0", // Default +}); + +const rerankedDocuments = await cohereRerank.rerank(docs, query, { + topN: 5, +}); + +console.log(rerankedDocuments); +/** +[ + { index: 3, relevanceScore: 0.9871293 }, + { index: 1, relevanceScore: 0.29961726 }, + { index: 4, relevanceScore: 0.27542195 }, + { index: 0, relevanceScore: 0.08977329 }, + { index: 2, relevanceScore: 0.041462272 } +] + */ diff --git a/examples/src/document_compressors/cohere_rerank_compressor.ts b/examples/src/document_compressors/cohere_rerank_compressor.ts new file mode 100644 index 000000000000..941e28f32e91 --- /dev/null +++ b/examples/src/document_compressors/cohere_rerank_compressor.ts @@ -0,0 +1,52 @@ +import { CohereRerank } from "@langchain/cohere"; +import { Document } from "@langchain/core/documents"; + +const query = "What is the capital of the United States?"; +const docs = [ + new Document({ + pageContent: + "Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274.", + }), + new Document({ + pageContent: + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan.", + }), + new Document({ + pageContent: + "Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas.", + }), + new Document({ + pageContent: + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.", + }), + new Document({ + pageContent: + "Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.", + }), +]; + +const cohereRerank = new CohereRerank({ + apiKey: process.env.COHERE_API_KEY, // Default + topN: 3, // Default + model: "rerank-english-v2.0", // Default +}); + +const rerankedDocuments = await cohereRerank.compressDocuments(docs, query); + +console.log(rerankedDocuments); +/** +[ + Document { + pageContent: 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.', + metadata: { relevanceScore: 0.9871293 } + }, + Document { + pageContent: 'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan.', + metadata: { relevanceScore: 0.29961726 } + }, + Document { + pageContent: 'Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.', + metadata: { relevanceScore: 0.27542195 } + } +] + */ diff --git a/libs/langchain-cohere/src/index.ts b/libs/langchain-cohere/src/index.ts index 7f420a4ed6d0..f453c82911ff 100644 --- a/libs/langchain-cohere/src/index.ts +++ b/libs/langchain-cohere/src/index.ts @@ -1,3 +1,4 @@ export * from "./chat_models.js"; export * from "./llms.js"; export * from "./embeddings.js"; +export * from "./rerank.js"; diff --git a/libs/langchain-cohere/src/rerank.ts b/libs/langchain-cohere/src/rerank.ts new file mode 100644 index 000000000000..4fbb86a3b8bd --- /dev/null +++ b/libs/langchain-cohere/src/rerank.ts @@ -0,0 +1,127 @@ +import { DocumentInterface } from "@langchain/core/documents"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { CohereClient } from "cohere-ai"; + +export interface CohereRerankArgs { + /** + * The API key to use. + * @default {process.env.COHERE_API_KEY} + */ + apiKey?: string; + /** + * The name of the model to use. + * @default {"rerank-english-v2.0"} + */ + model?: string; + /** + * How many documents to return. + * @default {3} + */ + topN?: number; + /** + * The maximum number of chunks per document. + */ + maxChunksPerDoc?: number; +} + +/** + * Document compressor that uses `Cohere Rerank API`. + */ +export class CohereRerank { + model = "rerank-english-v2.0"; + + topN = 3; + + client: CohereClient; + + maxChunksPerDoc: number | undefined; + + constructor(fields?: CohereRerankArgs) { + const token = fields?.apiKey ?? getEnvironmentVariable("COHERE_API_KEY"); + if (!token) { + throw new Error("No API key provided for CohereRerank."); + } + + this.client = new CohereClient({ + token, + }); + this.model = fields?.model ?? this.model; + this.topN = fields?.topN ?? this.topN; + this.maxChunksPerDoc = fields?.maxChunksPerDoc; + } + + /** + * Compress documents using Cohere's rerank API. + * + * @param {Array} documents A sequence of documents to compress. + * @param {string} query The query to use for compressing the documents. + * + * @returns {Promise>} A sequence of compressed documents. + */ + async compressDocuments( + documents: Array, + query: string + ): Promise> { + const _docs = documents.map((doc) => doc.pageContent); + const { results } = await this.client.rerank({ + model: this.model, + query, + documents: _docs, + topN: this.topN, + maxChunksPerDoc: this.maxChunksPerDoc, + }); + const finalResults: Array = []; + for (let i = 0; i < results.length; i += 1) { + const result = results[i]; + const doc = documents[result.index]; + doc.metadata.relevanceScore = result.relevanceScore; + finalResults.push(doc); + } + return finalResults; + } + + /** + * Returns an ordered list of documents ordered by their relevance to the provided query. + * + * @param {Array>} documents A list of documents as strings, DocumentInterfaces or objects with a `pageContent` key. + * @param {string} query The query to use for reranking the documents. + * @param options + * @param {string} options.model The name of the model to use. + * @param {number} options.topN How many documents to return. + * @param {number} options.maxChunksPerDoc The maximum number of chunks per document. + * + * @returns {Promise>} An ordered list of documents with relevance scores. + */ + async rerank( + documents: Array>, + query: string, + options?: { + model?: string; + topN?: number; + maxChunksPerDoc?: number; + } + ): Promise> { + const docs = documents.map((doc) => { + if (typeof doc === "string") { + return doc; + } + return doc.pageContent; + }); + const model = options?.model ?? this.model; + const topN = options?.topN ?? this.topN; + const maxChunksPerDoc = options?.maxChunksPerDoc ?? this.maxChunksPerDoc; + const { results } = await this.client.rerank({ + model, + query, + documents: docs, + topN, + maxChunksPerDoc, + }); + + const resultObjects = results.map((result) => ({ + index: result.index, + relevanceScore: result.relevanceScore, + })); + return resultObjects; + } +} diff --git a/libs/langchain-cohere/src/tests/rerank.int.test.ts b/libs/langchain-cohere/src/tests/rerank.int.test.ts new file mode 100644 index 000000000000..93270f36e373 --- /dev/null +++ b/libs/langchain-cohere/src/tests/rerank.int.test.ts @@ -0,0 +1,44 @@ +/* eslint-disable no-process-env */ +import { Document } from "@langchain/core/documents"; +import { CohereRerank } from "../rerank.js"; + +const query = "What is the capital of France?"; + +const documents = [ + new Document({ + pageContent: "Paris is the capital of France.", + }), + new Document({ + pageContent: "Build context-aware reasoning applications", + }), + new Document({ + pageContent: + "Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274", + }), +]; + +test("CohereRerank can indeed rerank documents with compressDocuments method", async () => { + const cohereRerank = new CohereRerank({ + apiKey: process.env.COHERE_API_KEY, + }); + + const rerankedDocuments = await cohereRerank.compressDocuments( + documents, + query + ); + console.log(rerankedDocuments); + expect(rerankedDocuments).toHaveLength(3); +}); + +test("CohereRerank can indeed rerank documents with rerank method", async () => { + const cohereRerank = new CohereRerank({ + apiKey: process.env.COHERE_API_KEY, + }); + + const rerankedDocuments = await cohereRerank.rerank( + documents.map((doc) => doc.pageContent), + query + ); + console.log(rerankedDocuments); + expect(rerankedDocuments).toHaveLength(3); +});