Skip to content

Commit

Permalink
update doc strings, error messages and ignore tags
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Sep 10, 2024
1 parent f9d9c03 commit cccf640
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 26 deletions.
3 changes: 3 additions & 0 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,9 @@ def generate_ids_when_not_present(
return ids

(_, n) = get_n_items_from_record_set(record_set)
if n is None:
raise ValueError("Expected record set to have at least one item")

generated_ids: List[str] = [str(uuid4()) for _ in range(n)]

return generated_ids
Expand Down
6 changes: 4 additions & 2 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def maybe_cast_one_to_many_embedding(
if isinstance(embeddings_list, List):
# One Embedding
if len(embeddings_list) == 0:
raise ValueError("Expected embeddings to be a list with at least one item")
raise ValueError(
"Expected embeddings to be a list or a numpy array with at least one item"
)

if isinstance(embeddings_list[0], (int, float)):
return cast(Embeddings, [embeddings_list])
Expand Down Expand Up @@ -602,7 +604,7 @@ def validate_record_set_consistency(record_set: RecordSet) -> None:

def get_n_items_from_record_set(
record_set: RecordSet,
) -> Tuple[Union[str, None], Union[int, None]]:
) -> Tuple[Optional[str], Optional[int]]:
"""
Get the number of items in the record set.
"""
Expand Down
33 changes: 17 additions & 16 deletions chromadb/test/api/test_api_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ def test_add_with_no_ids(client: ClientAPI) -> None:

coll = client.create_collection("test")
coll.add(
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore
metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore
documents=["a", "b", None], # type: ignore
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type]
metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore[list-item]
documents=["a", "b", None], # type: ignore[list-item]
)

results = coll.get()
assert len(results["ids"]) == 3

coll.add(
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore
metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore
documents=["a", "b", None], # type: ignore
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type]
metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore[list-item]
documents=["a", "b", None], # type: ignore[list-item]
)

results = coll.get()
Expand All @@ -36,7 +36,7 @@ def test_add_with_inconsistent_number_of_items(client: ClientAPI) -> None:
with pytest.raises(ValueError, match="Inconsistent number of records"):
coll.add(
ids=["1", "2"],
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type]
metadatas=[{"a": 1}, {"a": 2}, {"a": 3}],
documents=["a", "b", "c"],
)
Expand All @@ -45,7 +45,7 @@ def test_add_with_inconsistent_number_of_items(client: ClientAPI) -> None:
with pytest.raises(ValueError, match="Inconsistent number of records"):
coll.add(
ids=["1", "2", "3"],
embeddings=[[1, 2, 3], [1, 2, 3]], # type: ignore
embeddings=[[1, 2, 3], [1, 2, 3]], # type: ignore[arg-type]
metadatas=[{"a": 1}, {"a": 2}, {"a": 3}],
documents=["a", "b", "c"],
)
Expand All @@ -54,7 +54,7 @@ def test_add_with_inconsistent_number_of_items(client: ClientAPI) -> None:
with pytest.raises(ValueError, match="Inconsistent number of records"):
coll.add(
ids=["1", "2", "3"],
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type]
metadatas=[{"a": 1}, {"a": 2}],
documents=["a", "b", "c"],
)
Expand All @@ -63,7 +63,7 @@ def test_add_with_inconsistent_number_of_items(client: ClientAPI) -> None:
with pytest.raises(ValueError, match="Inconsistent number of records"):
coll.add(
ids=["1", "2", "3"],
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type]
metadatas=[{"a": 1}, {"a": 2}, {"a": 3}],
documents=["a", "b"],
)
Expand All @@ -72,7 +72,7 @@ def test_add_with_inconsistent_number_of_items(client: ClientAPI) -> None:
with pytest.raises(ValueError, match="Inconsistent number of records"):
coll.add(
ids=["1", "2"],
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type]
metadatas=[{"a": 1}],
documents=["a", "b", "c", "d"],
)
Expand All @@ -85,10 +85,10 @@ def test_add_with_partial_ids(client: ClientAPI) -> None:

with pytest.raises(ValueError, match="Expected ID to be a str"):
coll.add(
ids=["1", None], # type: ignore
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore
metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore
documents=["a", "b", None], # type: ignore
ids=["1", None], # type: ignore[list-item]
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type]
metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore[list-item]
documents=["a", "b", None], # type: ignore[list-item]
)


Expand All @@ -98,7 +98,8 @@ def test_add_with_no_data(client: ClientAPI) -> None:
coll = client.create_collection("test")

with pytest.raises(
Exception, match="Expected embeddings to be a list with at least one item"
Exception,
match="Expected embeddings to be a list or a numpy array with at least one item",
):
coll.add(
ids=["1"],
Expand Down
14 changes: 7 additions & 7 deletions chromadb/test/property/invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _field_matches(
n: int,
) -> None:
"""
The actual embedding field is equal to the expected field
The actual record field is equal to the expected field
field_name: one of [documents, metadatas]
"""
# If there are no ids, then there are no data to test
Expand Down Expand Up @@ -190,7 +190,7 @@ def ids_match(collection: Collection, record_set: RecordSet) -> None:


def metadatas_match(collection: Collection, record_set: RecordSet) -> None:
"""The actual embedding metadata is equal to the expected metadata"""
"""The actual record set metadatas are equal to the expected metadatas"""
normalized_record_set = wrap_all(record_set)

_field_matches(
Expand All @@ -204,7 +204,7 @@ def metadatas_match(collection: Collection, record_set: RecordSet) -> None:
def metadatas_match_state_record_set(
collection: Collection, record_set: StateMachineRecordSet
) -> None:
"""The actual embedding metadata is equal to the expected metadata"""
"""The actual metadatas within the state record set are equal to the expected metadata"""
normalized_record_set = wrap_all(cast(RecordSet, record_set))

_field_matches(
Expand All @@ -216,7 +216,7 @@ def metadatas_match_state_record_set(


def documents_match(collection: Collection, record_set: RecordSet) -> None:
"""The actual embedding documents is equal to the expected documents"""
"""The actual record set documents are equal to the expected documents"""
normalized_record_set = wrap_all(record_set)
_field_matches(
collection,
Expand All @@ -229,7 +229,7 @@ def documents_match(collection: Collection, record_set: RecordSet) -> None:
def documents_match_state_record_set(
collection: Collection, record_set: StateMachineRecordSet
) -> None:
"""The actual embedding documents is equal to the expected metadata"""
"""The actual documents within the state record set are equal to the expected documents"""
normalized_record_set = wrap_all(cast(RecordSet, record_set))

_field_matches(
Expand All @@ -241,7 +241,7 @@ def documents_match_state_record_set(


def embeddings_match(collection: Collection, record_set: RecordSet) -> None:
"""The actual embedding is equal to the expected documents"""
"""The actual record set embeddings are equal to the expected embeddings"""
normalized_record_set = wrap_all(record_set)
_field_matches(
collection,
Expand All @@ -254,7 +254,7 @@ def embeddings_match(collection: Collection, record_set: RecordSet) -> None:
def embeddings_match_state_record_set(
collection: Collection, record_set: StateMachineRecordSet
) -> None:
"""The actual embedding is equal to the expected metadata"""
"""The actual embeddings within the state record set are equal to the expected embeddings"""
normalized_record_set = wrap_all(cast(RecordSet, record_set))

_field_matches(
Expand Down
2 changes: 1 addition & 1 deletion chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID
if len(intersection) > 0:
# Partially apply the non-duplicative records to the state
new_ids = list(set(ids).difference(intersection)) # type: ignore[arg-type]
indices = [ids.index(id) for id in new_ids] # type: ignore
indices = [ids.index(id) for id in new_ids] # type: ignore[union-attr]
filtered_record_set: strategies.NormalizedRecordSet = {
"ids": [ids[i] for i in indices], # type: ignore
"metadatas": [normalized_record_set["metadatas"][i] for i in indices]
Expand Down

0 comments on commit cccf640

Please sign in to comment.