Skip to content

Commit d7ba6f0

Browse files
committed
Add Semantic Text Splitter
1 parent 9709279 commit d7ba6f0

File tree

3 files changed

+693
-0
lines changed

3 files changed

+693
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
export * from "./text_splitter.js";
2+
export * from "./semantic_splitter.js";
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
/**
2+
* Semantic text splitter based on semantic similarity.
3+
*
4+
* Inspired by Greg Kamradt's semantic chunking approach:
5+
* https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/tutorials/LevelsOfTextSplitting/5_Levels_Of_Text_Splitting.ipynb
6+
*/
7+
8+
import { Document, BaseDocumentTransformer } from "@langchain/core/documents";
9+
import type { Embeddings } from "@langchain/core/embeddings";
10+
11+
interface SentenceDict {
12+
sentence: string;
13+
index: number;
14+
combinedSentence?: string;
15+
combinedSentenceEmbedding?: number[];
16+
}
17+
18+
export function combineSentences(sentences: SentenceDict[], bufferSize: number = 1): SentenceDict[] {
19+
// Go through each sentence dict
20+
for (let i = 0; i < sentences.length; i++) {
21+
const sentence = sentences[i];
22+
23+
// Create the combined sentence by combining the sentences within the buffer
24+
let combinedSentence = "";
25+
26+
// Add sentences before the current one
27+
for (let j = Math.max(0, i - bufferSize); j < i; j++) {
28+
combinedSentence += sentences[j].sentence + " ";
29+
}
30+
31+
// Add the current sentence
32+
combinedSentence += sentence.sentence;
33+
34+
// Add sentences after the current one
35+
for (let j = i + 1; j <= Math.min(sentences.length - 1, i + bufferSize); j++) {
36+
combinedSentence += " " + sentences[j].sentence;
37+
}
38+
39+
// Assign the combined sentence to the dict
40+
sentence.combinedSentence = combinedSentence.trim();
41+
}
42+
43+
return sentences;
44+
}
45+
46+
function cosineSimilarity(a: number[], b: number[]): number {
47+
const dotProduct = a.reduce((sum, ai, i) => sum + ai * b[i], 0);
48+
const magnitudeA = Math.sqrt(a.reduce((sum, ai) => sum + ai * ai, 0));
49+
const magnitudeB = Math.sqrt(b.reduce((sum, bi) => sum + bi * bi, 0));
50+
51+
if (magnitudeA === 0 || magnitudeB === 0) {
52+
return 0;
53+
}
54+
55+
return dotProduct / (magnitudeA * magnitudeB);
56+
}
57+
58+
export function calculateCosineDistances(sentences: SentenceDict[]): [number[], SentenceDict[]] {
59+
const distances: number[] = [];
60+
61+
for (let i = 0; i < sentences.length - 1; i++) {
62+
const embeddingCurrent = sentences[i].combinedSentenceEmbedding!;
63+
const embeddingNext = sentences[i + 1].combinedSentenceEmbedding!;
64+
65+
// Calculate cosine similarity
66+
const similarity = cosineSimilarity(embeddingCurrent, embeddingNext);
67+
68+
// Convert to cosine distance (1 - cosine similarity)
69+
const distance = 1 - similarity;
70+
distances.push(distance);
71+
}
72+
73+
return [distances, sentences];
74+
}
75+
76+
type BreakpointThresholdType = "percentile" | "standard_deviation" | "interquartile" | "gradient";
77+
78+
const BREAKPOINT_DEFAULTS: Record<BreakpointThresholdType, number> = {
79+
percentile: 95,
80+
standard_deviation: 3,
81+
interquartile: 1.5,
82+
gradient: 95,
83+
};
84+
85+
export interface SemanticChunkerParams {
86+
embeddings: Embeddings;
87+
bufferSize?: number;
88+
addStartIndex?: boolean;
89+
breakpointThresholdType?: BreakpointThresholdType;
90+
breakpointThresholdAmount?: number;
91+
numberOfChunks?: number;
92+
sentenceSplitRegex?: string;
93+
minChunkSize?: number;
94+
}
95+
96+
/**
97+
* Split the text based on semantic similarity.
98+
*
99+
* Taken from Greg Kamradt's wonderful notebook:
100+
* https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/tutorials/LevelsOfTextSplitting/5_Levels_Of_Text_Splitting.ipynb
101+
*
102+
* All credits to him.
103+
*
104+
* At a high level, this splits into sentences, then groups into groups of 3
105+
* sentences, and then merges one that are similar in the embedding space.
106+
*/
107+
export class SemanticChunker extends BaseDocumentTransformer {
108+
lc_namespace = ["langchain", "document_transformers", "text_splitters"];
109+
110+
private embeddings: Embeddings;
111+
private bufferSize: number;
112+
private addStartIndex: boolean;
113+
private breakpointThresholdType: BreakpointThresholdType;
114+
private breakpointThresholdAmount: number;
115+
private numberOfChunks?: number;
116+
private sentenceSplitRegex: string;
117+
private minChunkSize?: number;
118+
119+
constructor(params: SemanticChunkerParams) {
120+
super();
121+
this.embeddings = params.embeddings;
122+
this.bufferSize = params.bufferSize ?? 1;
123+
this.addStartIndex = params.addStartIndex ?? false;
124+
this.breakpointThresholdType = params.breakpointThresholdType ?? "percentile";
125+
this.numberOfChunks = params.numberOfChunks;
126+
this.sentenceSplitRegex = params.sentenceSplitRegex ?? "(?<=[.?!])\\s+";
127+
this.minChunkSize = params.minChunkSize;
128+
129+
if (params.breakpointThresholdAmount === undefined) {
130+
this.breakpointThresholdAmount = BREAKPOINT_DEFAULTS[this.breakpointThresholdType];
131+
} else {
132+
this.breakpointThresholdAmount = params.breakpointThresholdAmount;
133+
}
134+
}
135+
136+
private calculateBreakpointThreshold(distances: number[]): [number, number[]] {
137+
if (this.breakpointThresholdType === "percentile") {
138+
const sorted = [...distances].sort((a, b) => a - b);
139+
const index = Math.ceil((this.breakpointThresholdAmount / 100) * sorted.length) - 1;
140+
return [sorted[Math.max(0, index)], distances];
141+
} else if (this.breakpointThresholdType === "standard_deviation") {
142+
const mean = distances.reduce((sum, d) => sum + d, 0) / distances.length;
143+
const variance = distances.reduce((sum, d) => sum + Math.pow(d - mean, 2), 0) / distances.length;
144+
const stdDev = Math.sqrt(variance);
145+
return [mean + this.breakpointThresholdAmount * stdDev, distances];
146+
} else if (this.breakpointThresholdType === "interquartile") {
147+
const sorted = [...distances].sort((a, b) => a - b);
148+
const q1Index = Math.floor(0.25 * sorted.length);
149+
const q3Index = Math.floor(0.75 * sorted.length);
150+
const q1 = sorted[q1Index];
151+
const q3 = sorted[q3Index];
152+
const iqr = q3 - q1;
153+
const mean = distances.reduce((sum, d) => sum + d, 0) / distances.length;
154+
return [mean + this.breakpointThresholdAmount * iqr, distances];
155+
} else if (this.breakpointThresholdType === "gradient") {
156+
const distanceGradient: number[] = [];
157+
for (let i = 0; i < distances.length - 1; i++) {
158+
distanceGradient.push(distances[i + 1] - distances[i]);
159+
}
160+
const sortedGradient = [...distanceGradient].sort((a, b) => a - b);
161+
const index = Math.ceil((this.breakpointThresholdAmount / 100) * sortedGradient.length) - 1;
162+
return [sortedGradient[Math.max(0, index)], distanceGradient];
163+
} else {
164+
throw new Error(`Got unexpected breakpointThresholdType: ${this.breakpointThresholdType}`);
165+
}
166+
}
167+
168+
private thresholdFromClusters(distances: number[]): number {
169+
if (this.numberOfChunks === undefined) {
170+
throw new Error("This should never be called if numberOfChunks is undefined.");
171+
}
172+
173+
const x1 = distances.length;
174+
const y1 = 0.0;
175+
const x2 = 1.0;
176+
const y2 = 100.0;
177+
178+
const x = Math.max(Math.min(this.numberOfChunks, x1), x2);
179+
180+
let y: number;
181+
if (x2 === x1) {
182+
y = y2;
183+
} else {
184+
y = y1 + ((y2 - y1) / (x2 - x1)) * (x - x1);
185+
}
186+
187+
y = Math.min(Math.max(y, 0), 100);
188+
189+
const sorted = [...distances].sort((a, b) => a - b);
190+
const index = Math.ceil((y / 100) * sorted.length) - 1;
191+
return sorted[Math.max(0, index)];
192+
}
193+
194+
private async calculateSentenceDistances(singleSentencesList: string[]): Promise<[number[], SentenceDict[]]> {
195+
const sentences: SentenceDict[] = singleSentencesList.map((sentence, index) => ({
196+
sentence,
197+
index,
198+
}));
199+
200+
const sentencesWithCombined = combineSentences(sentences, this.bufferSize);
201+
202+
const combinedSentences = sentencesWithCombined.map(s => s.combinedSentence!);
203+
const embeddings = await this.embeddings.embedDocuments(combinedSentences);
204+
205+
for (let i = 0; i < sentencesWithCombined.length; i++) {
206+
sentencesWithCombined[i].combinedSentenceEmbedding = embeddings[i];
207+
}
208+
209+
return calculateCosineDistances(sentencesWithCombined);
210+
}
211+
212+
private getSingleSentencesList(text: string): string[] {
213+
return text.split(new RegExp(this.sentenceSplitRegex)).filter(sentence => sentence.trim().length > 0);
214+
}
215+
216+
async splitText(text: string): Promise<string[]> {
217+
const singleSentencesList = this.getSingleSentencesList(text);
218+
219+
if (singleSentencesList.length === 1) {
220+
return singleSentencesList;
221+
}
222+
223+
if (this.breakpointThresholdType === "gradient" && singleSentencesList.length === 2) {
224+
return singleSentencesList;
225+
}
226+
227+
const [distances, sentences] = await this.calculateSentenceDistances(singleSentencesList);
228+
229+
let breakpointDistanceThreshold: number;
230+
let breakpointArray: number[];
231+
232+
if (this.numberOfChunks !== undefined) {
233+
breakpointDistanceThreshold = this.thresholdFromClusters(distances);
234+
breakpointArray = distances;
235+
} else {
236+
[breakpointDistanceThreshold, breakpointArray] = this.calculateBreakpointThreshold(distances);
237+
}
238+
239+
const indicesAboveThresh = breakpointArray
240+
.map((x, i) => ({ value: x, index: i }))
241+
.filter(({ value }) => value > breakpointDistanceThreshold)
242+
.map(({ index }) => index);
243+
244+
const chunks: string[] = [];
245+
let startIndex = 0;
246+
247+
for (const index of indicesAboveThresh) {
248+
const endIndex = index;
249+
const group = sentences.slice(startIndex, endIndex + 1);
250+
const combinedText = group.map((d: SentenceDict) => d.sentence).join(" ");
251+
252+
if (this.minChunkSize !== undefined && combinedText.length < this.minChunkSize) {
253+
continue;
254+
}
255+
256+
chunks.push(combinedText);
257+
startIndex = index + 1;
258+
}
259+
260+
if (startIndex < sentences.length) {
261+
const combinedText = sentences.slice(startIndex).map((d: SentenceDict) => d.sentence).join(" ");
262+
chunks.push(combinedText);
263+
}
264+
265+
return chunks;
266+
}
267+
268+
async createDocuments(
269+
texts: string[],
270+
metadatas: Record<string, any>[] = []
271+
): Promise<Document[]> {
272+
const _metadatas = metadatas.length > 0 ? metadatas : texts.map(() => ({}));
273+
const documents: Document[] = [];
274+
275+
for (let i = 0; i < texts.length; i++) {
276+
const text = texts[i];
277+
let startIndex = 0;
278+
279+
for (const chunk of await this.splitText(text)) {
280+
const metadata: Record<string, any> = { ..._metadatas[i] };
281+
if (this.addStartIndex) {
282+
metadata.start_index = startIndex;
283+
}
284+
285+
const newDoc = new Document({
286+
pageContent: chunk,
287+
metadata,
288+
});
289+
290+
documents.push(newDoc);
291+
startIndex += chunk.length;
292+
}
293+
}
294+
295+
return documents;
296+
}
297+
298+
async splitDocuments(documents: Document[]): Promise<Document[]> {
299+
const texts = documents.map(doc => doc.pageContent);
300+
const metadatas = documents.map(doc => doc.metadata);
301+
return this.createDocuments(texts, metadatas);
302+
}
303+
304+
async transformDocuments(documents: Document[]): Promise<Document[]> {
305+
return this.splitDocuments(documents);
306+
}
307+
}

0 commit comments

Comments
 (0)