Skip to content

Commit

Permalink
feat: optimize split rule when use custom split segment identifier (l…
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost authored May 16, 2023
1 parent 3117619 commit 815f794
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 6 deletions.
68 changes: 68 additions & 0 deletions api/core/index/spiltter/fixed_text_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Functionality for splitting text."""
from __future__ import annotations

from typing import (
Any,
List,
Optional,
)

from langchain.text_splitter import RecursiveCharacterTextSplitter


class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._fixed_separator = fixed_separator
self._separators = separators or ["\n\n", "\n", " ", ""]

def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
if self._fixed_separator:
chunks = text.split(self._fixed_separator)
else:
chunks = list(text)

final_chunks = []
for chunk in chunks:
if self._length_function(chunk) > self._chunk_size:
final_chunks.extend(self.recursive_split_text(chunk))
else:
final_chunks.append(chunk)

return final_chunks

def recursive_split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = self._separators[-1]
for _s in self._separators:
if _s == "":
separator = _s
break
if _s in text:
separator = _s
break
# Now that we have the separator, split the text
if separator:
splits = text.split(separator)
else:
splits = list(text)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text)
_good_splits = []
other_info = self.recursive_split_text(s)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text)
return final_chunks
11 changes: 5 additions & 6 deletions api/core/indexing_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from core.index.keyword_table_index import KeywordTableIndex
from core.index.readers.html_parser import HTMLParser
from core.index.readers.pdf_parser import PDFParser
from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.index.vector_index import VectorIndex
from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db
Expand Down Expand Up @@ -267,16 +268,14 @@ def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
raise ValueError("Custom segment length should be between 50 and 1000.")

separator = segmentation["separator"]
if not separator:
separators = ["\n\n", "。", ".", " ", ""]
else:
if separator:
separator = separator.replace('\\n', '\n')
separators = [separator, ""]

character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=0,
separators=separators
fixed_separator=separator,
separators=["\n\n", "。", ".", " ", ""]
)
else:
# Automatic segmentation
Expand Down

0 comments on commit 815f794

Please sign in to comment.