Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] generate IDs during .add() if not provided #2582

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from chromadb.api.types import (
URI,
AddResult,
CollectionMetadata,
Embedding,
Include,
Expand All @@ -31,7 +32,7 @@
class AsyncCollection(CollectionCommon["AsyncServerAPI"]):
async def add(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]] = None,
embeddings: Optional[
Union[
OneOrMany[Embedding],
Expand All @@ -42,10 +43,10 @@ async def add(
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
) -> AddResult:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
ids: The ids of the embeddings you wish to add. If None, ids will be generated for you. Optional.
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
documents: The documents to associate with the embeddings. Optional.
Expand Down Expand Up @@ -74,6 +75,7 @@ async def add(
)

await self._client._add(ids, self.id, embeddings, metadatas, documents, uris)
return {"ids": ids}

async def count(self) -> int:
"""The total number of embeddings added to the database
Expand Down
8 changes: 5 additions & 3 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from chromadb.api.models.CollectionCommon import CollectionCommon
from chromadb.api.types import (
URI,
AddResult,
CollectionMetadata,
Embedding,
Include,
Expand Down Expand Up @@ -39,7 +40,7 @@ def count(self) -> int:

def add(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]] = None,
embeddings: Optional[ # type: ignore[type-arg]
Union[
OneOrMany[Embedding],
Expand All @@ -50,10 +51,10 @@ def add(
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
) -> AddResult:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
ids: The ids of the embeddings you wish to add. If None, ids will be generated for you. Optional.
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
documents: The documents to associate with the embeddings. Optional.
Expand Down Expand Up @@ -82,6 +83,7 @@ def add(
)

self._client._add(ids, self.id, embeddings, metadatas, documents, uris)
return {"ids": ids}

def get(
self,
Expand Down
30 changes: 27 additions & 3 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
cast,
)
import numpy as np
from uuid import UUID
from uuid import UUID, uuid4

import chromadb.utils.embedding_functions as ef
from chromadb.api.types import (
Expand Down Expand Up @@ -243,7 +243,7 @@ def _validate_embedding_set(

def _validate_and_prepare_embedding_set(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]],
embeddings: Optional[ # type: ignore[type-arg]
Union[
OneOrMany[Embedding],
Expand All @@ -261,6 +261,30 @@ def _validate_and_prepare_embedding_set(
Optional[Documents],
Optional[URIs],
]:
def count_one_or_many(value: OneOrMany[Any]) -> int:
if isinstance(value, list):
return len(value)
return 1

if ids is None:
if embeddings:
if isinstance(embeddings[0], list) and len(embeddings[0]) > 1:
set_size = len(embeddings)
else:
set_size = 1
elif documents:
set_size = count_one_or_many(documents)
elif images:
set_size = count_one_or_many(images)
elif uris:
set_size = count_one_or_many(uris)
else:
raise ValueError(
"You must provide either ids, embeddings, documents, images, or uris."
)

ids = [str(uuid4()) for _ in range(set_size)]

(
ids,
embeddings,
Expand All @@ -269,7 +293,7 @@ def _validate_and_prepare_embedding_set(
images,
uris,
) = self._validate_embedding_set(
ids, embeddings, metadatas, documents, images, uris
cast(OneOrMany[ID], ids), embeddings, metadatas, documents, images, uris
)

# We need to compute the embeddings if they're not provided
Expand Down
4 changes: 4 additions & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ class IncludeEnum(str, Enum):
L = TypeVar("L", covariant=True, bound=Loadable)


class AddResult(TypedDict):
ids: List[ID]


class GetResult(TypedDict):
ids: List[ID]
embeddings: Optional[List[Embedding]]
Expand Down
6 changes: 4 additions & 2 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,14 @@ def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID
if normalized_record_set["embeddings"]
else None,
}
self.collection.add(**normalized_record_set) # type: ignore[arg-type]
result = self.collection.add(**normalized_record_set) # type: ignore[arg-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't currently generate RecordSet without IDs. Do we want to do that in property tests? This might get more complicated if we want to allow some IDs to be None.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens in the tests when we first do an add with IDs, then one without? The comparison here would fail I think.

assert result["ids"] == normalized_record_set["ids"]
self._upsert_embeddings(cast(strategies.RecordSet, filtered_record_set))
return multiple(*filtered_record_set["ids"])

else:
self.collection.add(**normalized_record_set) # type: ignore[arg-type]
result = self.collection.add(**normalized_record_set) # type: ignore[arg-type]
assert result["ids"] == normalized_record_set["ids"]
self._upsert_embeddings(cast(strategies.RecordSet, normalized_record_set))
return multiple(*normalized_record_set["ids"])

Expand Down
16 changes: 16 additions & 0 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# type: ignore
import traceback
import httpx
from hypothesis import given
import hypothesis.strategies as st

import chromadb
from chromadb.errors import ChromaError
from chromadb.api.fastapi import FastAPI
from chromadb.api.types import QueryResult, EmbeddingFunction, Document
from chromadb.config import Settings
from chromadb.errors import InvalidCollectionException
import chromadb.test.property.strategies as strategies
import chromadb.server.fastapi
import pytest
import tempfile
Expand Down Expand Up @@ -225,6 +228,19 @@ def test_add(client):
assert collection.count() == 2


collection_st = st.shared(strategies.collections(), key="coll")


@given(
record_set=strategies.recordsets(collection_st, min_size=1, max_size=5),
)
def test_add_without_ids(client, record_set):
client.reset()
collection = client.create_collection("testspace")
result = collection.add(**{k: v for k, v in record_set.items() if k != "ids"})
assert collection.count() == len(result["ids"])


def test_collection_add_with_invalid_collection_throws(client):
client.reset()
collection = client.create_collection("test")
Expand Down
Loading