Skip to content

Commit 07399a4

Browse files
committed
adding kwargs for most parameters
1 parent 49c0f2a commit 07399a4

File tree

2 files changed

+82
-52
lines changed

2 files changed

+82
-52
lines changed
+69-48
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,95 @@
11
import numpy as np
2-
from typing import List, Tuple
2+
from typing import List, Tuple, Optional, Union
33
from itertools import chain
44
from .find_local_minima import find_local_minima, windowed_cross_similarity
5-
from .splitter import constrained_coalesce, split_sentences
5+
from .splitter import (
6+
constrained_batches,
7+
constrained_coalesce,
8+
split_sentences,
9+
reverse_merge,
10+
)
611

712

813
class SemanticSplitter:
9-
"""A class for semantically splitting and reconstructing text."""
14+
"""
15+
A class for semantically splitting text.
16+
17+
This class provides methods to split text into chunks based on semantic similarity
18+
and reconstruct them while maintaining semantic coherence.
19+
"""
1020

1121
@staticmethod
12-
def flatten(nested_list: List[List]) -> List:
13-
"""Flatten a list of lists into a single list."""
22+
def flatten(nested_list: List[List[any]]) -> List[any]:
23+
"""
24+
Flatten a list of lists into a single list.
25+
26+
Args:
27+
nested_list (List[List[any]]): A list of lists to be flattened.
28+
29+
Returns:
30+
List[any]: A flattened list containing all elements from the nested lists.
31+
"""
1432
return list(chain.from_iterable(nested_list))
1533

1634
@staticmethod
1735
def constrained_split(
1836
text: str,
1937
target_size: int,
20-
coalesce_range: Tuple[int, int, int] = (256, 576, 64),
2138
separator: str = " ",
39+
min_size: int = 24,
2240
) -> List[str]:
2341
"""
2442
Split text into chunks of approximately target_size.
2543
26-
Parameters:
27-
- text (str): The text to split.
28-
- target_size (int): The target size for each chunk.
44+
Args:
45+
text (str): The text to split.
46+
target_size (int): The target size for each chunk.
47+
separator (str, optional): The separator to use when joining text. Defaults to " ".
48+
min_size (int, optional): The minimum size for each chunk. Defaults to 24.
2949
3050
Returns:
31-
- List[str]: List of text chunks.
51+
List[str]: List of text chunks.
3252
"""
3353
sentences = split_sentences(text)
34-
for i in range(*coalesce_range):
35-
sentences = constrained_coalesce(sentences, i, separator=separator)
54+
sentences = constrained_coalesce(sentences, target_size, separator=separator)
55+
sentences = reverse_merge(sentences, n=min_size, separator=separator)
3656
return sentences
3757

3858
@classmethod
3959
def split(
4060
cls,
4161
text: str,
4262
target_size: int,
43-
paragraph_range: Tuple[int, int, int] = (16, 60, 8),
44-
sentence_range: Tuple[int, int, int] = (256, 576, 64),
63+
cleanup_size: int = 24,
64+
intermediate_size: int = 96,
4565
) -> List[str]:
4666
"""
47-
Split the input text into chunks.
67+
Split the input text into chunks based on semantic coherence.
4868
49-
Parameters:
50-
- text (str): The input text to split.
51-
- target_size (int): The target size for final chunks.
52-
- initial_split_size (int): The initial size for splitting on newlines.
69+
Args:
70+
text (str): The input text to split.
71+
target_size (int): The target size for final chunks.
72+
cleanup_size (int, optional): The minimum size for cleaning up small chunks. Defaults to 24.
73+
intermediate_size (int, optional): The initial size for splitting on newlines. Defaults to 96.
5374
5475
Returns:
55-
- List[str]: List of text chunks.
76+
List[str]: List of text chunks.
5677
"""
57-
# paragraph splitting
58-
# split on newlines and coalesce to cleanup
5978
lines = text.splitlines()
60-
for i in range(*paragraph_range):
61-
lines = constrained_coalesce(lines, i, separator="\n")
79+
lines = constrained_coalesce(lines, intermediate_size, separator="\n")
80+
lines = reverse_merge(lines, n=cleanup_size, separator="\n")
6281

63-
# for paragraphs larger than target_size
64-
# split to sentences and coalesce
6582
chunks = [
6683
cls.constrained_split(
67-
line, target_size, coalesce_range=sentence_range, separator=" "
84+
line, target_size, min_size=cleanup_size, separator=" "
6885
)
6986
if len(line) > target_size
7087
else [line]
7188
for line in lines
7289
]
7390

74-
# flatten list of lists
7591
chunks = cls.flatten(chunks)
76-
return list(filter(lambda x: True if x.strip() else False, chunks))
92+
return list(filter(lambda x: bool(x.strip()), chunks))
7793

7894
@classmethod
7995
def reconstruct(
@@ -85,46 +101,51 @@ def reconstruct(
85101
poly_order: int,
86102
savgol_window: int,
87103
max_score_pct: float = 0.4,
88-
) -> List[str]:
104+
return_minima: bool = False,
105+
) -> Union[List[str], Tuple[np.ndarray, np.ndarray, np.ndarray]]:
89106
"""
90107
Reconstruct text chunks based on semantic similarity.
91108
92-
Parameters:
93-
- lines (List[str]): List of text chunks to reconstruct.
94-
- norm_embed (np.ndarray): Embeddings (normalized).
95-
- target_size (int): Target size for final chunks.
96-
- window_size (int): Window size for similarity matrix averaging.
97-
- poly_order (int): Polynomial order for Savitzky-Golay filter.
98-
- savgol_window (int): Window size for Savitzky-Golay filter.
109+
Args:
110+
lines (List[str]): List of text chunks to reconstruct.
111+
norm_embed (np.ndarray): Normalized embeddings of the text chunks.
112+
target_size (int): Target size for final chunks.
113+
window_size (int): Window size for similarity matrix averaging.
114+
poly_order (int): Polynomial order for Savitzky-Golay filter.
115+
savgol_window (int): Window size for Savitzky-Golay filter.
116+
max_score_pct (float, optional): Maximum percentile of similarity scores to consider. Defaults to 0.4.
117+
return_minima (bool, optional): If True, return minima information instead of reconstructed text. Defaults to False.
99118
100119
Returns:
101-
- List[str]: List of semantically split text chunks.
120+
Union[List[str], Tuple[np.ndarray, np.ndarray, np.ndarray]]:
121+
If return_minima is False, returns a list of reconstructed text chunks.
122+
If return_minima is True, returns a tuple of (roots, y, sim_avg).
123+
124+
Raises:
125+
AssertionError: If the number of texts doesn't equal the number of embeddings.
102126
"""
103127
assert (
104128
len(lines) == norm_embed.shape[0]
105129
), "Number of texts must equal number of embeddings"
106130

107-
# calculate the similarity for the window
108131
sim_avg = windowed_cross_similarity(norm_embed, window_size)
109-
110-
# find the minima
111132
roots, y = find_local_minima(
112133
sim_avg, poly_order=poly_order, window_size=savgol_window
113134
)
114-
split_points = np.round(roots).astype(int).tolist()
115135

116-
# filter to minima within bottom Nth percentile of similarity scores
136+
if return_minima:
137+
return roots, y, sim_avg
138+
117139
(x_idx,) = np.where(y < np.quantile(sim_avg, max_score_pct))
118-
split_points = [x for i, x in enumerate(split_points) if i in x_idx]
140+
split_points = [int(x) for i, x in enumerate(roots.tolist()) if i in x_idx]
119141

120-
# reconstruct using the minima as boundaries for coalesce
121-
# this ensures that any semantic boundaries are respected
122142
chunks = []
123143
start = 0
124144
for end in split_points + [len(lines)]:
125145
chunk = constrained_coalesce(lines[start:end], target_size)
126146
chunks.extend(chunk)
127147
start = end
128148

129-
chunks = constrained_coalesce(chunks, target_size)
130-
return chunks
149+
return list(
150+
map("".join, constrained_batches(lines, max_size=target_size, strict=False))
151+
)

wordllama/inference.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,11 @@ def split(
379379
text: str,
380380
target_size: int = 1536,
381381
window_size: int = 3,
382-
poly_order: int = 3,
383-
savgol_window: int = 5,
382+
poly_order: int = 2,
383+
savgol_window: int = 3,
384+
cleanup_size: int = 24,
385+
intermediate_size: int = 96,
386+
return_minima: bool = False,
384387
) -> List[str]:
385388
"""
386389
Perform semantic splitting on the input text.
@@ -397,9 +400,14 @@ def split(
397400
- List[str]: List of semantically split text chunks.
398401
"""
399402
# split text
400-
lines = SemanticSplitter.split(text, target_size=target_size)
403+
lines = SemanticSplitter.split(
404+
text,
405+
target_size=target_size,
406+
intermediate_size=intermediate_size,
407+
cleanup_size=cleanup_size,
408+
)
401409

402-
# compute cross similarity
410+
# embed lines and normalize
403411
embeddings = self.embed(lines, norm=True)
404412

405413
# reconstruct text with similarity signals
@@ -410,4 +418,5 @@ def split(
410418
window_size=window_size,
411419
poly_order=poly_order,
412420
savgol_window=savgol_window,
421+
return_minima=return_minima,
413422
)

0 commit comments

Comments
 (0)