Skip to content

Commit

Permalink
Add retriever types' configuration
Browse files Browse the repository at this point in the history
BM25 retriever can be selected as its own retriever without the need of
a vector retriever. This could be extended to use any type of
llama-index retriever in the future. Also, a pre-commit hook is added to
remove unused imports.
  • Loading branch information
Dedalo314 committed Jan 23, 2024
1 parent 6a5945f commit 03c062c
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 7 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ repos:
- id: check-added-large-files
exclude: poetry.lock

# remove unused imports
- repo: https://github.com/hadialqattan/pycln
rev: v2.4.0
hooks:
- id: pycln
args: [--config=pyproject.toml]

# python code formatting
- repo: https://github.com/psf/black
rev: 23.1.0
Expand Down
1 change: 1 addition & 0 deletions openbb_chat/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .auto_llama_index import AutoLlamaIndex
20 changes: 16 additions & 4 deletions openbb_chat/kernels/auto_llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
from llama_index.storage.docstore import SimpleDocumentStore
from llama_index.storage.index_store import SimpleIndexStore
from llama_index.vector_stores import SimpleVectorStore
from pydantic import BaseModel

from openbb_chat.retrievers.hybrid_or_retriever import HybridORRetriever


class AutoLlamaIndex:
class AutoLlamaIndex(BaseModel):
"""Wrapper around `llama-index` that fixes its possibilities to the ones needed for `openbb-
chat`.
Expand Down Expand Up @@ -53,6 +54,8 @@ class AutoLlamaIndex:
String representation of the LlamaIndex's QA template to use.
refine_template_str (`Optional[str]`):
String representation of the LlamaIndex's refine template to use.
retriever_type (`str`):
One of 'hybrid', 'vector' or 'bm25'. Default 'hybrid'.
other_llama_index_llm_kwargs (`dict`):
Overrides the default values in LlamaIndex's `LLM`
other_llama_index_simple_directory_reader_kwargs (`dict`):
Expand Down Expand Up @@ -83,7 +86,7 @@ def __init__(
model_kwargs: Optional[dict] = None,
qa_template_str: Optional[str] = None,
refine_template_str: Optional[str] = None,
use_hybrid_retriever: bool = True,
retriever_type: str = "hybrid",
other_llama_index_llm_kwargs: dict = {},
other_llama_index_simple_directory_reader_kwargs: dict = {},
other_llama_index_service_context_kwargs: dict = {},
Expand All @@ -95,6 +98,7 @@ def __init__(
other_llama_index_retriever_query_engine_kwargs: dict = {},
):
"""Init method."""
super().__init__()

# create LLM from configuration
self._llm = self._create_llama_index_llm(
Expand All @@ -119,18 +123,26 @@ def __init__(
)

# configure retriever
if use_hybrid_retriever:
if retriever_type == "hybrid":
vector_retriever = VectorIndexRetriever(
index=self._index, **other_llama_index_vector_index_retriever_kwargs
)
bm25_retriever = BM25Retriever.from_defaults(
index=self._index, **other_llama_index_bm25_retriever_kwargs
)
self._retriever = HybridORRetriever(vector_retriever, bm25_retriever)
else:
elif retriever_type == "vector":
self._retriever = VectorIndexRetriever(
index=self._index, **other_llama_index_vector_index_retriever_kwargs
)
elif retriever_type == "bm25":
self._retriever = BM25Retriever.from_defaults(
index=self._index, **other_llama_index_bm25_retriever_kwargs
)
else:
raise ValueError(
f"`retriever_type` must be 'hybrid', 'vector' or 'bm25'. Current value: {retriever_type}"
)

self._qa_template_str = (
PromptTemplate(qa_template_str) if qa_template_str is not None else None
Expand Down
2 changes: 2 additions & 0 deletions openbb_chat/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .chat_model_llm_iface import ChatModelWithLLMIface
from .guidance_wrapper import GuidanceWrapper
52 changes: 49 additions & 3 deletions tests/kernels/test_auto_llama_index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
import tempfile
from unittest.mock import patch

import pytest
Expand All @@ -11,7 +9,7 @@
)
from llama_index.llms import HuggingFaceLLM
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.retrievers import VectorIndexRetriever
from llama_index.retrievers import BM25Retriever, VectorIndexRetriever

from openbb_chat.kernels.auto_llama_index import AutoLlamaIndex

Expand Down Expand Up @@ -125,3 +123,51 @@ def test_auto_llama_index_synth(mocked_synthesize, mocked_retrieve):
node_list = autollamaindex.retrieve(query)
response = autollamaindex.synth(query, node_list)
mocked_synthesize.assert_called_once()


@patch.object(VectorIndexRetriever, "retrieve")
@patch.object(RetrieverQueryEngine, "query")
def test_auto_llama_index_vector_retriever(mocked_query, mocked_retrieve):
# load testing models
autollamaindex = AutoLlamaIndex(
"./docs",
"local:sentence-transformers/all-MiniLM-L6-v2",
"hf:sshleifer/tiny-gpt2",
retriever_type="vector",
context_window=100,
other_llama_index_response_synthesizer_kwargs={"response_mode": "simple_summarize"},
)

query = "What is the purpose of Index.md"

# test retrieval
node_list = autollamaindex.retrieve(query)
mocked_retrieve.assert_called_once()

# test query
response = autollamaindex.query(query)
mocked_query.assert_called_once()


@patch.object(BM25Retriever, "retrieve")
@patch.object(RetrieverQueryEngine, "query")
def test_auto_llama_index_bm25_retriever(mocked_query, mocked_retrieve):
# load testing models
autollamaindex = AutoLlamaIndex(
"./docs",
"local:sentence-transformers/all-MiniLM-L6-v2",
"hf:sshleifer/tiny-gpt2",
retriever_type="bm25",
context_window=100,
other_llama_index_response_synthesizer_kwargs={"response_mode": "simple_summarize"},
)

query = "What is the purpose of Index.md"

# test retrieval
node_list = autollamaindex.retrieve(query)
mocked_retrieve.assert_called_once()

# test query
response = autollamaindex.query(query)
mocked_query.assert_called_once()

0 comments on commit 03c062c

Please sign in to comment.