From 93e10dba1fbfeeb7d1c9c515d9d0e42ad983a696 Mon Sep 17 00:00:00 2001 From: Spike Lu Date: Mon, 26 Aug 2024 16:33:10 -0700 Subject: [PATCH] fix types --- chromadb/test/api/test_validations.py | 2 +- chromadb/test/property/strategies.py | 17 +++++++++++------ chromadb/test/property/test_add.py | 10 +++++----- chromadb/test/property/test_filtering.py | 4 ++-- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/chromadb/test/api/test_validations.py b/chromadb/test/api/test_validations.py index b1e41ab06dc..fb47388866c 100644 --- a/chromadb/test/api/test_validations.py +++ b/chromadb/test/api/test_validations.py @@ -6,7 +6,7 @@ from chromadb.api.types import IDs, validate_ids -def test_ids_validation(): +def test_ids_validation() -> None: ids = ["id1", "id2", "id3"] assert validate_ids(ids) == ids diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index dd1ab41a003..342d9409873 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -1,7 +1,7 @@ import hashlib import hypothesis import hypothesis.strategies as st -from typing import Any, Optional, List, Dict, Union, cast +from typing import Any, Optional, List, Dict, Union, cast, Tuple from typing_extensions import TypedDict import uuid import numpy as np @@ -195,7 +195,7 @@ def create_embeddings_ndarray( dim: int, count: int, dtype: npt.DTypeLike, -) -> np.typing.NDArray[Any]: +) -> npt.NDArray[Any]: return np.random.uniform( low=-1.0, high=1.0, @@ -377,7 +377,10 @@ def collections( @st.composite def metadata( - draw: st.DrawFn, collection: Collection, min_size=0, max_size=None + draw: st.DrawFn, + collection: Collection, + min_size: int = 0, + max_size: Optional[int] = None, ) -> Optional[types.Metadata]: """Strategy for generating metadata that could be a part of the given collection""" # First draw a random dictionary. @@ -409,6 +412,8 @@ def document(draw: st.DrawFn, collection: Collection) -> types.Document: # For cluster tests, we want to avoid generating documents of length < 3. # We also don't want them to contain certan special # characters like _ and % that implicitly involve searching for a regex in sqlite. + + blacklist_categories: Tuple[str, ...] = () if not NOT_CLUSTER_ONLY: # Blacklist certain unicode characters that affect sqlite processing. # For example, the null (/x00) character makes sqlite stop processing a string. @@ -563,7 +568,7 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where: if not NOT_CLUSTER_ONLY: legal_ops: List[Optional[str]] = [None, "$eq"] else: - legal_ops: List[Optional[str]] = [None, "$eq", "$ne", "$in", "$nin"] + legal_ops = [None, "$eq", "$ne", "$in", "$nin"] if not isinstance(value, str) and not isinstance(value, bool): legal_ops.extend(["$gt", "$lt", "$lte", "$gte"]) @@ -615,10 +620,10 @@ def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocu else: op = draw(st.sampled_from(["$contains", "$not_contains"])) - if op == "$contains": + if op == "$contains": # type: ignore[comparison-overlap] return {"$contains": word} else: - assert op == "$not_contains" + assert op == "$not_contains" # type: ignore[comparison-overlap] return {"$not_contains": word} diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py index 5ca0c9c83f5..17868593b76 100644 --- a/chromadb/test/property/test_add.py +++ b/chromadb/test/property/test_add.py @@ -99,7 +99,7 @@ def _test_add( metadata=collection.metadata, # type: ignore embedding_function=collection.embedding_function, ) - initial_version = coll.get_model()["version"] + initial_version = cast(int, coll.get_model()["version"]) normalized_record_set = invariants.wrap_all(record_set) @@ -182,7 +182,7 @@ def test_add_large( embedding_function=collection.embedding_function, ) normalized_record_set = invariants.wrap_all(record_set) - initial_version = coll.get_model()["version"] + initial_version = cast(int, coll.get_model()["version"]) for batch in create_batches( api=client, @@ -355,7 +355,7 @@ def test_add_with_no_data(client: ClientAPI) -> None: ): coll.add( ids=["1"], - embeddings=[], # type: ignore - metadatas=[{"a": 1}], # type: ignore - documents=[], # type: ignore + embeddings=[], + metadatas=[{"a": 1}], + documents=[], ) diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index e031ce4b810..c2f2e9ad63f 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -207,7 +207,7 @@ def test_filterable_metadata_get( embedding_function=collection.embedding_function, ) - initial_version = coll.get_model()["version"] + initial_version = cast(int, coll.get_model()["version"]) coll.add(**record_set) @@ -317,7 +317,7 @@ def test_filterable_metadata_query( metadata=collection.metadata, # type: ignore embedding_function=collection.embedding_function, ) - initial_version = coll.get_model()["version"] + initial_version = cast(int, coll.get_model()["version"]) normalized_record_set = invariants.wrap_all(record_set) coll.add(**record_set) # type: ignore[arg-type]