Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cohere[minor]: Add cohere rerank #4380

Merged
merged 9 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Cohere Rerank

Reranking documents can greatly improve any RAG application and document retrieval system.

At a high level, a rerank API is a language model which analyzes documents and reorders them based on their relevance to a given query.

Cohere offers an API for reranking documents. In this example we'll show you how to use it.

## Setup

import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx";

<IntegrationInstallTooltip></IntegrationInstallTooltip>

```bash npm2yarn
npm install @langchain/cohere
```

import CodeBlock from "@theme/CodeBlock";
import Example from "@examples/document_compressors/cohere_rerank.ts";

<CodeBlock language="typescript">{Example}</CodeBlock>

From the results, we can see it returned the top 3 documents, and assigned a `relevanceScore` to each.

As expected, the document with the highest `relevanceScore` is the one that references Washington, D.C., with a score of `98.7%`!
52 changes: 52 additions & 0 deletions examples/src/document_compressors/cohere_rerank.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import { CohereRerank } from "@langchain/cohere";
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work on the PR! I've flagged a specific change for the maintainers to review, as it accesses the COHERE_API_KEY environment variable using process.env. Keep up the good work!

import { Document } from "@langchain/core/documents";

const query = "What is the capital of the United States?";
const docs = [
new Document({
pageContent:
"Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274.",
}),
new Document({
pageContent:
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan.",
}),
new Document({
pageContent:
"Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas.",
}),
new Document({
pageContent:
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.",
}),
new Document({
pageContent:
"Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.",
}),
];

const cohereRerank = new CohereRerank({
apiKey: process.env.COHERE_API_KEY, // Default
topN: 3, // Default
model: "rerank-english-v2.0", // Default
});

const rerankedDocuments = await cohereRerank.compressDocuments(docs, query);

console.log(rerankedDocuments);
/**
[
Document {
pageContent: 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.',
metadata: { relevanceScore: 0.9871293 }
},
Document {
pageContent: 'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan.',
metadata: { relevanceScore: 0.29961726 }
},
Document {
pageContent: 'Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.',
metadata: { relevanceScore: 0.27542195 }
}
]
*/
1 change: 1 addition & 0 deletions libs/langchain-cohere/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export * from "./chat_models.js";
export * from "./llms.js";
export * from "./embeddings.js";
export * from "./rerank.js";
74 changes: 74 additions & 0 deletions libs/langchain-cohere/src/rerank.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import { DocumentInterface } from "@langchain/core/documents";
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I've flagged this PR for your review because it introduces changes related to accessing and using environment variables via getEnvironmentVariable. Please take a look and ensure everything is handled correctly. Let me know if you have any questions!

import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { CohereClient } from "cohere-ai";

export interface CohereRerankArgs {
/**
* The API key to use.
* @default {process.env.COHERE_API_KEY}
*/
apiKey?: string;
/**
* The name of the model to use.
* @default {"rerank-english-v2.0"}
*/
model?: string;
/**
* How many documents to return.
* @default {3}
*/
topN?: number;
}

/**
* Document compressor that uses `Cohere Rerank API`.
*/
export class CohereRerank {
model = "rerank-english-v2.0";

topN = 3;

client: CohereClient;

constructor(fields: CohereRerankArgs) {
const token = fields?.apiKey ?? getEnvironmentVariable("COHERE_API_KEY");
if (!token) {
throw new Error("No API key provided for CohereRerank.");
}

this.client = new CohereClient({
token,
});
this.model = fields.model ?? this.model;
this.topN = fields.topN ?? this.topN;
}

/**
* Compress documents using Cohere's rerank API.
*
* @param {Array<DocumentInterface>} documents A sequence of documents to compress.
* @param {string} query The query to use for compressing the documents.
*
* @returns {Promise<Array<DocumentInterface>>} A sequence of compressed documents.
*/
async compressDocuments(
documents: Array<DocumentInterface>,
query: string
): Promise<Array<DocumentInterface>> {
const _docs = documents.map((doc) => doc.pageContent);
const { results } = await this.client.rerank({
model: this.model,
query,
documents: _docs,
topN: this.topN,
});
const finalResults: Array<DocumentInterface> = [];
for (let i = 0; i < results.length; i += 1) {
const result = results[i];
const doc = documents[result.index];
doc.metadata.relevanceScore = result.relevanceScore;
finalResults.push(doc);
}
return finalResults;
}
}
31 changes: 31 additions & 0 deletions libs/langchain-cohere/src/tests/rerank.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* eslint-disable no-process-env */
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I've reviewed the code and noticed that the new test file is accessing an environment variable via process.env. I've flagged this for your review to ensure proper handling of environment variables. Let me know if you need any further assistance with this.

import { Document } from "@langchain/core/documents";
import { CohereRerank } from "../rerank.js";

test("CohereRerank can indeed rerank documents!", async () => {
const query = "What is the capital of France?";

const documents = [
new Document({
pageContent: "Paris is the capital of France.",
}),
new Document({
pageContent: "Build context-aware reasoning applications",
}),
new Document({
pageContent:
"Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274",
}),
];

const cohereRerank = new CohereRerank({
apiKey: process.env.COHERE_API_KEY,
});

const rerankedDocuments = await cohereRerank.compressDocuments(
documents,
query
);
console.log(rerankedDocuments);
expect(rerankedDocuments).toHaveLength(3);
});
Loading