Skip to content

Commit 559f7c0

Browse files
committed
improve bm25 sorting
1 parent eece02a commit 559f7c0

File tree

2 files changed

+39
-34
lines changed
  • libs/langchain-community/src

2 files changed

+39
-34
lines changed

libs/langchain-community/src/retrievers/bm25.ts

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,28 @@ export class BM25Retriever extends BaseRetriever {
4747

4848
async _getRelevantDocuments(query: string) {
4949
const processedQuery = this.preprocessFunc(query);
50-
const documents = this.docs.map((doc) => doc.pageContent);
51-
const scores = BM25(documents, processedQuery) as number[];
52-
53-
const scoredDocs = this.docs.map((doc, index) => ({
54-
document: doc,
55-
score: scores[index],
56-
}));
57-
58-
scoredDocs.sort((a, b) => b.score - a.score);
50+
const scoredDocs = BM25<Document>(
51+
this.docs.map((d) => ({
52+
text: d.pageContent,
53+
docs: d,
54+
})),
55+
processedQuery,
56+
undefined,
57+
(a, b) => b.score - a.score
58+
);
5959

6060
return scoredDocs.slice(0, this.k).map((item) => {
6161
if (this.includeScore) {
6262
return new Document({
63-
...(item.document.id && { id: item.document.id }),
64-
pageContent: item.document.pageContent,
63+
...(item.docs.id && { id: item.docs.id }),
64+
pageContent: item.docs.pageContent,
6565
metadata: {
6666
bm25Score: item.score,
67-
...item.document.metadata,
67+
...item.docs.metadata,
6868
},
6969
});
7070
} else {
71-
return item.document;
71+
return item.docs;
7272
}
7373
});
7474
}

libs/langchain-community/src/utils/@furkantoprak/bm25/BM25.ts

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,26 @@ export const getTermFrequency = (term: string, corpus: string) => {
1616
};
1717

1818
/** Inverse document frequency. */
19-
export const getIDF = (term: string, documents: string[]) => {
19+
export const getIDF = <T>(term: string, documents: BMInputDocument<T>[]) => {
2020
// Number of relevant documents.
21-
const relevantDocuments = documents.filter((document: string) =>
22-
document.includes(term)
21+
const relevantDocuments = documents.filter((document) =>
22+
document.text.includes(term)
2323
).length;
2424
return Math.log(
2525
(documents.length - relevantDocuments + 0.5) / (relevantDocuments + 0.5) + 1
2626
);
2727
};
2828

29+
export interface BMInputDocument<T> {
30+
text: string;
31+
docs: T;
32+
}
33+
2934
/** Represents a document; useful when sorting results.
3035
*/
31-
export interface BMDocument {
32-
/** The document is originally scoreed. */
33-
document: string;
36+
export interface BMOutputDocument<T> {
37+
/** The original source */
38+
docs: T;
3439
/** The score that the document recieves. */
3540
score: number;
3641
}
@@ -44,7 +49,10 @@ export interface BMConstants {
4449
}
4550

4651
/** If returns positive, the sorting results in secondEl coming before firstEl, else, firstEl comes before secondEL */
47-
export type BMSorter = (firstEl: BMDocument, secondEl: BMDocument) => number;
52+
export type BMSorter<T> = (
53+
firstEl: BMOutputDocument<T>,
54+
secondEl: BMOutputDocument<T>
55+
) => number;
4856

4957
/** Implementation of Okapi BM25 algorithm.
5058
* @param documents: Collection of documents.
@@ -53,16 +61,16 @@ export type BMSorter = (firstEl: BMDocument, secondEl: BMDocument) => number;
5361
* @param sort: A function that allows you to sort queries by a given rule. If not provided, returns results corresponding to the original order.
5462
* If this option is provided, the return type will not be an array of scores but an array of documents with their scores.
5563
*/
56-
export function BM25(
57-
documents: string[],
64+
export function BM25<T>(
65+
documents: BMInputDocument<T>[],
5866
keywords: string[],
5967
constants?: BMConstants,
60-
sorter?: BMSorter
61-
): number[] | BMDocument[] {
68+
sorter?: BMSorter<T>
69+
): BMOutputDocument<T>[] {
6270
const b = constants && constants.b ? constants.b : 0.75;
6371
const k1 = constants && constants.k1 ? constants.k1 : 1.2;
64-
const documentLengths = documents.map((document: string) =>
65-
getWordCount(document)
72+
const documentLengths = documents.map((document) =>
73+
getWordCount(document.text)
6674
);
6775
const averageDocumentLength =
6876
documentLengths.reduce((a, b) => a + b, 0) / documents.length;
@@ -71,14 +79,14 @@ export function BM25(
7179
return obj;
7280
}, new Map<string, number>());
7381

74-
const scores = documents.map((document: string, index: number) => {
82+
const scoredDocs = documents.map(({ text, docs }, index) => {
7583
const score = keywords
7684
.map((keyword: string) => {
7785
const inverseDocumentFrequency = idfByKeyword.get(keyword);
7886
if (inverseDocumentFrequency === undefined) {
7987
throw new Error("Missing keyword.");
8088
}
81-
const termFrequency = getTermFrequency(keyword, document);
89+
const termFrequency = getTermFrequency(keyword, text);
8290
const documentLength = documentLengths[index];
8391
return (
8492
(inverseDocumentFrequency * (termFrequency * (k1 + 1))) /
@@ -87,14 +95,11 @@ export function BM25(
8795
);
8896
})
8997
.reduce((a: number, b: number) => a + b, 0);
90-
if (sorter) {
91-
return { score, document } as BMDocument;
92-
}
93-
return score;
98+
return { score, docs } as BMOutputDocument<T>;
9499
});
95100
// sort the results
96101
if (sorter) {
97-
return (scores as BMDocument[]).sort(sorter);
102+
return scoredDocs.sort(sorter);
98103
}
99-
return scores as number[];
104+
return scoredDocs;
100105
}

0 commit comments

Comments
 (0)