diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index f304110b652..1c747b0d44c 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -24,6 +24,7 @@ URIs, Where, QueryResult, + AddResult, GetResult, WhereDocument, ) @@ -121,7 +122,7 @@ def _add( metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: + ) -> AddResult: """[Internal] Add embeddings to a collection specified by UUID. If (some) ids already exist, only the new embeddings will be added. diff --git a/chromadb/api/async_api.py b/chromadb/api/async_api.py index d674a918319..1ceefa8d78e 100644 --- a/chromadb/api/async_api.py +++ b/chromadb/api/async_api.py @@ -24,6 +24,7 @@ Where, QueryResult, GetResult, + AddResult, WhereDocument, ) from chromadb.config import Component, Settings @@ -112,7 +113,7 @@ async def _add( metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: + ) -> AddResult: """[Internal] Add embeddings to a collection specified by UUID. If (some) ids already exist, only the new embeddings will be added. diff --git a/chromadb/api/async_client.py b/chromadb/api/async_client.py index 0eeb7a388f4..f3f523f2cf0 100644 --- a/chromadb/api/async_client.py +++ b/chromadb/api/async_client.py @@ -14,6 +14,7 @@ EmbeddingFunction, Embeddings, GetResult, + AddResult, IDs, Include, Loadable, @@ -266,7 +267,7 @@ async def _add( metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: + ) -> AddResult: return await self._server._add( ids=ids, collection_id=collection_id, diff --git a/chromadb/api/async_fastapi.py b/chromadb/api/async_fastapi.py index 1e56eed6e66..1d867426b63 100644 --- a/chromadb/api/async_fastapi.py +++ b/chromadb/api/async_fastapi.py @@ -30,6 +30,7 @@ Where, WhereDocument, GetResult, + AddResult, QueryResult, CollectionMetadata, validate_batch, @@ -448,11 +449,25 @@ async def _add( metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: + ) -> AddResult: batch = (ids, embeddings, metadatas, documents, uris) validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()}) - await self._submit_batch(batch, "/collections/" + str(collection_id) + "/add") - return True + + resp_json = await self._make_request( + "post", + "/collections/" + str(collection_id) + "/add", + json={ + "ids": batch[0], + "embeddings": batch[1], + "metadatas": batch[2], + "documents": batch[3], + "uris": batch[4], + }, + ) + + return AddResult( + ids=resp_json["ids"], + ) @trace_method("AsyncFastAPI._update", OpenTelemetryGranularity.ALL) @override diff --git a/chromadb/api/client.py b/chromadb/api/client.py index 3cf76cd47b4..21e61c16e12 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -19,6 +19,7 @@ Loadable, Metadatas, QueryResult, + AddResult, URIs, ) from chromadb.config import Settings, System @@ -214,7 +215,7 @@ def _add( metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: + ) -> AddResult: return self._server._add( ids=ids, collection_id=collection_id, diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index e5707513d8f..3773880a6a8 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -22,6 +22,7 @@ Where, WhereDocument, GetResult, + AddResult, QueryResult, CollectionMetadata, validate_batch, @@ -413,15 +414,29 @@ def _add( metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: + ) -> AddResult: """ Adds a batch of embeddings to the database - pass in column oriented data lists """ batch = (ids, embeddings, metadatas, documents, uris) validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) - self._submit_batch(batch, "/collections/" + str(collection_id) + "/add") - return True + + resp_json = self._make_request( + "post", + "/collections/" + str(collection_id) + "/add", + json={ + "ids": batch[0], + "embeddings": batch[1], + "metadatas": batch[2], + "documents": batch[3], + "uris": batch[4], + }, + ) + + return AddResult( + ids=resp_json["ids"], + ) @trace_method("FastAPI._update", OpenTelemetryGranularity.ALL) @override diff --git a/chromadb/api/models/AsyncCollection.py b/chromadb/api/models/AsyncCollection.py index 61a00bdc6b6..5592f13327e 100644 --- a/chromadb/api/models/AsyncCollection.py +++ b/chromadb/api/models/AsyncCollection.py @@ -18,6 +18,7 @@ Where, IDs, GetResult, + AddResult, QueryResult, ID, OneOrMany, @@ -44,7 +45,7 @@ async def add( documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, - ) -> None: + ) -> AddResult: """Add embeddings to the data store. Args: ids: The ids of the embeddings you wish to add @@ -74,7 +75,7 @@ async def add( uris, ) - await self._client._add( + result = await self._client._add( record_set["ids"], self.id, cast(Embeddings, record_set["embeddings"]), @@ -83,6 +84,8 @@ async def add( record_set["uris"], ) + return result + async def count(self) -> int: """The total number of embeddings added to the database @@ -266,7 +269,6 @@ async def update( documents, images, uris, - require_embeddings_or_data=False, ) await self._client._update( @@ -310,7 +312,6 @@ async def upsert( documents, images, uris, - require_embeddings_or_data=True, ) await self._client._upsert( diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index 336f2633d97..feb73582efc 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -14,6 +14,7 @@ Where, IDs, GetResult, + AddResult, QueryResult, ID, OneOrMany, @@ -40,7 +41,7 @@ def count(self) -> int: def add( self, - ids: OneOrMany[ID], + ids: Optional[OneOrMany[ID]], embeddings: Optional[ # type: ignore[type-arg] Union[ OneOrMany[Embedding], @@ -51,7 +52,7 @@ def add( documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, - ) -> None: + ) -> AddResult: """Add embeddings to the data store. Args: ids: The ids of the embeddings you wish to add @@ -81,7 +82,7 @@ def add( uris, ) - self._client._add( + result = self._client._add( record_set["ids"], self.id, cast(Embeddings, record_set["embeddings"]), @@ -90,6 +91,8 @@ def add( record_set["uris"], ) + return result + def get( self, ids: Optional[OneOrMany[ID]] = None, @@ -264,7 +267,6 @@ def update( documents, images, uris, - require_embeddings_or_data=False, ) self._client._update( @@ -308,7 +310,6 @@ def upsert( documents, images, uris, - require_embeddings_or_data=True, ) self._client._upsert( diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index 36e9f44f05d..8f04b590605 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -371,7 +371,7 @@ def _update_model_after_modify_success( def _process_add_request( self, - ids: OneOrMany[ID], + ids: Optional[OneOrMany[ID]], embeddings: Optional[ Union[ OneOrMany[Embedding], @@ -384,7 +384,7 @@ def _process_add_request( uris: Optional[OneOrMany[URI]] = None, ) -> RecordSet: unpacked_embedding_set = self._unpack_embedding_set( - ids, + ids if ids is not None else [], embeddings, metadatas, documents, @@ -446,7 +446,6 @@ def _process_upsert_or_update_request( documents: Optional[OneOrMany[Document]], images: Optional[OneOrMany[Image]], uris: Optional[OneOrMany[URI]], - require_embeddings_or_data: bool = True, ) -> RecordSet: unpacked_embedding_set = self._unpack_embedding_set( ids, embeddings, metadatas, documents, images, uris diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index bc2197d1f3f..02c7a697b25 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -15,7 +15,11 @@ from chromadb.ingest import Producer from chromadb.types import Collection as CollectionModel from chromadb import __version__ -from chromadb.errors import InvalidDimensionException, InvalidCollectionException +from chromadb.errors import ( + InvalidDimensionException, + InvalidCollectionException, + InvalidInputError, +) from chromadb.api.types import ( URI, @@ -31,6 +35,7 @@ WhereDocument, Include, GetResult, + AddResult, QueryResult, validate_metadata, validate_update_metadata, @@ -339,11 +344,18 @@ def _add( metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: + ) -> AddResult: self._quota.static_check(metadatas, documents, embeddings, str(collection_id)) coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.ADD) + ids = self.generate_ids_when_not_present( + ids=ids, + n_documents=len(documents) if documents is not None else 0, + n_uris=len(uris) if uris is not None else 0, + n_embeddings=len(embeddings), + ) + self._validate_embedding_record_set( collection=coll, ids=ids, @@ -375,7 +387,7 @@ def _add( with_uris=len(ids) if uris is not None else 0, ) ) - return True + return AddResult(ids=ids) @trace_method("SegmentAPI._update", OpenTelemetryGranularity.OPERATION) @override @@ -814,6 +826,30 @@ def get_settings(self) -> Settings: def get_max_batch_size(self) -> int: return self._producer.max_batch_size + @staticmethod + def generate_ids_when_not_present( + ids: Optional[IDs], + n_documents: int, + n_uris: int, + n_embeddings: int, + ) -> IDs: + if ids is not None and len(ids) != 0: + return ids + + n = 0 + if n_documents > 0: + n = n_documents + elif n_uris > 0: + n = n_uris + elif n_embeddings > 0: + n = n_embeddings + + generated_ids: List[str] = [] + for _ in range(n): + generated_ids.append(str(uuid4())) + + return generated_ids + # TODO: This could potentially cause race conditions in a distributed version of the # system, since the cache is only local. # TODO: promote collection -> topic to a base class method so that it can be @@ -838,17 +874,20 @@ def _validate_embedding_record_set( add_attributes_to_current_span({"collection_id": str(collection["id"])}) - validate_ids(ids) - validate_embeddings(embeddings) if embeddings is not None else None - validate_metadatas(metadatas) if metadatas is not None else None - if ( require_embeddings_or_data - and embeddings is None - and documents is None - and uris is None + and (embeddings is None or len(embeddings) == 0) + and (documents is None or len(documents) == 0) + and (uris is None or len(uris) == 0) ): - raise ValueError("You must provide embeddings, documents, or uris.") + raise InvalidInputError("You must provide embeddings, documents, or uris.") + + try: + validate_ids(ids) + validate_embeddings(embeddings) if embeddings is not None else None + validate_metadatas(metadatas) if metadatas is not None else None + except ValueError as e: + raise InvalidInputError(str(e)) from e entities: List[Tuple[Any, str]] = [ (embeddings, "embeddings"), @@ -866,7 +905,7 @@ def _validate_embedding_record_set( n = len(entity[0]) if n != len(ids): - raise ValueError( + raise InvalidInputError( f"Number of {name} ({n}) does not match number of ids ({n_ids})" ) diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 393cef7f090..6297bcb3362 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -68,7 +68,7 @@ def maybe_cast_one_to_many_embedding( if target is None: return None - if isinstance(target, List): + if isinstance(target, List) and len(target) != 0: # One Embedding if isinstance(target[0], (int, float)): return cast(Embeddings, [target]) @@ -211,6 +211,10 @@ class GetResult(TypedDict): included: Include +class AddResult(TypedDict): + ids: List[ID] + + class QueryResult(TypedDict): ids: List[IDs] embeddings: Optional[List[List[Embedding]]] @@ -282,11 +286,16 @@ def validate_ids(ids: IDs) -> IDs: raise ValueError(f"Expected IDs to be a list, got {type(ids).__name__} as IDs") if len(ids) == 0: raise ValueError(f"Expected IDs to be a non-empty list, got {len(ids)} IDs") + seen = set() dups = set() for id_ in ids: if not isinstance(id_, str): raise ValueError(f"Expected ID to be a str, got {id_}") + + if len(id_) == 0: + raise ValueError("Expected ID to be a non-empty str, got empty string") + if id_ in seen: dups.add(id_) else: diff --git a/chromadb/errors.py b/chromadb/errors.py index ff3a37a8692..ebddc818c32 100644 --- a/chromadb/errors.py +++ b/chromadb/errors.py @@ -34,6 +34,17 @@ def name(cls) -> str: return "InvalidCollection" +class InvalidInputError(ChromaError): + @overrides + def code(self) -> int: + return 400 # Bad Request + + @classmethod + @overrides + def name(cls) -> str: + return "InvalidInput" + + class IDAlreadyExistsError(ChromaError): @overrides def code(self) -> int: diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index bf146a2dd78..1c3bb8453bb 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -25,7 +25,7 @@ from chromadb.api.configuration import CollectionConfigurationInternal from pydantic import BaseModel -from chromadb.api.types import GetResult, QueryResult +from chromadb.api.types import GetResult, QueryResult, AddResult from chromadb.auth import ( AuthzAction, AuthzResource, @@ -747,10 +747,10 @@ async def delete_collection( @trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION) async def add( self, request: Request, collection_id: str, body: AddEmbedding = Body(...) - ) -> bool: + ) -> AddResult: try: - def process_add(request: Request, raw_body: bytes) -> bool: + def process_add(request: Request, raw_body: bytes) -> AddResult: add = validate_model(AddEmbedding, orjson.loads(raw_body)) self.auth_and_get_tenant_and_database_for_request( request.headers, @@ -769,7 +769,7 @@ def process_add(request: Request, raw_body: bytes) -> bool: ) return cast( - bool, + AddResult, await to_thread.run_sync( process_add, request, diff --git a/chromadb/test/api/test_validations.py b/chromadb/test/api/test_validations.py new file mode 100644 index 00000000000..b1e41ab06dc --- /dev/null +++ b/chromadb/test/api/test_validations.py @@ -0,0 +1,46 @@ +import pytest +from typing import cast + +import chromadb.errors as errors + +from chromadb.api.types import IDs, validate_ids + + +def test_ids_validation(): + ids = ["id1", "id2", "id3"] + assert validate_ids(ids) == ids + + with pytest.raises(ValueError, match="Expected IDs to be a list"): + validate_ids(cast(IDs, "not a list")) + + with pytest.raises(ValueError, match="Expected IDs to be a non-empty list"): + validate_ids([]) + + with pytest.raises(ValueError, match="Expected ID to be a str"): + validate_ids(cast(IDs, ["id1", 123, "id3"])) + + with pytest.raises(ValueError, match="Expected ID to be a non-empty str"): + validate_ids(["id1", "", "id3"]) + + with pytest.raises(errors.DuplicateIDError, match="Expected IDs to be unique"): + validate_ids(["id1", "id2", "id1"]) + + ids = [ + "id1", + "id2", + "id3", + "id4", + "id5", + "id6", + "id7", + "id8", + "id9", + "id10", + "id11", + "id12", + "id13", + "id14", + "id15", + ] * 2 + with pytest.raises(errors.DuplicateIDError, match="found 15 duplicated IDs: "): + validate_ids(ids) diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 28f20b940c7..dd1ab41a003 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -17,6 +17,7 @@ EmbeddingFunction, Embeddings, Metadata, + IDs, ) from chromadb.types import LiteralValue, WhereOperator, LogicalOperator @@ -461,14 +462,21 @@ def recordsets( ) -> RecordSet: collection = draw(collection_strategy) - ids = list( + ids: IDs = list( draw(st.lists(id_strategy, min_size=min_size, max_size=max_size, unique=True)) ) + n_records = len(ids) + + if len(ids) == 0: + n_records += 1 + embeddings: Optional[Embeddings] = None if collection.has_embeddings: - embeddings = create_embeddings(collection.dimension, len(ids), collection.dtype) - num_metadata = num_unique_metadata if num_unique_metadata is not None else len(ids) + embeddings = create_embeddings( + collection.dimension, n_records, collection.dtype + ) + num_metadata = num_unique_metadata if num_unique_metadata is not None else n_records generated_metadatas = draw( st.lists( metadata( @@ -479,20 +487,22 @@ def recordsets( ) ) metadatas = [] - for i in range(len(ids)): + for i in range(n_records): metadatas.append(generated_metadatas[i % len(generated_metadatas)]) documents: Optional[Documents] = None if collection.has_documents: documents = draw( - st.lists(document(collection), min_size=len(ids), max_size=len(ids)) + st.lists(document(collection), min_size=n_records, max_size=n_records) ) # in the case where we have a single record, sometimes exercise # the code that handles individual values rather than lists. # In this case, any field may be a list or a single value. - if len(ids) == 1: - single_id: Union[str, List[str]] = ids[0] if draw(st.booleans()) else ids + if n_records == 1: + single_id: Union[str, List[str]] = ( + ids[0] if draw(st.booleans()) and len(ids) == 1 else ids + ) single_embedding = ( embeddings[0] if embeddings is not None and draw(st.booleans()) diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py index ca0e6c660b5..3cb9c99bb37 100644 --- a/chromadb/test/property/test_add.py +++ b/chromadb/test/property/test_add.py @@ -27,7 +27,7 @@ # record sets so we explicitly create a large record set without using Hypothesis @given( collection=collection_st, - record_set=strategies.recordsets(collection_st, min_size=1, max_size=500), + record_set=strategies.recordsets(collection_st, min_size=0, max_size=500), should_compact=st.booleans(), ) @settings( @@ -105,9 +105,12 @@ def _test_add( # TODO: The type of add() is incorrect as it does not allow for metadatas # like [{"a": 1}, None, {"a": 3}] - coll.add(**record_set) # type: ignore + result = coll.add(**record_set) # type: ignore # Only wait for compaction if the size of the collection is # some minimal size + + normalized_record_set["ids"] = result["ids"] + if ( not NOT_CLUSTER_ONLY and should_compact @@ -287,3 +290,62 @@ def test_add_partial(client: ClientAPI) -> None: assert results["ids"] == ["1", "2", "3"] assert results["metadatas"] == [{"a": 1}, None, {"a": 3}] assert results["documents"] == ["a", "b", None] + + +def test_add_with_no_ids(client: ClientAPI) -> None: + """Tests adding a record set with some of the fields set to None.""" + reset(client) + + coll = client.create_collection("test") + # TODO: We need to clean up the api types to support this typing + coll.add( + ids=[], + embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore + metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore + documents=["a", "b", None], # type: ignore + ) + + results = coll.get() + assert len(results["ids"]) == 3 + + +def test_add_with_partial_ids(client: ClientAPI) -> None: + """Tests adding a record set with some of the fields set to None.""" + reset(client) + + coll = client.create_collection("test") + # TODO: We need to clean up the api types to support this typing + + with pytest.raises(Exception, match="Expected ID to be a non-empty str"): + coll.add( + ids=["1", ""], + embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore + metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore + documents=["a", "b", None], # type: ignore + ) + + with pytest.raises(Exception, match="does not match number of ids"): + coll.add( + ids=["1", "2"], + embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore + metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore + documents=["a", "b", None], # type: ignore + ) + + +def test_add_with_no_data(client: ClientAPI) -> None: + """Tests adding a record set with some of the fields set to None.""" + reset(client) + + coll = client.create_collection("test") + # TODO: We need to clean up the api types to support this typing + + with pytest.raises( + Exception, match="You must provide embeddings, documents, or uris." + ): + coll.add( + ids=["1"], + embeddings=[], # type: ignore + metadatas=[{"a": 1}], # type: ignore + documents=[], # type: ignore + ) diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index 91ea7bb76c1..f96cae52e47 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -256,7 +256,7 @@ def persist_generated_data_with_old_version( # In order to test old versions, we can't rely on the not_implemented function embedding_function=not_implemented_ef(), ) - coll.add(**embeddings_strategy) + result = coll.add(**embeddings_strategy) # Just use some basic checks for sanity and manual testing where you break the new # version @@ -293,7 +293,7 @@ def persist_generated_data_with_old_version( @given( collection_strategy=collection_st, - embeddings_strategy=strategies.recordsets(collection_st), + embeddings_strategy=strategies.recordsets(collection_st, min_size=1), ) @settings(deadline=None) def test_cycle_versions( diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index dc53bbc52d7..fbe6cacf33a 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -104,7 +104,7 @@ def teardown(self) -> None: @rule( target=embedding_ids, - record_set=strategies.recordsets(collection_st), + record_set=strategies.recordsets(collection_st, min_size=0), ) def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID]: trace("add_embeddings") @@ -141,7 +141,10 @@ def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID return multiple(*filtered_record_set["ids"]) else: - self.collection.add(**normalized_record_set) # type: ignore[arg-type] + result = self.collection.add(**normalized_record_set) # type: ignore[arg-type] + + print(result) + normalized_record_set["ids"] = result["ids"] self._upsert_embeddings(cast(strategies.RecordSet, normalized_record_set)) return multiple(*normalized_record_set["ids"]) @@ -366,7 +369,7 @@ def wait_for_compaction(self) -> None: @rule( target=embedding_ids, - record_set=strategies.recordsets(collection_st), + record_set=strategies.recordsets(collection_st, min_size=0), ) def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID]: res = super().add_embeddings(record_set) diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index b252b620cb5..e031ce4b810 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -171,7 +171,8 @@ def _filter_embedding_set( key="coll", ) recordset_st = st.shared( - strategies.recordsets(collection_st, max_size=1000), key="recordset" + strategies.recordsets(collection_st, max_size=1000), + key="recordset", ) diff --git a/chromadb/test/property/test_persist.py b/chromadb/test/property/test_persist.py index 267692ce485..134a3c633ac 100644 --- a/chromadb/test/property/test_persist.py +++ b/chromadb/test/property/test_persist.py @@ -74,7 +74,7 @@ def settings(request: pytest.FixtureRequest) -> Generator[Settings, None, None]: @given( collection_strategy=collection_st, - embeddings_strategy=strategies.recordsets(collection_st), + embeddings_strategy=strategies.recordsets(collection_st, min_size=0), ) def test_persist( settings: Settings, @@ -92,7 +92,8 @@ def test_persist( embedding_function=collection_strategy.embedding_function, ) - coll.add(**embeddings_strategy) # type: ignore[arg-type] + result = coll.add(**embeddings_strategy) # type: ignore[arg-type] + embeddings_strategy["ids"] = result["ids"] invariants.count(coll, embeddings_strategy) invariants.metadatas_match(coll, embeddings_strategy) diff --git a/clients/js/src/ChromaClient.ts b/clients/js/src/ChromaClient.ts index 3f2ebc50cde..4537b800ce4 100644 --- a/clients/js/src/ChromaClient.ts +++ b/clients/js/src/ChromaClient.ts @@ -2,7 +2,6 @@ import { AdminClient } from "./AdminClient"; import { authOptionsToAuthProvider, ClientAuthProvider } from "./auth"; import { chromaFetch } from "./ChromaFetch"; import { DefaultEmbeddingFunction } from "./embeddings/DefaultEmbeddingFunction"; -import { ChromaConnectionError, ChromaServerError } from "./Errors"; import { Configuration, ApiApi as DefaultApi, @@ -18,7 +17,6 @@ import type { CreateCollectionParams, DeleteCollectionParams, DeleteParams, - Embedding, Embeddings, GetCollectionParams, GetOrCreateCollectionParams, @@ -416,10 +414,10 @@ export class ChromaClient { async addRecords( collection: Collection, params: AddRecordsParams, - ): Promise { + ): Promise { await this.init(); - await this.api.add( + const resp = (await this.api.add( collection.id, // TODO: For some reason the auto generated code requires metadata to be defined here. (await prepareRecordRequest( @@ -427,7 +425,9 @@ export class ChromaClient { collection.embeddingFunction, )) as GeneratedApi.AddEmbedding, this.api.options, - ); + )) as AddResponse; + + return resp; } /** diff --git a/clients/js/src/ChromaFetch.ts b/clients/js/src/ChromaFetch.ts index ebd98cc15a9..854088c0d25 100644 --- a/clients/js/src/ChromaFetch.ts +++ b/clients/js/src/ChromaFetch.ts @@ -59,7 +59,7 @@ export const chromaFetch: FetchAPI = async ( switch (resp.status) { case 400: throw new ChromaClientError( - `Bad request to ${input} with status: ${resp.statusText}`, + `Bad request to ${input} with status: ${respBody?.message}`, ); case 401: throw new ChromaUnauthorizedError(`Unauthorized`); diff --git a/clients/js/src/types.ts b/clients/js/src/types.ts index c2494b11825..250931fe185 100644 --- a/clients/js/src/types.ts +++ b/clients/js/src/types.ts @@ -91,7 +91,9 @@ export type MultiQueryResponse = { export type QueryResponse = SingleQueryResponse | MultiQueryResponse; -export type AddResponse = {}; +export type AddResponse = { + ids: IDs; +}; export interface Collection { name: string; @@ -164,6 +166,13 @@ export type BaseRecordOperationParams = { documents?: Document | Documents; }; +export type BaseRecordOperationParamsWithIDsOptional = { + ids?: ID | IDs; + embeddings?: Embedding | Embeddings; + metadatas?: Metadata | Metadatas; + documents?: Document | Documents; +}; + export type SingleRecordOperationParams = BaseRecordOperationParams & { ids: ID; embeddings?: Embedding; @@ -171,6 +180,14 @@ export type SingleRecordOperationParams = BaseRecordOperationParams & { documents?: Document; }; +export type SingleRecordOperationParamsWithIDsOptional = + BaseRecordOperationParamsWithIDsOptional & { + ids?: ID; + embeddings?: Embedding; + metadatas?: Metadata; + documents?: Document; + }; + type SingleEmbeddingRecordOperationParams = SingleRecordOperationParams & { embeddings: Embedding; }; @@ -179,9 +196,15 @@ type SingleContentRecordOperationParams = SingleRecordOperationParams & { documents: Document; }; -export type SingleAddRecordOperationParams = - | SingleEmbeddingRecordOperationParams - | SingleContentRecordOperationParams; +type SingleEmbeddingRecordOperationParamsWithOptionalIDs = + BaseRecordOperationParamsWithIDsOptional & { + embeddings: Embedding; + }; + +type SingleContentRecordOperationParamsWithOptionalIDs = + BaseRecordOperationParamsWithIDsOptional & { + documents: Document; + }; export type MultiRecordOperationParams = BaseRecordOperationParams & { ids: IDs; @@ -190,6 +213,14 @@ export type MultiRecordOperationParams = BaseRecordOperationParams & { documents?: Documents; }; +export type MultiRecordOperationParamsWithIDsOptional = + BaseRecordOperationParamsWithIDsOptional & { + ids?: IDs; + embeddings?: Embeddings; + metadatas?: Metadatas; + documents?: Documents; + }; + type MultiEmbeddingRecordOperationParams = MultiRecordOperationParams & { embeddings: Embeddings; }; @@ -198,15 +229,40 @@ type MultiContentRecordOperationParams = MultiRecordOperationParams & { documents: Documents; }; +type MultiEmbeddingRecordOperationParamsWithOptionalIDs = + MultiRecordOperationParamsWithIDsOptional & { + embeddings: Embeddings; + }; + +type MultiContentRecordOperationParamsWithOptionalIDs = + MultiRecordOperationParamsWithIDsOptional & { + documents: Documents; + }; + +export type SingleAddRecordOperationParams = + | SingleEmbeddingRecordOperationParams + | SingleContentRecordOperationParams; + +export type SingleAddRecordOperationParamsWithOptionalIDs = + | SingleEmbeddingRecordOperationParamsWithOptionalIDs + | SingleContentRecordOperationParamsWithOptionalIDs; + +export type MultiAddRecordsOperationParamsWithOptionalIDs = + | MultiEmbeddingRecordOperationParamsWithOptionalIDs + | MultiContentRecordOperationParamsWithOptionalIDs; + export type MultiAddRecordsOperationParams = | MultiEmbeddingRecordOperationParams | MultiContentRecordOperationParams; export type AddRecordsParams = + | SingleAddRecordOperationParamsWithOptionalIDs + | MultiAddRecordsOperationParamsWithOptionalIDs; + +export type UpsertRecordsParams = | SingleAddRecordOperationParams | MultiAddRecordsOperationParams; -export type UpsertRecordsParams = AddRecordsParams; export type UpdateRecordsParams = | MultiRecordOperationParams | SingleRecordOperationParams; diff --git a/clients/js/src/utils.ts b/clients/js/src/utils.ts index f8711d11e89..a811260d6ab 100644 --- a/clients/js/src/utils.ts +++ b/clients/js/src/utils.ts @@ -4,6 +4,7 @@ import { IEmbeddingFunction } from "./embeddings/IEmbeddingFunction"; import { AddRecordsParams, BaseRecordOperationParams, + BaseRecordOperationParamsWithIDsOptional, Collection, Metadata, MultiRecordOperationParams, @@ -82,10 +83,10 @@ export function isBrowser() { } function arrayifyParams( - params: BaseRecordOperationParams, + params: BaseRecordOperationParamsWithIDsOptional, ): MultiRecordOperationParams { return { - ids: toArray(params.ids), + ids: params.ids !== undefined ? toArray(params.ids) : [], embeddings: params.embeddings ? toArrayOfArrays(params.embeddings) : undefined, @@ -125,16 +126,6 @@ export async function prepareRecordRequest( } } - if ( - (embeddingsArray !== undefined && ids.length !== embeddingsArray.length) || - (metadatas !== undefined && ids.length !== metadatas.length) || - (documents !== undefined && ids.length !== documents.length) - ) { - throw new Error( - "ids, embeddings, metadatas, and documents must all be the same length", - ); - } - const uniqueIds = new Set(ids); if (uniqueIds.size !== ids.length) { const duplicateIds = ids.filter( diff --git a/clients/js/test/add.collections.test.ts b/clients/js/test/add.collections.test.ts index f1ee66fbd4f..0a49cb01ce7 100644 --- a/clients/js/test/add.collections.test.ts +++ b/clients/js/test/add.collections.test.ts @@ -137,6 +137,14 @@ describe("add collections", () => { } }); + test("It should generate IDs if not provided", async () => { + const collection = await client.createCollection({ name: "test" }); + const embeddings = EMBEDDINGS.concat([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]); + const metadatas = METADATAS.concat([{ test: "test1", float_value: 0.1 }]); + const resp = await client.addRecords(collection, { embeddings, metadatas }); + expect(resp.ids.length).toEqual(4); + }); + test("should error on empty embedding", async () => { const collection = await client.createCollection({ name: "test" }); const ids = ["id1"];