Skip to content

Commit 40fd426

Browse files
committed
splitter class and inference implementation
1 parent c84c36f commit 40fd426

File tree

7 files changed

+245
-223
lines changed

7 files changed

+245
-223
lines changed
-258 KB
Binary file not shown.
258 KB
Loading
748 Bytes
Loading
2.87 KB
Loading

tutorials/semantic_split/semantic_split.md

+100-223
Large diffs are not rendered by default.
+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import numpy as np
2+
from typing import List
3+
from itertools import chain
4+
from .find_local_minima import window_average, find_local_minima
5+
from .splitter import constrained_coalesce, split_sentences
6+
7+
8+
class SemanticSplitter:
9+
"""A class for semantically splitting and reconstructing text."""
10+
11+
@staticmethod
12+
def flatten(nested_list: List[List]) -> List:
13+
"""Flatten a list of lists into a single list."""
14+
return list(chain.from_iterable(nested_list))
15+
16+
@staticmethod
17+
def constrained_split(text: str, target_size: int) -> List[str]:
18+
"""
19+
Split text into chunks of approximately target_size.
20+
21+
Parameters:
22+
- text (str): The text to split.
23+
- target_size (int): The target size for each chunk.
24+
25+
Returns:
26+
- List[str]: List of text chunks.
27+
"""
28+
sentences = split_sentences(text)
29+
return constrained_coalesce(sentences, target_size, separator=" ")
30+
31+
@classmethod
32+
def split(cls, text: str, target_size: int, initial_split_size: int) -> List[str]:
33+
"""
34+
Split the input text into chunks.
35+
36+
Parameters:
37+
- text (str): The input text to split.
38+
- target_size (int): The target size for final chunks.
39+
- initial_split_size (int): The initial size for splitting on newlines.
40+
41+
Returns:
42+
- List[str]: List of text chunks.
43+
"""
44+
lines = constrained_coalesce(
45+
text.splitlines(), initial_split_size, separator="\n"
46+
)
47+
chunks = [
48+
cls.constrained_split(line, target_size)
49+
if len(line) > target_size
50+
else [line]
51+
for line in lines
52+
]
53+
chunks = cls.flatten(chunks)
54+
return [chunk for chunk in chunks if chunk.strip()]
55+
56+
@classmethod
57+
def reconstruct(
58+
cls,
59+
lines: List[str],
60+
x_sim: np.ndarray,
61+
target_size: int,
62+
window_size: int,
63+
poly_order: int,
64+
savgol_window: int,
65+
max_score_pct: float = 0.4,
66+
) -> List[str]:
67+
"""
68+
Reconstruct text chunks based on semantic similarity.
69+
70+
Parameters:
71+
- lines (List[str]): List of text chunks to reconstruct.
72+
- x_sim (np.ndarray): Cross-similarity matrix of text chunks.
73+
- target_size (int): Target size for final chunks.
74+
- window_size (int): Window size for similarity matrix averaging.
75+
- poly_order (int): Polynomial order for Savitzky-Golay filter.
76+
- savgol_window (int): Window size for Savitzky-Golay filter.
77+
78+
Returns:
79+
- List[str]: List of semantically split text chunks.
80+
"""
81+
sim_avg = window_average(x_sim, window_size)
82+
x = np.arange(len(sim_avg))
83+
roots, y = find_local_minima(
84+
x, sim_avg, poly_order=poly_order, window_size=savgol_window
85+
)
86+
split_points = np.round(roots).astype(int).tolist()
87+
88+
# filter to minima within bottom 40% of similarity scores
89+
(x_idx,) = np.where(y < np.quantile(sim_avg, max_score_pct))
90+
split_points = [x for i, x in enumerate(split_points) if i in x_idx]
91+
92+
# reconstruct using the minima as boundaries for coalesce
93+
# this ensures that any semantic boundaries are respected
94+
chunks = []
95+
start = 0
96+
for end in split_points + [len(lines)]:
97+
chunk = constrained_coalesce(lines[start:end], target_size)
98+
chunks.extend(chunk)
99+
start = end
100+
101+
chunks = constrained_coalesce(chunks, target_size)
102+
return chunks

wordllama/inference.py

+43
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
binarize_and_packbits,
1010
process_batches_cy,
1111
)
12+
from .algorithms.semantic_splitter import SemanticSplitter
1213
from .config import WordLlamaConfig
1314

1415
# Set up logging
@@ -370,3 +371,45 @@ def cluster(
370371
random_state=random_state,
371372
)
372373
return cluster_labels, inertia
374+
375+
def split(
376+
self,
377+
text: str,
378+
target_size: int = 1536,
379+
window_size: int = 3,
380+
initial_split_size: int = 64,
381+
poly_order: int = 2,
382+
savgol_window: int = 7,
383+
) -> List[str]:
384+
"""
385+
Perform semantic splitting on the input text.
386+
387+
Parameters:
388+
- text (str): The input text to split.
389+
- target_size (int): Target size for text chunks.
390+
- window_size (int): Window size for similarity matrix averaging.
391+
- initial_split_size (int): Initial size for splitting on newlines.
392+
- poly_order (int): Polynomial order for Savitzky-Golay filter.
393+
- savgol_window (int): Window size for Savitzky-Golay filter.
394+
395+
Returns:
396+
- List[str]: List of semantically split text chunks.
397+
"""
398+
# split text
399+
lines = SemanticSplitter.split(
400+
text, target_size=target_size, initial_split_size=initial_split_size
401+
)
402+
403+
# compute cross similarity
404+
embeddings = self.embed(lines)
405+
cross_similarity = self.vector_similarity(embeddings, embeddings)
406+
407+
# reconstruct text with similarity signals
408+
return SemanticSplitter.reconstruct(
409+
lines,
410+
cross_similarity,
411+
target_size=target_size,
412+
window_size=window_size,
413+
poly_order=poly_order,
414+
savgol_window=savgol_window,
415+
)

0 commit comments

Comments
 (0)