Skip to content

Commit

Permalink
create a new func for ensuring record set consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Sep 3, 2024
1 parent 7d3b5aa commit 9dd5ac7
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 19 deletions.
66 changes: 48 additions & 18 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,10 +536,53 @@ def validate_batch_size(
)


def validate_record_set_consistency(record_set: RecordSet) -> None:
"""
Validate the consistency of the record set, ensuring all values are non-empty lists and have the same length.
"""
error_messages = []
field_record_counts = []
count = 0
consistentcy_error = False

for field, value in record_set.items():
if value is None:
continue

if not isinstance(value, list):
error_messages.append(
f"Expected field {field} to be a list, got {type(value).__name__}"
)
continue

if len(value) == 0:
error_messages.append(
f"Expected field {field} to be a non-empty list, got an empty list"
)
continue

n_items = len(value)
field_record_counts.append(f"{field}: ({n_items})")
if count == 0:
count = n_items
elif count != n_items:
consistentcy_error = True

if consistentcy_error:
error_messages.append(
f"Inconsistent number of records: {', '.join(field_record_counts)}"
)

if len(error_messages) > 0:
raise ValueError(", ".join(error_messages))


def get_n_items_from_record_set(record_set: RecordSet) -> Tuple[str, int]:
"""
Get the number of items in the record set.
"""
validate_record_set_consistency(record_set)

for field, value in record_set.items():
if isinstance(value, list) and len(value) > 0:
return field, len(value)
Expand All @@ -548,6 +591,10 @@ def get_n_items_from_record_set(record_set: RecordSet) -> Tuple[str, int]:


def validate_record_set(record_set: RecordSet) -> None:
"""
Validate the record set, ensuring all values within a record set are non-empty lists, have the same length and are valid.
"""

embeddings = record_set["embeddings"]
ids = record_set["ids"]
metadatas = record_set["metadatas"]
Expand All @@ -556,21 +603,4 @@ def validate_record_set(record_set: RecordSet) -> None:
validate_embeddings(embeddings) if embeddings is not None else None
validate_metadatas(metadatas) if metadatas is not None else None

_, n_items = get_n_items_from_record_set(record_set)
if n_items == 0:
raise ValueError("No items in record set")

should_error = False
field_record_counts = []

for field, value in record_set.items():
if isinstance(value, list):
n = len(value)
field_record_counts.append(f"{field}: ({n})")
if n != n_items:
should_error = True

if should_error:
raise ValueError(
"Inconsistent number of records: " + ", ".join(field_record_counts)
)
validate_record_set_consistency(record_set)
65 changes: 64 additions & 1 deletion chromadb/test/api/test_validations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import pytest
from typing import cast
import chromadb.errors as errors
from chromadb.api.types import validate_embeddings, Embeddings, IDs, validate_ids
from chromadb.api.types import (
validate_embeddings,
Embeddings,
IDs,
RecordSet,
validate_ids,
validate_record_set_consistency,
)


def test_embeddings_validation() -> None:
Expand Down Expand Up @@ -67,3 +74,59 @@ def test_ids_validation() -> None:
] * 2
with pytest.raises(errors.DuplicateIDError, match="found 15 duplicated IDs: "):
validate_ids(ids)


def test_validate_record_set_consistency() -> None:
# Test record set with inconsistent lengths
inconsistent_record_set: RecordSet = {
"ids": ["1", "2"],
"embeddings": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
"metadatas": [{"key": "value1"}, {"key": "value2"}, {"key": "value3"}],
"documents": ["doc1", "doc2"],
"images": None,
"uris": None,
}
with pytest.raises(ValueError, match="Inconsistent number of records:"):
validate_record_set_consistency(inconsistent_record_set)

# Test record set with empty list
empty_list_record_set: RecordSet = {
"ids": ["1", "2", "3"],
"embeddings": [],
"metadatas": [{"key": "value1"}, {"key": "value2"}, {"key": "value3"}],
"documents": ["doc1", "doc2", "doc3"],
"images": None,
"uris": None,
}
with pytest.raises(
ValueError, match="Expected field embeddings to be a non-empty list"
):
validate_record_set_consistency(empty_list_record_set)

# Test record set with non-list value
non_list_record_set: RecordSet = {
"ids": ["1", "2", "3"],
"embeddings": "not a list", # type: ignore[typeddict-item]
"metadatas": [{"key": "value1"}, {"key": "value2"}, {"key": "value3"}],
"documents": ["doc1", "doc2", "doc3"],
"images": None,
"uris": None,
}
with pytest.raises(ValueError, match="Expected field embeddings to be a list"):
validate_record_set_consistency(non_list_record_set)

# Test record set with multiple errors
multiple_error_record_set: RecordSet = {
"ids": [],
"embeddings": "not a list", # type: ignore[typeddict-item]
"metadatas": [{"key": "value1"}, {"key": "value2"}],
"documents": ["doc1"],
"images": None,
"uris": None,
}
with pytest.raises(ValueError) as exc_info:
validate_record_set_consistency(multiple_error_record_set)

assert "Expected field ids to be a non-empty list" in str(exc_info.value)
assert "Expected field embeddings to be a list" in str(exc_info.value)
assert "Inconsistent number of records:" in str(exc_info.value)

0 comments on commit 9dd5ac7

Please sign in to comment.