-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
cohere[minor]: Add cohere rerank (#4380)
* cohere[minor]: Add cohere rerank * cr * docs * chore: lint files * cr * Add sidebar * add rerank method * cr * chore: lint files
- Loading branch information
1 parent
ae11d0a
commit 4a8fe37
Showing
7 changed files
with
320 additions
and
0 deletions.
There are no files selected for viewing
35 changes: 35 additions & 0 deletions
35
docs/core_docs/docs/integrations/document_compressors/cohere_rerank.mdx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Cohere Rerank | ||
|
||
Reranking documents can greatly improve any RAG application and document retrieval system. | ||
|
||
At a high level, a rerank API is a language model which analyzes documents and reorders them based on their relevance to a given query. | ||
|
||
Cohere offers an API for reranking documents. In this example we'll show you how to use it. | ||
|
||
## Setup | ||
|
||
import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx"; | ||
|
||
<IntegrationInstallTooltip></IntegrationInstallTooltip> | ||
|
||
```bash npm2yarn | ||
npm install @langchain/cohere | ||
``` | ||
|
||
import CodeBlock from "@theme/CodeBlock"; | ||
|
||
import Example from "@examples/document_compressors/cohere_rerank.ts"; | ||
|
||
<CodeBlock language="typescript">{Example}</CodeBlock> | ||
|
||
Here, we can see the `.rerank()` method returns just the index of the documents (matching the indexes of the input documents) and their relevancy scores. | ||
|
||
If we'd like to have the documents returned from the method itself, we can use the `.compressDocuments()` method. | ||
|
||
import ExampleCompressor from "@examples/document_compressors/cohere_rerank_compressor.ts"; | ||
|
||
<CodeBlock language="typescript">{ExampleCompressor}</CodeBlock> | ||
|
||
From the results, we can see it returned the top 3 documents, and assigned a `relevanceScore` to each. | ||
|
||
As expected, the document with the highest `relevanceScore` is the one that references Washington, D.C., with a score of `98.7%`! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import { CohereRerank } from "@langchain/cohere"; | ||
import { Document } from "@langchain/core/documents"; | ||
|
||
const query = "What is the capital of the United States?"; | ||
const docs = [ | ||
new Document({ | ||
pageContent: | ||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274.", | ||
}), | ||
new Document({ | ||
pageContent: | ||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan.", | ||
}), | ||
new Document({ | ||
pageContent: | ||
"Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas.", | ||
}), | ||
new Document({ | ||
pageContent: | ||
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.", | ||
}), | ||
new Document({ | ||
pageContent: | ||
"Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.", | ||
}), | ||
]; | ||
|
||
const cohereRerank = new CohereRerank({ | ||
apiKey: process.env.COHERE_API_KEY, // Default | ||
model: "rerank-english-v2.0", // Default | ||
}); | ||
|
||
const rerankedDocuments = await cohereRerank.rerank(docs, query, { | ||
topN: 5, | ||
}); | ||
|
||
console.log(rerankedDocuments); | ||
/** | ||
[ | ||
{ index: 3, relevanceScore: 0.9871293 }, | ||
{ index: 1, relevanceScore: 0.29961726 }, | ||
{ index: 4, relevanceScore: 0.27542195 }, | ||
{ index: 0, relevanceScore: 0.08977329 }, | ||
{ index: 2, relevanceScore: 0.041462272 } | ||
] | ||
*/ |
52 changes: 52 additions & 0 deletions
52
examples/src/document_compressors/cohere_rerank_compressor.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import { CohereRerank } from "@langchain/cohere"; | ||
import { Document } from "@langchain/core/documents"; | ||
|
||
const query = "What is the capital of the United States?"; | ||
const docs = [ | ||
new Document({ | ||
pageContent: | ||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274.", | ||
}), | ||
new Document({ | ||
pageContent: | ||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan.", | ||
}), | ||
new Document({ | ||
pageContent: | ||
"Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas.", | ||
}), | ||
new Document({ | ||
pageContent: | ||
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.", | ||
}), | ||
new Document({ | ||
pageContent: | ||
"Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.", | ||
}), | ||
]; | ||
|
||
const cohereRerank = new CohereRerank({ | ||
apiKey: process.env.COHERE_API_KEY, // Default | ||
topN: 3, // Default | ||
model: "rerank-english-v2.0", // Default | ||
}); | ||
|
||
const rerankedDocuments = await cohereRerank.compressDocuments(docs, query); | ||
|
||
console.log(rerankedDocuments); | ||
/** | ||
[ | ||
Document { | ||
pageContent: 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.', | ||
metadata: { relevanceScore: 0.9871293 } | ||
}, | ||
Document { | ||
pageContent: 'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan.', | ||
metadata: { relevanceScore: 0.29961726 } | ||
}, | ||
Document { | ||
pageContent: 'Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.', | ||
metadata: { relevanceScore: 0.27542195 } | ||
} | ||
] | ||
*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
export * from "./chat_models.js"; | ||
export * from "./llms.js"; | ||
export * from "./embeddings.js"; | ||
export * from "./rerank.js"; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import { DocumentInterface } from "@langchain/core/documents"; | ||
import { getEnvironmentVariable } from "@langchain/core/utils/env"; | ||
import { CohereClient } from "cohere-ai"; | ||
|
||
export interface CohereRerankArgs { | ||
/** | ||
* The API key to use. | ||
* @default {process.env.COHERE_API_KEY} | ||
*/ | ||
apiKey?: string; | ||
/** | ||
* The name of the model to use. | ||
* @default {"rerank-english-v2.0"} | ||
*/ | ||
model?: string; | ||
/** | ||
* How many documents to return. | ||
* @default {3} | ||
*/ | ||
topN?: number; | ||
/** | ||
* The maximum number of chunks per document. | ||
*/ | ||
maxChunksPerDoc?: number; | ||
} | ||
|
||
/** | ||
* Document compressor that uses `Cohere Rerank API`. | ||
*/ | ||
export class CohereRerank { | ||
model = "rerank-english-v2.0"; | ||
|
||
topN = 3; | ||
|
||
client: CohereClient; | ||
|
||
maxChunksPerDoc: number | undefined; | ||
|
||
constructor(fields?: CohereRerankArgs) { | ||
const token = fields?.apiKey ?? getEnvironmentVariable("COHERE_API_KEY"); | ||
if (!token) { | ||
throw new Error("No API key provided for CohereRerank."); | ||
} | ||
|
||
this.client = new CohereClient({ | ||
token, | ||
}); | ||
this.model = fields?.model ?? this.model; | ||
this.topN = fields?.topN ?? this.topN; | ||
this.maxChunksPerDoc = fields?.maxChunksPerDoc; | ||
} | ||
|
||
/** | ||
* Compress documents using Cohere's rerank API. | ||
* | ||
* @param {Array<DocumentInterface>} documents A sequence of documents to compress. | ||
* @param {string} query The query to use for compressing the documents. | ||
* | ||
* @returns {Promise<Array<DocumentInterface>>} A sequence of compressed documents. | ||
*/ | ||
async compressDocuments( | ||
documents: Array<DocumentInterface>, | ||
query: string | ||
): Promise<Array<DocumentInterface>> { | ||
const _docs = documents.map((doc) => doc.pageContent); | ||
const { results } = await this.client.rerank({ | ||
model: this.model, | ||
query, | ||
documents: _docs, | ||
topN: this.topN, | ||
maxChunksPerDoc: this.maxChunksPerDoc, | ||
}); | ||
const finalResults: Array<DocumentInterface> = []; | ||
for (let i = 0; i < results.length; i += 1) { | ||
const result = results[i]; | ||
const doc = documents[result.index]; | ||
doc.metadata.relevanceScore = result.relevanceScore; | ||
finalResults.push(doc); | ||
} | ||
return finalResults; | ||
} | ||
|
||
/** | ||
* Returns an ordered list of documents ordered by their relevance to the provided query. | ||
* | ||
* @param {Array<DocumentInterface | string | Record<string, string>>} documents A list of documents as strings, DocumentInterfaces or objects with a `pageContent` key. | ||
* @param {string} query The query to use for reranking the documents. | ||
* @param options | ||
* @param {string} options.model The name of the model to use. | ||
* @param {number} options.topN How many documents to return. | ||
* @param {number} options.maxChunksPerDoc The maximum number of chunks per document. | ||
* | ||
* @returns {Promise<Array<{ index: number; relevanceScore: number }>>} An ordered list of documents with relevance scores. | ||
*/ | ||
async rerank( | ||
documents: Array<DocumentInterface | string | Record<string, string>>, | ||
query: string, | ||
options?: { | ||
model?: string; | ||
topN?: number; | ||
maxChunksPerDoc?: number; | ||
} | ||
): Promise<Array<{ index: number; relevanceScore: number }>> { | ||
const docs = documents.map((doc) => { | ||
if (typeof doc === "string") { | ||
return doc; | ||
} | ||
return doc.pageContent; | ||
}); | ||
const model = options?.model ?? this.model; | ||
const topN = options?.topN ?? this.topN; | ||
const maxChunksPerDoc = options?.maxChunksPerDoc ?? this.maxChunksPerDoc; | ||
const { results } = await this.client.rerank({ | ||
model, | ||
query, | ||
documents: docs, | ||
topN, | ||
maxChunksPerDoc, | ||
}); | ||
|
||
const resultObjects = results.map((result) => ({ | ||
index: result.index, | ||
relevanceScore: result.relevanceScore, | ||
})); | ||
return resultObjects; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* eslint-disable no-process-env */ | ||
import { Document } from "@langchain/core/documents"; | ||
import { CohereRerank } from "../rerank.js"; | ||
|
||
const query = "What is the capital of France?"; | ||
|
||
const documents = [ | ||
new Document({ | ||
pageContent: "Paris is the capital of France.", | ||
}), | ||
new Document({ | ||
pageContent: "Build context-aware reasoning applications", | ||
}), | ||
new Document({ | ||
pageContent: | ||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274", | ||
}), | ||
]; | ||
|
||
test("CohereRerank can indeed rerank documents with compressDocuments method", async () => { | ||
const cohereRerank = new CohereRerank({ | ||
apiKey: process.env.COHERE_API_KEY, | ||
}); | ||
|
||
const rerankedDocuments = await cohereRerank.compressDocuments( | ||
documents, | ||
query | ||
); | ||
console.log(rerankedDocuments); | ||
expect(rerankedDocuments).toHaveLength(3); | ||
}); | ||
|
||
test("CohereRerank can indeed rerank documents with rerank method", async () => { | ||
const cohereRerank = new CohereRerank({ | ||
apiKey: process.env.COHERE_API_KEY, | ||
}); | ||
|
||
const rerankedDocuments = await cohereRerank.rerank( | ||
documents.map((doc) => doc.pageContent), | ||
query | ||
); | ||
console.log(rerankedDocuments); | ||
expect(rerankedDocuments).toHaveLength(3); | ||
}); |