Skip to content

Chroma in LangChain #1010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions langchain/vectorstores/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Wrappers on top of vector stores."""
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.chroma import Chroma
from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch
from langchain.vectorstores.faiss import FAISS
from langchain.vectorstores.milvus import Milvus
Expand All @@ -15,4 +16,5 @@
"Weaviate",
"Qdrant",
"Milvus",
"Chroma",
]
176 changes: 176 additions & 0 deletions langchain/vectorstores/chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""Wrapper around ChromaDB embeddings platform."""
from __future__ import annotations

import logging
import uuid
from typing import Any, Dict, Iterable, List, Optional

from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore

logger = logging.getLogger()


class Chroma(VectorStore):
"""Wrapper around ChromaDB embeddings platform.

To use, you should have the ``chromadb`` python package installed.

Example:
.. code-block:: python

from langchain.vectorstores import Chroma
from langchain.embeddings.openai import OpenAIEmbeddings

embeddings = OpenAIEmbeddings()
vectorstore = Chroma("langchain_store", embeddings.embed_query)
"""

def __init__(
self, collection_name: str, embedding_function: Optional[Embeddings] = None
) -> None:
"""Initialize with Chroma client."""
try:
import chromadb
except ImportError:
raise ValueError(
"Could not import chromadb python package. "
"Please it install it with `pip install chromadb`."
)

# TODO: Add support for custom client. For now this is in-memory only.
self._client = chromadb.Client()
self._embedding_function = embedding_function

# Check if the collection exists, create it if not
if collection_name in [col.name for col in self._client.list_collections()]:
self._collection = self._client.get_collection(name=collection_name)
if embedding_function is not None:
logger.warning(
f"Collection {collection_name} already exists,"
" embedding function will not be updated."
)
else:
self._collection = self._client.create_collection(
name=collection_name,
embedding_fn=self._embedding_function.embed_documents
if self._embedding_function is not None
else None,
)

def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.

Args:
texts (Iterable[str]): Texts to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
ids (Optional[List[str]], optional): Optional list of IDs.

Returns:
List[str]: List of IDs of the added texts.
"""
# TODO: Handle the case where the user doesn't provide ids on the Collection
if ids is None:
ids = [str(uuid.uuid1()) for _ in texts]
self._collection.add(metadatas=metadatas, documents=texts, ids=ids)
return ids

def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Run similarity search with Chroma.

Args:
query (str): Query text to search for.
k (int): Number of results to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.

Returns:
List[Document]: List of documents most simmilar to the query text.
"""
if self._embedding_function is None:
results = self._collection.query(
query_texts=[query], n_results=k, where=filter
)
else:
query_embedding = self._embedding_function.embed_query(query)
results = self._collection.query(
query_embeddings=[query_embedding], n_results=k, where=filter
)

print(results)

docs = [
# TODO: Chroma can do batch querying,
# we shouldn't hard code to the 1st result
Document(page_content=result[0], metadata=result[1])
for result in zip(results["documents"][0], results["metadatas"][0])
]
return docs

@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
collection_name: str = "langchain",
**kwargs: Any,
) -> Chroma:
"""Create a Chroma vectorstore from a raw documents.

Args:
collection_name (str): Name of the collection to create.
documents (List[Document]): List of documents to add.
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
ids (Optional[List[str]]): List of document IDs. Defaults to None.

Returns:
Chroma: Chroma vectorstore.
"""
chroma_collection = cls(
collection_name=collection_name, embedding_function=embedding
)
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
return chroma_collection

@classmethod
def from_documents(
cls,
documents: List[Document],
embedding: Optional[Embeddings] = None,
ids: Optional[List[str]] = None,
collection_name: str = "langchain",
**kwargs: Any,
) -> Chroma:
"""Create a Chroma vectorstore from a list of documents.

Args:
collection_name (str): Name of the collection to create.
documents (List[Document]): List of documents to add to the vectorstore.
embedding (Optional[Embeddings]): Embedding function. Defaults to None.

Returns:
Chroma: Chroma vectorstore.
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return cls.from_texts(
collection_name=collection_name,
texts=texts,
embedding=embedding,
metadatas=metadatas,
ids=ids,
)
28 changes: 28 additions & 0 deletions tests/integration_tests/vectorstores/test_chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Test Chroma functionality."""
from langchain.docstore.document import Document
from langchain.vectorstores import Chroma
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings


def test_chroma() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = Chroma.from_texts(
collection_name="test_collection", texts=texts, embedding=FakeEmbeddings()
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]


def test_chroma_with_metadatas() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = Chroma.from_texts(
collection_name="test_collection",
texts=texts,
embedding=FakeEmbeddings(),
metadatas=metadatas,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": "0"})]