From 805ea9eb0c816945873bc09c559dc5470c8d63e0 Mon Sep 17 00:00:00 2001 From: Spike Lu Date: Wed, 11 Sep 2024 16:51:21 -0700 Subject: [PATCH] update tests --- chromadb/test/property/invariants.py | 23 ++----- chromadb/test/property/strategies.py | 10 +-- .../property/test_cross_version_persist.py | 61 ++++++++++--------- chromadb/test/property/test_filtering.py | 19 +++--- 4 files changed, 53 insertions(+), 60 deletions(-) diff --git a/chromadb/test/property/invariants.py b/chromadb/test/property/invariants.py index bbd55520f52..4f4382e0513 100644 --- a/chromadb/test/property/invariants.py +++ b/chromadb/test/property/invariants.py @@ -72,10 +72,11 @@ def wrap_all(record_set: RecordSet) -> NormalizedRecordSet: def get_n_items_from_record_set_state(state_record_set: StateMachineRecordSet) -> int: - normalized_record_set = wrap_all(cast(RecordSet, state_record_set)) - # we need to replace empty lists with None within the record set state to use get_n_items_from_record_set # get_n_items_from_record_set would throw an error if it encounters an empty list + if all(len(value) == 0 for value in state_record_set.values()): # type: ignore[arg-type] + return 0 + record_set_with_empty_lists_replaced: types.RecordSet = { "ids": None, "documents": None, @@ -85,22 +86,8 @@ def get_n_items_from_record_set_state(state_record_set: StateMachineRecordSet) - "uris": None, } - all_fields_are_empty = True - for key, value in normalized_record_set.items(): - if value is None: - continue - - if isinstance(value, list): - if len(value) == 0: - record_set_with_empty_lists_replaced[key] = None # type: ignore[literal-required] - continue - - all_fields_are_empty = False - - record_set_with_empty_lists_replaced[key] = value # type: ignore[literal-required] - - if all_fields_are_empty: - return 0 + for key, value in state_record_set.items(): + record_set_with_empty_lists_replaced[key] = None if len(value) == 0 else value # type: ignore[literal-required, arg-type] return types.get_n_items_from_record_set(record_set_with_empty_lists_replaced) diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index caa31832b05..b4f0699df6c 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -695,18 +695,18 @@ def filters( st.one_of(st.none(), recursive_where_doc_clause(collection)) ) - ids: Optional[Union[List[types.ID], types.ID]] + if recordset["ids"] is None: + raise ValueError("Record set IDs cannot be None") + + ids: Union[List[types.ID], types.ID] # Record sets can be a value instead of a list of values if there is only one record if isinstance(recordset["ids"], str): ids = [recordset["ids"]] else: ids = recordset["ids"] - if ids is None: - raise ValueError("IDs cannot be None") - if not include_all_ids: - ids = draw(st.one_of(st.none(), st.lists(st.sampled_from(ids)))) + ids = draw(st.one_of(st.none(), st.lists(st.sampled_from(ids)))) # type: ignore[assignment] if ids is not None: # Remove duplicates since hypothesis samples with replacement ids = list(set(ids)) diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index a7235513f08..729daf87e46 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -97,7 +97,7 @@ def _patch_telemetry_client( def patch_for_version( version: str, collection: strategies.Collection, - embeddings: strategies.RecordSet, + record_set: strategies.RecordSet, settings: Settings, ) -> None: """Override aspects of the collection and embeddings, before testing, to account for @@ -107,7 +107,7 @@ def patch_for_version( if packaging_version.Version(version) <= packaging_version.Version( patch_version ): - patch(collection, embeddings, settings) + patch(collection, record_set, settings) def api_import_for_version(module: Any, version: str) -> Type: # type: ignore @@ -234,7 +234,7 @@ def persist_generated_data_with_old_version( version: str, settings: Settings, collection_strategy: strategies.Collection, - embeddings_strategy: strategies.RecordSet, + record_set: strategies.RecordSet, conn: Connection, ) -> None: try: @@ -256,21 +256,24 @@ def persist_generated_data_with_old_version( # In order to test old versions, we can't rely on the not_implemented function embedding_function=not_implemented_ef(), ) - result = coll.add(**embeddings_strategy) + result = coll.add(**record_set) - if embeddings_strategy["ids"] is None: + if ( + packaging_version.Version(version) >= packaging_version.Version("0.5.5") + and record_set["ids"] is None + ): if result is None: raise ValueError("IDs from embeddings strategy should not be None") if result["ids"] is None: raise ValueError("IDs from result should not be None") - embeddings_strategy["ids"] = result["ids"] + record_set["ids"] = result["ids"] # Just use some basic checks for sanity and manual testing where you break the new # version - check_embeddings = invariants.wrap_all(embeddings_strategy) + check_embeddings = invariants.wrap_all(record_set) # Check count assert coll.count() == len(check_embeddings["embeddings"]) # type: ignore[arg-type] @@ -303,14 +306,14 @@ def persist_generated_data_with_old_version( @given( collection_strategy=collection_st, - embeddings_strategy=strategies.recordsets(collection_strategy=collection_st), + record_set=strategies.recordsets(collection_strategy=collection_st), should_stomp_ids=st.booleans(), ) @settings(deadline=None) def test_cycle_versions( version_settings: Tuple[str, Settings], collection_strategy: strategies.Collection, - embeddings_strategy: strategies.RecordSet, + record_set: strategies.RecordSet, should_stomp_ids: bool, ) -> None: # Test backwards compatibility @@ -320,26 +323,26 @@ def test_cycle_versions( # TODO: This condition is subject to change as we decide on whether we want to # release auto ID generation feature after 0.5.5 + if ( packaging_version.Version(version) > packaging_version.Version("0.5.5") and should_stomp_ids ): - embeddings_strategy["ids"] = None + record_set["ids"] = None # The strategies can generate metadatas of malformed inputs. Other tests # will error check and cover these cases to make sure they error. Here we # just convert them to valid values since the error cases are already tested - if embeddings_strategy["metadatas"] == {}: - embeddings_strategy["metadatas"] = None - if embeddings_strategy["metadatas"] is not None and isinstance( - embeddings_strategy["metadatas"], list + if record_set["metadatas"] == {}: + record_set["metadatas"] = None + if record_set["metadatas"] is not None and isinstance( + record_set["metadatas"], list ): - embeddings_strategy["metadatas"] = [ - m if m is None or len(m) > 0 else None - for m in embeddings_strategy["metadatas"] + record_set["metadatas"] = [ + m if m is None or len(m) > 0 else None for m in record_set["metadatas"] ] - patch_for_version(version, collection_strategy, embeddings_strategy, settings) + patch_for_version(version, collection_strategy, record_set, settings) # Can't pickle a function, and we won't need them collection_strategy.embedding_function = None @@ -352,7 +355,7 @@ def test_cycle_versions( conn1, conn2 = multiprocessing.Pipe() p = ctx.Process( target=persist_generated_data_with_old_version, - args=(version, settings, collection_strategy, embeddings_strategy, conn2), + args=(version, settings, collection_strategy, record_set, conn2), ) p.start() p.join() @@ -397,18 +400,18 @@ def test_cycle_versions( invariants.log_size_below_max(system, [coll], True) # Should be able to add embeddings - result = coll.add(**embeddings_strategy) # type: ignore[arg-type] - if embeddings_strategy["ids"] is None: - embeddings_strategy["ids"] = result["ids"] - - invariants.count(coll, embeddings_strategy) - invariants.metadatas_match(coll, embeddings_strategy) - invariants.documents_match(coll, embeddings_strategy) - invariants.ids_match(coll, embeddings_strategy) + result = coll.add(**record_set) # type: ignore[arg-type] + if record_set["ids"] is None: + record_set["ids"] = result["ids"] + + invariants.count(coll, record_set) + invariants.metadatas_match(coll, record_set) + invariants.documents_match(coll, record_set) + invariants.ids_match(coll, record_set) invariants.ann_accuracy( coll, - embeddings_strategy, - n_records=invariants.get_n_items_from_record_set(embeddings_strategy), + record_set, + n_records=invariants.get_n_items_from_record_set(record_set), ) invariants.log_size_below_max(system, [coll], True) diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index c2f2e9ad63f..1171cd3cdbf 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -134,7 +134,7 @@ def _filter_embedding_set( """Return IDs from the embedding set that match the given filter object""" normalized_record_set = invariants.wrap_all(record_set) - ids = set(normalized_record_set["ids"]) + ids = set(normalized_record_set["ids"]) # type: ignore[arg-type] filter_ids = filter["ids"] @@ -145,23 +145,23 @@ def _filter_embedding_set( if len(filter_ids) != 0: ids = ids.intersection(filter_ids) - for i in range(len(normalized_record_set["ids"])): + for i in range(len(normalized_record_set["ids"])): # type: ignore[arg-type] if filter["where"]: metadatas: Metadatas if isinstance(normalized_record_set["metadatas"], list): metadatas = normalized_record_set["metadatas"] # type: ignore[assignment] else: - metadatas = [EMPTY_DICT] * len(normalized_record_set["ids"]) + metadatas = [EMPTY_DICT] * len(normalized_record_set["ids"]) # type: ignore[arg-type] filter_where: Where = filter["where"] if not _filter_where_clause(filter_where, metadatas[i]): - ids.discard(normalized_record_set["ids"][i]) + ids.discard(normalized_record_set["ids"][i]) # type: ignore[index] if filter["where_document"]: documents = normalized_record_set["documents"] or [EMPTY_STRING] * len( - normalized_record_set["ids"] + normalized_record_set["ids"] # type: ignore[arg-type] ) if not _filter_where_doc_clause(filter["where_document"], documents[i]): - ids.discard(normalized_record_set["ids"][i]) + ids.discard(normalized_record_set["ids"][i]) # type: ignore[index] return list(ids) @@ -209,6 +209,9 @@ def test_filterable_metadata_get( initial_version = cast(int, coll.get_model()["version"]) + if record_set["ids"] is None: + raise ValueError("Record set IDs cannot be None") + coll.add(**record_set) if not NOT_CLUSTER_ONLY: @@ -325,11 +328,11 @@ def test_filterable_metadata_query( if not NOT_CLUSTER_ONLY: # Only wait for compaction if the size of the collection is # some minimal size - if should_compact and len(invariants.wrap(record_set["ids"])) > 10: + if should_compact and len(invariants.wrap(record_set["ids"])) > 10: # type: ignore[arg-type] # Wait for the model to be updated wait_for_version_increase(client, collection.name, initial_version) - total_count = len(normalized_record_set["ids"]) + total_count = len(normalized_record_set["ids"]) # type: ignore[arg-type] # Pick a random vector random_query: Embedding if collection.has_embeddings: