Skip to content

Commit

Permalink
Add support to custom text spliter (microsoft#270)
Browse files Browse the repository at this point in the history
* Add support to custom text spliter function and a list of files or urls

* Add parameter to retrieve_config, add tests

* Fix tests

* Fix tests
  • Loading branch information
thinkall authored Oct 17, 2023
1 parent 294e006 commit f2d7553
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 5 deletions.
4 changes: 4 additions & 0 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def __init__(
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
- custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings.
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
Example of overriding retrieve_docs:
Expand Down Expand Up @@ -175,6 +177,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False
)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
self._context_max_tokens = self._max_tokens * 0.8
self._collection = True if self._docs_path is None else False # whether the collection is created
self._ipython = get_ipython()
Expand Down Expand Up @@ -364,6 +367,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
embedding_model=self._embedding_model,
get_or_create=self._get_or_create,
embedding_function=self._embedding_function,
custom_text_split_function=self.custom_text_split_function,
)
self._collection = True
self._get_or_create = False
Expand Down
36 changes: 31 additions & 5 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,11 @@ def extract_text_from_pdf(file: str) -> str:


def split_files_to_chunks(
files: list, max_tokens: int = 4000, chunk_mode: str = "multi_lines", must_break_at_empty_line: bool = True
files: list,
max_tokens: int = 4000,
chunk_mode: str = "multi_lines",
must_break_at_empty_line: bool = True,
custom_text_split_function: Callable = None,
):
"""Split a list of files into chunks of max_tokens."""

Expand All @@ -200,18 +204,33 @@ def split_files_to_chunks(
logger.warning(f"No text available in file: {file}")
continue # Skip to the next file if no text is available

chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
if custom_text_split_function is not None:
chunks += custom_text_split_function(text)
else:
chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)

return chunks


def get_files_from_dir(dir_path: str, types: list = TEXT_FORMATS, recursive: bool = True):
def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True):
"""Return a list of all the files in a given directory."""
if len(types) == 0:
raise ValueError("types cannot be empty.")
types = [t[1:].lower() if t.startswith(".") else t.lower() for t in set(types)]
types += [t.upper() for t in types]

files = []
# If the path is a list of files or urls, process and return them
if isinstance(dir_path, list):
for item in dir_path:
if os.path.isfile(item):
files.append(item)
elif is_url(item):
files.append(get_file_from_url(item))
else:
logger.warning(f"File {item} does not exist. Skipping.")
return files

# If the path is a file, return it
if os.path.isfile(dir_path):
return [dir_path]
Expand All @@ -220,7 +239,6 @@ def get_files_from_dir(dir_path: str, types: list = TEXT_FORMATS, recursive: boo
if is_url(dir_path):
return [get_file_from_url(dir_path)]

files = []
if os.path.exists(dir_path):
for type in types:
if recursive:
Expand Down Expand Up @@ -265,6 +283,7 @@ def create_vector_db_from_dir(
must_break_at_empty_line: bool = True,
embedding_model: str = "all-MiniLM-L6-v2",
embedding_function: Callable = None,
custom_text_split_function: Callable = None,
):
"""Create a vector db from all the files in a given directory, the directory can also be a single file or a url to
a single file. We support chromadb compatible APIs to create the vector db, this function is not required if
Expand Down Expand Up @@ -304,7 +323,14 @@ def create_vector_db_from_dir(
metadata={"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}, # ip, l2, cosine
)

chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
if custom_text_split_function is not None:
chunks = split_files_to_chunks(
get_files_from_dir(dir_path), custom_text_split_function=custom_text_split_function
)
else:
chunks = split_files_to_chunks(
get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line
)
logger.info(f"Found {len(chunks)} chunks.")
# Upsert in batch of 40000 or less if the total number of chunks is less than 40000
for i in range(0, len(chunks), min(40000, len(chunks))):
Expand Down
22 changes: 22 additions & 0 deletions test/test_retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def test_split_files_to_chunks(self):
def test_get_files_from_dir(self):
files = get_files_from_dir(test_dir)
assert all(os.path.isfile(file) for file in files)
pdf_file_path = os.path.join(test_dir, "example.pdf")
txt_file_path = os.path.join(test_dir, "example.txt")
files = get_files_from_dir([pdf_file_path, txt_file_path])
assert all(os.path.isfile(file) for file in files)

def test_is_url(self):
assert is_url("https://www.example.com")
Expand Down Expand Up @@ -164,6 +168,24 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark")
assert ragragproxyagent._results["ids"] == [3, 1, 5]

def test_custom_text_split_function(self):
def custom_text_split_function(text):
return [text[: len(text) // 2], text[len(text) // 2 :]]

db_path = "/tmp/test_retrieve_utils_chromadb.db"
client = chromadb.PersistentClient(path=db_path)
create_vector_db_from_dir(
os.path.join(test_dir, "example.txt"),
client=client,
collection_name="mytestcollection",
custom_text_split_function=custom_text_split_function,
)
results = query_vector_db(["autogen"], client=client, collection_name="mytestcollection", n_results=1)
assert (
results.get("documents")[0][0]
== "AutoGen is an advanced tool designed to assist developers in harnessing the capabilities\nof Large Language Models (LLMs) for various applications. The primary purpose o"
)


if __name__ == "__main__":
pytest.main()
Expand Down

0 comments on commit f2d7553

Please sign in to comment.