diff --git a/chromadb/api/models/AsyncCollection.py b/chromadb/api/models/AsyncCollection.py index d9d42bef3e7..636017dc511 100644 --- a/chromadb/api/models/AsyncCollection.py +++ b/chromadb/api/models/AsyncCollection.py @@ -7,6 +7,7 @@ from chromadb.api.types import ( URI, + AddResult, CollectionMetadata, Embedding, Include, @@ -31,7 +32,7 @@ class AsyncCollection(CollectionCommon["AsyncServerAPI"]): async def add( self, - ids: OneOrMany[ID], + ids: Optional[OneOrMany[ID]] = None, embeddings: Optional[ Union[ OneOrMany[Embedding], @@ -42,10 +43,10 @@ async def add( documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, - ) -> None: + ) -> AddResult: """Add embeddings to the data store. Args: - ids: The ids of the embeddings you wish to add + ids: The ids of the embeddings you wish to add. If None, ids will be generated for you. Optional. embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional. metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. documents: The documents to associate with the embeddings. Optional. @@ -74,6 +75,7 @@ async def add( ) await self._client._add(ids, self.id, embeddings, metadatas, documents, uris) + return {"ids": ids} async def count(self) -> int: """The total number of embeddings added to the database diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index eb8b601e5cc..fdd7bb3cf0e 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -4,6 +4,7 @@ from chromadb.api.models.CollectionCommon import CollectionCommon from chromadb.api.types import ( URI, + AddResult, CollectionMetadata, Embedding, Include, @@ -39,7 +40,7 @@ def count(self) -> int: def add( self, - ids: OneOrMany[ID], + ids: Optional[OneOrMany[ID]] = None, embeddings: Optional[ # type: ignore[type-arg] Union[ OneOrMany[Embedding], @@ -50,10 +51,10 @@ def add( documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, - ) -> None: + ) -> AddResult: """Add embeddings to the data store. Args: - ids: The ids of the embeddings you wish to add + ids: The ids of the embeddings you wish to add. If None, ids will be generated for you. Optional. embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional. metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. documents: The documents to associate with the embeddings. Optional. @@ -82,6 +83,7 @@ def add( ) self._client._add(ids, self.id, embeddings, metadatas, documents, uris) + return {"ids": ids} def get( self, diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index 233bcf4773b..287def5b02f 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -10,7 +10,7 @@ cast, ) import numpy as np -from uuid import UUID +from uuid import UUID, uuid4 import chromadb.utils.embedding_functions as ef from chromadb.api.types import ( @@ -243,7 +243,7 @@ def _validate_embedding_set( def _validate_and_prepare_embedding_set( self, - ids: OneOrMany[ID], + ids: Optional[OneOrMany[ID]], embeddings: Optional[ # type: ignore[type-arg] Union[ OneOrMany[Embedding], @@ -261,6 +261,22 @@ def _validate_and_prepare_embedding_set( Optional[Documents], Optional[URIs], ]: + if ids is None: + if embeddings: + set_size = len(embeddings) + elif documents: + set_size = len(documents) + elif images: + set_size = len(images) + elif uris: + set_size = len(uris) + else: + raise ValueError( + "You must provide either ids, embeddings, documents, images, or uris." + ) + + ids = [str(uuid4()) for _ in range(set_size)] + ( ids, embeddings, @@ -269,7 +285,7 @@ def _validate_and_prepare_embedding_set( images, uris, ) = self._validate_embedding_set( - ids, embeddings, metadatas, documents, images, uris + cast(OneOrMany[ID], ids), embeddings, metadatas, documents, images, uris ) # We need to compute the embeddings if they're not provided diff --git a/chromadb/api/types.py b/chromadb/api/types.py index f0ffc1e6ca0..272ae8ba4fa 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -165,6 +165,10 @@ class IncludeEnum(str, Enum): L = TypeVar("L", covariant=True, bound=Loadable) +class AddResult(TypedDict): + ids: List[ID] + + class GetResult(TypedDict): ids: List[ID] embeddings: Optional[List[Embedding]] diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index b97440d9a84..bdb6d413d05 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -134,12 +134,14 @@ def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID if normalized_record_set["embeddings"] else None, } - self.collection.add(**normalized_record_set) # type: ignore[arg-type] + result = self.collection.add(**normalized_record_set) # type: ignore[arg-type] + assert result["ids"] == normalized_record_set["ids"] self._upsert_embeddings(cast(strategies.RecordSet, filtered_record_set)) return multiple(*filtered_record_set["ids"]) else: - self.collection.add(**normalized_record_set) # type: ignore[arg-type] + result = self.collection.add(**normalized_record_set) # type: ignore[arg-type] + assert result["ids"] == normalized_record_set["ids"] self._upsert_embeddings(cast(strategies.RecordSet, normalized_record_set)) return multiple(*normalized_record_set["ids"]) diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 7b9bc763fff..5da1d8f8239 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -225,6 +225,22 @@ def test_add(client): assert collection.count() == 2 +def test_add_embeddings_without_ids(client): + client.reset() + collection = client.create_collection("testspace") + result = collection.add(embeddings=[[0, 0], [1, 1]]) + assert len(result["ids"]) == 2 + assert collection.count() == 2 + + +def test_add_documents_without_ids(client): + client.reset() + collection = client.create_collection("testspace") + result = collection.add(documents=["hello", "world"]) + assert len(result["ids"]) == 2 + assert collection.count() == 2 + + def test_collection_add_with_invalid_collection_throws(client): client.reset() collection = client.create_collection("test")