Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from langchain.schema.retriever import BaseRetriever
from langchain.vectorstores.chroma import Chroma

from langchain_benchmarks.rag.utils._downloading import (
fetch_remote_file,
)
from langchain_benchmarks.rag.utils.indexing import (
get_hyde_retriever,
get_parent_document_retriever,
Expand All @@ -30,6 +33,10 @@ def load_docs_from_parquet(filename: Optional[str] = None) -> Iterable[Document]
"Please install pandas to use the langchain docs benchmarking task.\n"
"pip install pandas"
)
if filename is None:
filename = DOCS_FILE
if not os.path.exists(filename):
fetch_remote_file(REMOTE_DOCS_FILE, filename)
df = pd.read_parquet(filename)
docs_transformed = [Document(**row) for row in df.to_dict(orient="records")]
for doc in docs_transformed:
Expand All @@ -43,11 +50,13 @@ def load_docs_from_parquet(filename: Optional[str] = None) -> Iterable[Document]

def _chroma_retriever_factory(
embedding: Embeddings,
*,
docs: Optional[Iterable[Document]] = None,
search_kwargs: Optional[dict] = None,
transform_docs: Optional[Callable] = None,
transformation_name: Optional[str] = None,
) -> BaseRetriever:
docs = load_docs_from_parquet(DOCS_FILE)
docs = docs or load_docs_from_parquet()
embedding_name = embedding.__class__.__name__
vectorstore = Chroma(
collection_name=f"langchain-benchmarks-classic-{embedding_name}",
Expand All @@ -67,9 +76,11 @@ def _chroma_retriever_factory(

def _chroma_parent_document_retriever_factory(
embedding: Embeddings,
*,
docs: Optional[Iterable[Document]] = None,
search_kwargs: Optional[dict] = None,
) -> BaseRetriever:
docs = load_docs_from_parquet(DOCS_FILE)
docs = docs or load_docs_from_parquet()
embedding_name = embedding.__class__.__name__
vectorstore = Chroma(
collection_name=f"langchain-benchmarks-parent-doc-{embedding_name}",
Expand All @@ -87,9 +98,11 @@ def _chroma_parent_document_retriever_factory(

def _chroma_hyde_retriever_factory(
embedding: Embeddings,
*,
docs: Optional[Iterable[Document]] = None,
search_kwargs: Optional[dict] = None,
) -> BaseRetriever:
docs = load_docs_from_parquet(DOCS_FILE)
docs = docs or load_docs_from_parquet()
embedding_name = embedding.__class__.__name__
vectorstore = Chroma(
collection_name=f"langchain-benchmarks-hyde-{embedding_name}",
Expand Down
12 changes: 10 additions & 2 deletions langchain_benchmarks/rag/tasks/langchain_docs/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from functools import partial
from typing import Iterable

from langchain.schema.document import Document

from langchain_benchmarks.rag.tasks.langchain_docs import architectures, indexing
from langchain_benchmarks.rag.tasks.langchain_docs.indexing.retriever_registry import (
Expand All @@ -11,12 +13,18 @@
"452ccafc-18e1-4314-885b-edd735f17b9d" # ID of public LangChain Docs dataset
)


def load_cached_docs() -> Iterable[Document]:
"""Load the docs from the cached file."""
return load_docs_from_parquet(DOCS_FILE)


LANGCHAIN_DOCS_TASK = RetrievalTask(
name="LangChain Docs Q&A",
dataset_id=DATASET_ID,
retriever_factories=indexing.RETRIEVER_FACTORIES,
architecture_factories=architectures.ARCH_FACTORIES,
get_docs=partial(load_docs_from_parquet, DOCS_FILE),
get_docs=load_cached_docs,
description=(
"""\
Questions and answers based on a snapshot of the LangChain python docs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def _chroma_parent_document_retriever_factory(
*,
docs: Optional[Iterable[Document]] = None,
search_kwargs: Optional[dict] = None,
transformation_name: Optional[str] = None,
) -> BaseRetriever:
docs = docs or load_docs()
embedding_name = embedding.__class__.__name__
Expand All @@ -140,6 +141,7 @@ def _chroma_parent_document_retriever_factory(
vectorstore,
collection_name="semi-structured-earnings",
search_kwargs=search_kwargs or _DEFAULT_SEARCH_KWARGS,
transformation_name=transformation_name,
)


Expand All @@ -148,6 +150,7 @@ def _chroma_hyde_retriever_factory(
*,
docs: Optional[Iterable[Document]] = None,
search_kwargs: Optional[dict] = None,
transformation_name: Optional[str] = None,
) -> BaseRetriever:
docs = docs or load_docs()
embedding_name = embedding.__class__.__name__
Expand All @@ -162,6 +165,7 @@ def _chroma_hyde_retriever_factory(
vectorstore,
collection_name="semi-structured-earnings",
search_kwargs=search_kwargs or _DEFAULT_SEARCH_KWARGS,
transformation_name=transformation_name,
)


Expand Down
27 changes: 27 additions & 0 deletions langchain_benchmarks/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Schema for the Langchain Benchmarks."""
from __future__ import annotations

import dataclasses
import inspect
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union

from langchain.prompts import ChatPromptTemplate
Expand Down Expand Up @@ -104,6 +107,7 @@ def _table(self) -> List[List[str]]:
return table + [
["Retriever Factories", ", ".join(self.retriever_factories.keys())],
["Architecture Factories", ", ".join(self.architecture_factories.keys())],
["get_docs", self.get_docs],
]


Expand Down Expand Up @@ -149,6 +153,29 @@ def _repr_html_(self) -> str:
]
return tabulate(table, headers=headers, tablefmt="html")

def filter(
self,
Type: Optional[str],
dataset_id: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = None,
) -> Registry:
"""Filter the tasks in the registry."""
tasks = self.tasks
if Type is not None:
tasks = [task for task in tasks if task.__class__.__name__ == Type]
if dataset_id is not None:
tasks = [task for task in tasks if task.dataset_id == dataset_id]
if name is not None:
tasks = [task for task in tasks if task.name == name]
if description is not None:
tasks = [
task
for task in tasks
if description.lower() in task.description.lower()
]
return Registry(tasks=tasks)

def __getitem__(self, key: Union[int, str]) -> BaseTask:
"""Get an environment from the registry."""
if isinstance(key, slice):
Expand Down