Skip to content

Commit

Permalink
[ENH] generate IDs during .add() if not provided
Browse files Browse the repository at this point in the history
Closes #2286.
  • Loading branch information
codetheweb committed Jul 26, 2024
1 parent ea0ac35 commit 785defe
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 11 deletions.
8 changes: 5 additions & 3 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from chromadb.api.types import (
URI,
AddResult,
CollectionMetadata,
Embedding,
Include,
Expand All @@ -31,7 +32,7 @@
class AsyncCollection(CollectionCommon["AsyncServerAPI"]):
async def add(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]] = None,
embeddings: Optional[
Union[
OneOrMany[Embedding],
Expand All @@ -42,10 +43,10 @@ async def add(
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
) -> AddResult:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
ids: The ids of the embeddings you wish to add. If None, ids will be generated for you. Optional.
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
documents: The documents to associate with the embeddings. Optional.
Expand Down Expand Up @@ -74,6 +75,7 @@ async def add(
)

await self._client._add(ids, self.id, embeddings, metadatas, documents, uris)
return {"ids": ids}

async def count(self) -> int:
"""The total number of embeddings added to the database
Expand Down
8 changes: 5 additions & 3 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from chromadb.api.models.CollectionCommon import CollectionCommon
from chromadb.api.types import (
URI,
AddResult,
CollectionMetadata,
Embedding,
Include,
Expand Down Expand Up @@ -39,7 +40,7 @@ def count(self) -> int:

def add(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]] = None,
embeddings: Optional[ # type: ignore[type-arg]
Union[
OneOrMany[Embedding],
Expand All @@ -50,10 +51,10 @@ def add(
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
) -> AddResult:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
ids: The ids of the embeddings you wish to add. If None, ids will be generated for you. Optional.
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
documents: The documents to associate with the embeddings. Optional.
Expand Down Expand Up @@ -82,6 +83,7 @@ def add(
)

self._client._add(ids, self.id, embeddings, metadatas, documents, uris)
return {"ids": ids}

def get(
self,
Expand Down
22 changes: 19 additions & 3 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
cast,
)
import numpy as np
from uuid import UUID
from uuid import UUID, uuid4

import chromadb.utils.embedding_functions as ef
from chromadb.api.types import (
Expand Down Expand Up @@ -243,7 +243,7 @@ def _validate_embedding_set(

def _validate_and_prepare_embedding_set(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]],
embeddings: Optional[ # type: ignore[type-arg]
Union[
OneOrMany[Embedding],
Expand All @@ -261,6 +261,22 @@ def _validate_and_prepare_embedding_set(
Optional[Documents],
Optional[URIs],
]:
if ids is None:
if embeddings:
set_size = len(embeddings)
elif documents:
set_size = len(documents)
elif images:
set_size = len(images)
elif uris:
set_size = len(uris)
else:
raise ValueError(
"You must provide either ids, embeddings, documents, images, or uris."
)

ids = [str(uuid4()) for _ in range(set_size)]

(
ids,
embeddings,
Expand All @@ -269,7 +285,7 @@ def _validate_and_prepare_embedding_set(
images,
uris,
) = self._validate_embedding_set(
ids, embeddings, metadatas, documents, images, uris
cast(OneOrMany[ID], ids), embeddings, metadatas, documents, images, uris
)

# We need to compute the embeddings if they're not provided
Expand Down
4 changes: 4 additions & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ class IncludeEnum(str, Enum):
L = TypeVar("L", covariant=True, bound=Loadable)


class AddResult(TypedDict):
ids: List[ID]


class GetResult(TypedDict):
ids: List[ID]
embeddings: Optional[List[Embedding]]
Expand Down
6 changes: 4 additions & 2 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,14 @@ def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID
if normalized_record_set["embeddings"]
else None,
}
self.collection.add(**normalized_record_set) # type: ignore[arg-type]
result = self.collection.add(**normalized_record_set) # type: ignore[arg-type]
assert result["ids"] == normalized_record_set["ids"]
self._upsert_embeddings(cast(strategies.RecordSet, filtered_record_set))
return multiple(*filtered_record_set["ids"])

else:
self.collection.add(**normalized_record_set) # type: ignore[arg-type]
result = self.collection.add(**normalized_record_set) # type: ignore[arg-type]
assert result["ids"] == normalized_record_set["ids"]
self._upsert_embeddings(cast(strategies.RecordSet, normalized_record_set))
return multiple(*normalized_record_set["ids"])

Expand Down
16 changes: 16 additions & 0 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,22 @@ def test_add(client):
assert collection.count() == 2


def test_add_embeddings_without_ids(client):
client.reset()
collection = client.create_collection("testspace")
result = collection.add(embeddings=[[0, 0], [1, 1]])
assert len(result["ids"]) == 2
assert collection.count() == 2


def test_add_documents_without_ids(client):
client.reset()
collection = client.create_collection("testspace")
result = collection.add(documents=["hello", "world"])
assert len(result["ids"]) == 2
assert collection.count() == 2


def test_collection_add_with_invalid_collection_throws(client):
client.reset()
collection = client.create_collection("test")
Expand Down

0 comments on commit 785defe

Please sign in to comment.