Skip to content

Commit

Permalink
fix types
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Aug 27, 2024
1 parent fe2f2b7 commit 93e10db
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
2 changes: 1 addition & 1 deletion chromadb/test/api/test_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from chromadb.api.types import IDs, validate_ids


def test_ids_validation():
def test_ids_validation() -> None:
ids = ["id1", "id2", "id3"]
assert validate_ids(ids) == ids

Expand Down
17 changes: 11 additions & 6 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hashlib
import hypothesis
import hypothesis.strategies as st
from typing import Any, Optional, List, Dict, Union, cast
from typing import Any, Optional, List, Dict, Union, cast, Tuple
from typing_extensions import TypedDict
import uuid
import numpy as np
Expand Down Expand Up @@ -195,7 +195,7 @@ def create_embeddings_ndarray(
dim: int,
count: int,
dtype: npt.DTypeLike,
) -> np.typing.NDArray[Any]:
) -> npt.NDArray[Any]:
return np.random.uniform(
low=-1.0,
high=1.0,
Expand Down Expand Up @@ -377,7 +377,10 @@ def collections(

@st.composite
def metadata(
draw: st.DrawFn, collection: Collection, min_size=0, max_size=None
draw: st.DrawFn,
collection: Collection,
min_size: int = 0,
max_size: Optional[int] = None,
) -> Optional[types.Metadata]:
"""Strategy for generating metadata that could be a part of the given collection"""
# First draw a random dictionary.
Expand Down Expand Up @@ -409,6 +412,8 @@ def document(draw: st.DrawFn, collection: Collection) -> types.Document:
# For cluster tests, we want to avoid generating documents of length < 3.
# We also don't want them to contain certan special
# characters like _ and % that implicitly involve searching for a regex in sqlite.

blacklist_categories: Tuple[str, ...] = ()
if not NOT_CLUSTER_ONLY:
# Blacklist certain unicode characters that affect sqlite processing.
# For example, the null (/x00) character makes sqlite stop processing a string.
Expand Down Expand Up @@ -563,7 +568,7 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:
if not NOT_CLUSTER_ONLY:
legal_ops: List[Optional[str]] = [None, "$eq"]
else:
legal_ops: List[Optional[str]] = [None, "$eq", "$ne", "$in", "$nin"]
legal_ops = [None, "$eq", "$ne", "$in", "$nin"]

if not isinstance(value, str) and not isinstance(value, bool):
legal_ops.extend(["$gt", "$lt", "$lte", "$gte"])
Expand Down Expand Up @@ -615,10 +620,10 @@ def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocu
else:
op = draw(st.sampled_from(["$contains", "$not_contains"]))

if op == "$contains":
if op == "$contains": # type: ignore[comparison-overlap]
return {"$contains": word}
else:
assert op == "$not_contains"
assert op == "$not_contains" # type: ignore[comparison-overlap]
return {"$not_contains": word}


Expand Down
10 changes: 5 additions & 5 deletions chromadb/test/property/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _test_add(
metadata=collection.metadata, # type: ignore
embedding_function=collection.embedding_function,
)
initial_version = coll.get_model()["version"]
initial_version = cast(int, coll.get_model()["version"])

normalized_record_set = invariants.wrap_all(record_set)

Expand Down Expand Up @@ -182,7 +182,7 @@ def test_add_large(
embedding_function=collection.embedding_function,
)
normalized_record_set = invariants.wrap_all(record_set)
initial_version = coll.get_model()["version"]
initial_version = cast(int, coll.get_model()["version"])

for batch in create_batches(
api=client,
Expand Down Expand Up @@ -355,7 +355,7 @@ def test_add_with_no_data(client: ClientAPI) -> None:
):
coll.add(
ids=["1"],
embeddings=[], # type: ignore
metadatas=[{"a": 1}], # type: ignore
documents=[], # type: ignore
embeddings=[],
metadatas=[{"a": 1}],
documents=[],
)
4 changes: 2 additions & 2 deletions chromadb/test/property/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def test_filterable_metadata_get(
embedding_function=collection.embedding_function,
)

initial_version = coll.get_model()["version"]
initial_version = cast(int, coll.get_model()["version"])

coll.add(**record_set)

Expand Down Expand Up @@ -317,7 +317,7 @@ def test_filterable_metadata_query(
metadata=collection.metadata, # type: ignore
embedding_function=collection.embedding_function,
)
initial_version = coll.get_model()["version"]
initial_version = cast(int, coll.get_model()["version"])
normalized_record_set = invariants.wrap_all(record_set)

coll.add(**record_set) # type: ignore[arg-type]
Expand Down

0 comments on commit 93e10db

Please sign in to comment.