Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Sep 12, 2024
1 parent cd30c01 commit 805ea9e
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 60 deletions.
23 changes: 5 additions & 18 deletions chromadb/test/property/invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ def wrap_all(record_set: RecordSet) -> NormalizedRecordSet:


def get_n_items_from_record_set_state(state_record_set: StateMachineRecordSet) -> int:
normalized_record_set = wrap_all(cast(RecordSet, state_record_set))

# we need to replace empty lists with None within the record set state to use get_n_items_from_record_set
# get_n_items_from_record_set would throw an error if it encounters an empty list
if all(len(value) == 0 for value in state_record_set.values()): # type: ignore[arg-type]
return 0

record_set_with_empty_lists_replaced: types.RecordSet = {
"ids": None,
"documents": None,
Expand All @@ -85,22 +86,8 @@ def get_n_items_from_record_set_state(state_record_set: StateMachineRecordSet) -
"uris": None,
}

all_fields_are_empty = True
for key, value in normalized_record_set.items():
if value is None:
continue

if isinstance(value, list):
if len(value) == 0:
record_set_with_empty_lists_replaced[key] = None # type: ignore[literal-required]
continue

all_fields_are_empty = False

record_set_with_empty_lists_replaced[key] = value # type: ignore[literal-required]

if all_fields_are_empty:
return 0
for key, value in state_record_set.items():
record_set_with_empty_lists_replaced[key] = None if len(value) == 0 else value # type: ignore[literal-required, arg-type]

return types.get_n_items_from_record_set(record_set_with_empty_lists_replaced)

Expand Down
10 changes: 5 additions & 5 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,18 +695,18 @@ def filters(
st.one_of(st.none(), recursive_where_doc_clause(collection))
)

ids: Optional[Union[List[types.ID], types.ID]]
if recordset["ids"] is None:
raise ValueError("Record set IDs cannot be None")

ids: Union[List[types.ID], types.ID]
# Record sets can be a value instead of a list of values if there is only one record
if isinstance(recordset["ids"], str):
ids = [recordset["ids"]]
else:
ids = recordset["ids"]

if ids is None:
raise ValueError("IDs cannot be None")

if not include_all_ids:
ids = draw(st.one_of(st.none(), st.lists(st.sampled_from(ids))))
ids = draw(st.one_of(st.none(), st.lists(st.sampled_from(ids)))) # type: ignore[assignment]
if ids is not None:
# Remove duplicates since hypothesis samples with replacement
ids = list(set(ids))
Expand Down
61 changes: 32 additions & 29 deletions chromadb/test/property/test_cross_version_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _patch_telemetry_client(
def patch_for_version(
version: str,
collection: strategies.Collection,
embeddings: strategies.RecordSet,
record_set: strategies.RecordSet,
settings: Settings,
) -> None:
"""Override aspects of the collection and embeddings, before testing, to account for
Expand All @@ -107,7 +107,7 @@ def patch_for_version(
if packaging_version.Version(version) <= packaging_version.Version(
patch_version
):
patch(collection, embeddings, settings)
patch(collection, record_set, settings)


def api_import_for_version(module: Any, version: str) -> Type: # type: ignore
Expand Down Expand Up @@ -234,7 +234,7 @@ def persist_generated_data_with_old_version(
version: str,
settings: Settings,
collection_strategy: strategies.Collection,
embeddings_strategy: strategies.RecordSet,
record_set: strategies.RecordSet,
conn: Connection,
) -> None:
try:
Expand All @@ -256,21 +256,24 @@ def persist_generated_data_with_old_version(
# In order to test old versions, we can't rely on the not_implemented function
embedding_function=not_implemented_ef(),
)
result = coll.add(**embeddings_strategy)
result = coll.add(**record_set)

if embeddings_strategy["ids"] is None:
if (
packaging_version.Version(version) >= packaging_version.Version("0.5.5")
and record_set["ids"] is None
):
if result is None:
raise ValueError("IDs from embeddings strategy should not be None")

if result["ids"] is None:
raise ValueError("IDs from result should not be None")

embeddings_strategy["ids"] = result["ids"]
record_set["ids"] = result["ids"]

# Just use some basic checks for sanity and manual testing where you break the new
# version

check_embeddings = invariants.wrap_all(embeddings_strategy)
check_embeddings = invariants.wrap_all(record_set)
# Check count
assert coll.count() == len(check_embeddings["embeddings"]) # type: ignore[arg-type]

Expand Down Expand Up @@ -303,14 +306,14 @@ def persist_generated_data_with_old_version(

@given(
collection_strategy=collection_st,
embeddings_strategy=strategies.recordsets(collection_strategy=collection_st),
record_set=strategies.recordsets(collection_strategy=collection_st),
should_stomp_ids=st.booleans(),
)
@settings(deadline=None)
def test_cycle_versions(
version_settings: Tuple[str, Settings],
collection_strategy: strategies.Collection,
embeddings_strategy: strategies.RecordSet,
record_set: strategies.RecordSet,
should_stomp_ids: bool,
) -> None:
# Test backwards compatibility
Expand All @@ -320,26 +323,26 @@ def test_cycle_versions(

# TODO: This condition is subject to change as we decide on whether we want to
# release auto ID generation feature after 0.5.5

if (
packaging_version.Version(version) > packaging_version.Version("0.5.5")
and should_stomp_ids
):
embeddings_strategy["ids"] = None
record_set["ids"] = None

# The strategies can generate metadatas of malformed inputs. Other tests
# will error check and cover these cases to make sure they error. Here we
# just convert them to valid values since the error cases are already tested
if embeddings_strategy["metadatas"] == {}:
embeddings_strategy["metadatas"] = None
if embeddings_strategy["metadatas"] is not None and isinstance(
embeddings_strategy["metadatas"], list
if record_set["metadatas"] == {}:
record_set["metadatas"] = None
if record_set["metadatas"] is not None and isinstance(
record_set["metadatas"], list
):
embeddings_strategy["metadatas"] = [
m if m is None or len(m) > 0 else None
for m in embeddings_strategy["metadatas"]
record_set["metadatas"] = [
m if m is None or len(m) > 0 else None for m in record_set["metadatas"]
]

patch_for_version(version, collection_strategy, embeddings_strategy, settings)
patch_for_version(version, collection_strategy, record_set, settings)

# Can't pickle a function, and we won't need them
collection_strategy.embedding_function = None
Expand All @@ -352,7 +355,7 @@ def test_cycle_versions(
conn1, conn2 = multiprocessing.Pipe()
p = ctx.Process(
target=persist_generated_data_with_old_version,
args=(version, settings, collection_strategy, embeddings_strategy, conn2),
args=(version, settings, collection_strategy, record_set, conn2),
)
p.start()
p.join()
Expand Down Expand Up @@ -397,18 +400,18 @@ def test_cycle_versions(
invariants.log_size_below_max(system, [coll], True)

# Should be able to add embeddings
result = coll.add(**embeddings_strategy) # type: ignore[arg-type]
if embeddings_strategy["ids"] is None:
embeddings_strategy["ids"] = result["ids"]

invariants.count(coll, embeddings_strategy)
invariants.metadatas_match(coll, embeddings_strategy)
invariants.documents_match(coll, embeddings_strategy)
invariants.ids_match(coll, embeddings_strategy)
result = coll.add(**record_set) # type: ignore[arg-type]
if record_set["ids"] is None:
record_set["ids"] = result["ids"]

invariants.count(coll, record_set)
invariants.metadatas_match(coll, record_set)
invariants.documents_match(coll, record_set)
invariants.ids_match(coll, record_set)
invariants.ann_accuracy(
coll,
embeddings_strategy,
n_records=invariants.get_n_items_from_record_set(embeddings_strategy),
record_set,
n_records=invariants.get_n_items_from_record_set(record_set),
)
invariants.log_size_below_max(system, [coll], True)

Expand Down
19 changes: 11 additions & 8 deletions chromadb/test/property/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _filter_embedding_set(
"""Return IDs from the embedding set that match the given filter object"""

normalized_record_set = invariants.wrap_all(record_set)
ids = set(normalized_record_set["ids"])
ids = set(normalized_record_set["ids"]) # type: ignore[arg-type]

filter_ids = filter["ids"]

Expand All @@ -145,23 +145,23 @@ def _filter_embedding_set(
if len(filter_ids) != 0:
ids = ids.intersection(filter_ids)

for i in range(len(normalized_record_set["ids"])):
for i in range(len(normalized_record_set["ids"])): # type: ignore[arg-type]
if filter["where"]:
metadatas: Metadatas
if isinstance(normalized_record_set["metadatas"], list):
metadatas = normalized_record_set["metadatas"] # type: ignore[assignment]
else:
metadatas = [EMPTY_DICT] * len(normalized_record_set["ids"])
metadatas = [EMPTY_DICT] * len(normalized_record_set["ids"]) # type: ignore[arg-type]
filter_where: Where = filter["where"]
if not _filter_where_clause(filter_where, metadatas[i]):
ids.discard(normalized_record_set["ids"][i])
ids.discard(normalized_record_set["ids"][i]) # type: ignore[index]

if filter["where_document"]:
documents = normalized_record_set["documents"] or [EMPTY_STRING] * len(
normalized_record_set["ids"]
normalized_record_set["ids"] # type: ignore[arg-type]
)
if not _filter_where_doc_clause(filter["where_document"], documents[i]):
ids.discard(normalized_record_set["ids"][i])
ids.discard(normalized_record_set["ids"][i]) # type: ignore[index]

return list(ids)

Expand Down Expand Up @@ -209,6 +209,9 @@ def test_filterable_metadata_get(

initial_version = cast(int, coll.get_model()["version"])

if record_set["ids"] is None:
raise ValueError("Record set IDs cannot be None")

coll.add(**record_set)

if not NOT_CLUSTER_ONLY:
Expand Down Expand Up @@ -325,11 +328,11 @@ def test_filterable_metadata_query(
if not NOT_CLUSTER_ONLY:
# Only wait for compaction if the size of the collection is
# some minimal size
if should_compact and len(invariants.wrap(record_set["ids"])) > 10:
if should_compact and len(invariants.wrap(record_set["ids"])) > 10: # type: ignore[arg-type]
# Wait for the model to be updated
wait_for_version_increase(client, collection.name, initial_version)

total_count = len(normalized_record_set["ids"])
total_count = len(normalized_record_set["ids"]) # type: ignore[arg-type]
# Pick a random vector
random_query: Embedding
if collection.has_embeddings:
Expand Down

0 comments on commit 805ea9e

Please sign in to comment.