From 5619d9aa8e9e5a86de637cdf552c3d0a613685eb Mon Sep 17 00:00:00 2001 From: Jerry Liu Date: Mon, 5 Feb 2024 15:26:36 -0800 Subject: [PATCH] add local RAG CLI variant (#916) --- llama_hub/llama_packs/library.json | 5 + llama_hub/llama_packs/rag_cli_local/README.md | 60 +++++++++ .../llama_packs/rag_cli_local/__init__.py | 0 llama_hub/llama_packs/rag_cli_local/base.py | 117 ++++++++++++++++++ .../rag_cli_local/requirements.txt | 0 poetry.lock | 22 +++- pyproject.toml | 2 +- 7 files changed, 200 insertions(+), 6 deletions(-) create mode 100644 llama_hub/llama_packs/rag_cli_local/README.md create mode 100644 llama_hub/llama_packs/rag_cli_local/__init__.py create mode 100644 llama_hub/llama_packs/rag_cli_local/base.py create mode 100644 llama_hub/llama_packs/rag_cli_local/requirements.txt diff --git a/llama_hub/llama_packs/library.json b/llama_hub/llama_packs/library.json index cd8ba8bcaa..186dddb6c4 100644 --- a/llama_hub/llama_packs/library.json +++ b/llama_hub/llama_packs/library.json @@ -287,5 +287,10 @@ "id": "llama_packs/research/infer_retrieve_rerank", "author": "jerryjliu", "keywords": ["infer", "retrieve", "rerank", "retriever", "rag"] + }, + "LocalRAGCLIPack": { + "id": "llama_packs/rag_cli_local", + "author": "jerryjliu", + "keywords": ["rag", "cli", "local"] } } diff --git a/llama_hub/llama_packs/rag_cli_local/README.md b/llama_hub/llama_packs/rag_cli_local/README.md new file mode 100644 index 0000000000..5b9f05588c --- /dev/null +++ b/llama_hub/llama_packs/rag_cli_local/README.md @@ -0,0 +1,60 @@ +# RAG Local CLI Pack + +This LlamaPack implements a fully local version of our [RAG CLI](https://docs.llamaindex.ai/en/stable/use_cases/q_and_a/rag_cli.html), +with Mistral (through Ollama) and [BGE-M3](https://huggingface.co/BAAI/bge-m3). + +## CLI Usage + +You can download llamapacks directly using `llamaindex-cli`, which comes installed with the `llama-index` python package: + +```bash +llamaindex-cli download-llamapack LocalRAGCLIPack --download-dir ./local_rag_cli_pack +``` + +You can then inspect the files at `./local_rag_cli_pack` and use them as a template for your own project! + +## Code Usage + +You can download the pack to a directory. **NOTE**: You must specify `skip_load=True` - the pack contains multiple files, +which makes it hard to load directly. + +We will show you how to import the agent from these files! + +```python +from llama_index.llama_pack import download_llama_pack + +# download and install dependencies +download_llama_pack( + "LocalRAGCLIPack", "./local_rag_cli_pack", skip_load=True +) +``` + +From here, you can use the pack. The most straightforward way is through the CLI. You can directly run base.py, or run the `setup_cli.sh` script. + +```bash +cd local_rag_cli_pack + +# option 1 +python base.py rag -h + +# option 2 - you may need sudo +# default name is lcli_local +sudo sh setup_cli.sh +lcli_local rag -h + +``` + +You can also directly get modules from the pack. + +```python +from local_rag_cli_pack.base import LocalRAGCLIPack + +pack = LocalRAGCLIPack(verbose=True, llm_model_name="mistral", embed_model_name="BAAI/bge-m3") +# will spin up the CLI +pack.run() + +# get modules +rag_cli = pack.get_modules()["rag_cli"] +rag_cli.cli() + +``` diff --git a/llama_hub/llama_packs/rag_cli_local/__init__.py b/llama_hub/llama_packs/rag_cli_local/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/llama_hub/llama_packs/rag_cli_local/base.py b/llama_hub/llama_packs/rag_cli_local/base.py new file mode 100644 index 0000000000..8f6f45df84 --- /dev/null +++ b/llama_hub/llama_packs/rag_cli_local/base.py @@ -0,0 +1,117 @@ +"""Local RAG CLI Pack.""" + +from llama_index.ingestion import IngestionPipeline, IngestionCache +from llama_index.query_pipeline.query import QueryPipeline +from llama_index.storage.docstore import SimpleDocumentStore +from llama_index.command_line.rag import RagCLI +from llama_index.text_splitter import SentenceSplitter +from llama_index.embeddings import HuggingFaceEmbedding +from llama_index.llms import Ollama +from llama_index.vector_stores import ChromaVectorStore +from llama_index.utils import get_cache_dir +from llama_index import ServiceContext, VectorStoreIndex +from llama_index.response_synthesizers import CompactAndRefine +from llama_index.query_pipeline import InputComponent +from llama_index.llama_pack.base import BaseLlamaPack +from typing import Optional, Dict, Any +from pathlib import Path +import chromadb + + +def default_ragcli_persist_dir() -> str: + """Get default RAG CLI persist dir.""" + return str(Path(get_cache_dir()) / "rag_cli_local") + + +def init_local_rag_cli( + persist_dir: Optional[str] = None, + verbose: bool = False, + llm_model_name: str = "mistral", + embed_model_name: str = "BAAI/bge-m3", +) -> RagCLI: + """Init local RAG CLI.""" + + docstore = SimpleDocumentStore() + persist_dir = persist_dir or default_ragcli_persist_dir() + chroma_client = chromadb.PersistentClient(path=persist_dir) + chroma_collection = chroma_client.create_collection("default", get_or_create=True) + vector_store = ChromaVectorStore( + chroma_collection=chroma_collection, persist_dir=persist_dir + ) + print("> Chroma collection initialized") + llm = Ollama(model=llm_model_name, request_timeout=30.0) + print("> LLM initialized") + embed_model = HuggingFaceEmbedding(model_name=embed_model_name) + print("> Embedding model initialized") + + ingestion_pipeline = IngestionPipeline( + transformations=[SentenceSplitter(), embed_model], + vector_store=vector_store, + docstore=docstore, + cache=IngestionCache(), + ) + + service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model) + retriever = VectorStoreIndex.from_vector_store( + ingestion_pipeline.vector_store, service_context=service_context + ).as_retriever(similarity_top_k=8) + response_synthesizer = CompactAndRefine( + service_context=service_context, streaming=True, verbose=True + ) + # define query pipeline + query_pipeline = QueryPipeline(verbose=verbose) + query_pipeline.add_modules( + { + "input": InputComponent(), + "retriever": retriever, + "summarizer": response_synthesizer, + } + ) + query_pipeline.add_link("input", "retriever") + query_pipeline.add_link("retriever", "summarizer", dest_key="nodes") + query_pipeline.add_link("input", "summarizer", dest_key="query_str") + + rag_cli_instance = RagCLI( + ingestion_pipeline=ingestion_pipeline, + llm=llm, # optional + persist_dir=persist_dir, + query_pipeline=query_pipeline, + verbose=False, + ) + return rag_cli_instance + + +class LocalRAGCLIPack(BaseLlamaPack): + """Local RAG CLI Pack.""" + + def __init__( + self, + verbose: bool = False, + persist_dir: Optional[str] = None, + llm_model_name: str = "mistral", + embed_model_name: str = "BAAI/bge-m3", + ) -> None: + """Init params.""" + self.verbose = verbose + self.persist_dir = persist_dir or default_ragcli_persist_dir() + self.llm_model_name = llm_model_name + self.embed_model_name = embed_model_name + self.rag_cli = init_local_rag_cli( + persist_dir=self.persist_dir, + verbose=self.verbose, + llm_model_name=self.llm_model_name, + embed_model_name=self.embed_model_name, + ) + + def get_modules(self) -> Dict[str, Any]: + """Get modules.""" + return {"rag_cli": self.rag_cli} + + def run(self, *args: Any, **kwargs: Any) -> Any: + """Run the pipeline.""" + return self.rag_cli.cli(*args, **kwargs) + + +if __name__ == "__main__": + rag_cli_instance = init_local_rag_cli() + rag_cli_instance.cli() diff --git a/llama_hub/llama_packs/rag_cli_local/requirements.txt b/llama_hub/llama_packs/rag_cli_local/requirements.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/poetry.lock b/poetry.lock index 395d0911d2..d39d68177a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -378,6 +378,17 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "dirtyjson" +version = "1.0.8" +description = "JSON decoder for Python that can extract data from the muck" +optional = false +python-versions = "*" +files = [ + {file = "dirtyjson-1.0.8-py3-none-any.whl", hash = "sha256:125e27248435a58acace26d5c2c4c11a1c0de0a9c5124c5a94ba78e517d74f53"}, + {file = "dirtyjson-1.0.8.tar.gz", hash = "sha256:90ca4a18f3ff30ce849d100dcf4a003953c79d3a2348ef056f1d9c22231a25fd"}, +] + [[package]] name = "distro" version = "1.9.0" @@ -714,19 +725,20 @@ files = [ [[package]] name = "llama-index" -version = "0.9.39" +version = "0.9.44" description = "Interface between LLMs and your data" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "llama_index-0.9.39-py3-none-any.whl", hash = "sha256:73e19bf664b0643e3c1b88229d4bcaad841f4c6e882a63b27f637386c54d5353"}, - {file = "llama_index-0.9.39.tar.gz", hash = "sha256:c0d4093cd1c6d6056275f96d6acba56f383ef98925c9ce3fc8cde9fb4dee1f75"}, + {file = "llama_index-0.9.44-py3-none-any.whl", hash = "sha256:678ee2cdbc95b718f474ce475dcae1d7e4fcb7acd71dccb8581f7a396d46c02c"}, + {file = "llama_index-0.9.44.tar.gz", hash = "sha256:9c5261e7016f5c5898970c3f8932e01daeaa80f62e9efa1405f0d7364f3aea33"}, ] [package.dependencies] aiohttp = ">=3.8.6,<4.0.0" dataclasses-json = "*" deprecated = ">=1.2.9.3" +dirtyjson = ">=1.0.8,<2.0.0" fsspec = ">=2023.5.0" httpx = "*" nest-asyncio = ">=1.5.8,<2.0.0" @@ -1095,8 +1107,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -2145,4 +2157,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.12" -content-hash = "32db2f1fe0f22f0387ca768712f877fdd723f6d2d12ad4a5d01efad154ff5646" +content-hash = "cdee3348a99cb5fc26020c1f818f1950c8027301d6a65c37497dd1ebb0d0baf6" diff --git a/pyproject.toml b/pyproject.toml index 6f2e657d42..f4852224b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ include = [ [tool.poetry.dependencies] # Updated Python version python = ">=3.8.1,<3.12" -llama-index = ">=0.9.39" +llama-index = ">=0.9.41" html2text = "*" psutil = "*" retrying = "*"