diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 657726d00e9..e7d073ec859 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -536,10 +536,53 @@ def validate_batch_size( ) +def validate_record_set_consistency(record_set: RecordSet) -> None: + """ + Validate the consistency of the record set, ensuring all values are non-empty lists and have the same length. + """ + error_messages = [] + field_record_counts = [] + count = 0 + consistentcy_error = False + + for field, value in record_set.items(): + if value is None: + continue + + if not isinstance(value, list): + error_messages.append( + f"Expected field {field} to be a list, got {type(value).__name__}" + ) + continue + + if len(value) == 0: + error_messages.append( + f"Expected field {field} to be a non-empty list, got an empty list" + ) + continue + + n_items = len(value) + field_record_counts.append(f"{field}: ({n_items})") + if count == 0: + count = n_items + elif count != n_items: + consistentcy_error = True + + if consistentcy_error: + error_messages.append( + f"Inconsistent number of records: {', '.join(field_record_counts)}" + ) + + if len(error_messages) > 0: + raise ValueError(", ".join(error_messages)) + + def get_n_items_from_record_set(record_set: RecordSet) -> Tuple[str, int]: """ Get the number of items in the record set. """ + validate_record_set_consistency(record_set) + for field, value in record_set.items(): if isinstance(value, list) and len(value) > 0: return field, len(value) @@ -548,6 +591,10 @@ def get_n_items_from_record_set(record_set: RecordSet) -> Tuple[str, int]: def validate_record_set(record_set: RecordSet) -> None: + """ + Validate the record set, ensuring all values within a record set are non-empty lists, have the same length and are valid. + """ + embeddings = record_set["embeddings"] ids = record_set["ids"] metadatas = record_set["metadatas"] @@ -556,21 +603,4 @@ def validate_record_set(record_set: RecordSet) -> None: validate_embeddings(embeddings) if embeddings is not None else None validate_metadatas(metadatas) if metadatas is not None else None - _, n_items = get_n_items_from_record_set(record_set) - if n_items == 0: - raise ValueError("No items in record set") - - should_error = False - field_record_counts = [] - - for field, value in record_set.items(): - if isinstance(value, list): - n = len(value) - field_record_counts.append(f"{field}: ({n})") - if n != n_items: - should_error = True - - if should_error: - raise ValueError( - "Inconsistent number of records: " + ", ".join(field_record_counts) - ) + validate_record_set_consistency(record_set) diff --git a/chromadb/test/api/test_validations.py b/chromadb/test/api/test_validations.py index c02ad354eb7..7fb4c0063dd 100644 --- a/chromadb/test/api/test_validations.py +++ b/chromadb/test/api/test_validations.py @@ -1,7 +1,14 @@ import pytest from typing import cast import chromadb.errors as errors -from chromadb.api.types import validate_embeddings, Embeddings, IDs, validate_ids +from chromadb.api.types import ( + validate_embeddings, + Embeddings, + IDs, + RecordSet, + validate_ids, + validate_record_set_consistency, +) def test_embeddings_validation() -> None: @@ -67,3 +74,59 @@ def test_ids_validation() -> None: ] * 2 with pytest.raises(errors.DuplicateIDError, match="found 15 duplicated IDs: "): validate_ids(ids) + + +def test_validate_record_set_consistency() -> None: + # Test record set with inconsistent lengths + inconsistent_record_set: RecordSet = { + "ids": ["1", "2"], + "embeddings": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + "metadatas": [{"key": "value1"}, {"key": "value2"}, {"key": "value3"}], + "documents": ["doc1", "doc2"], + "images": None, + "uris": None, + } + with pytest.raises(ValueError, match="Inconsistent number of records:"): + validate_record_set_consistency(inconsistent_record_set) + + # Test record set with empty list + empty_list_record_set: RecordSet = { + "ids": ["1", "2", "3"], + "embeddings": [], + "metadatas": [{"key": "value1"}, {"key": "value2"}, {"key": "value3"}], + "documents": ["doc1", "doc2", "doc3"], + "images": None, + "uris": None, + } + with pytest.raises( + ValueError, match="Expected field embeddings to be a non-empty list" + ): + validate_record_set_consistency(empty_list_record_set) + + # Test record set with non-list value + non_list_record_set: RecordSet = { + "ids": ["1", "2", "3"], + "embeddings": "not a list", # type: ignore[typeddict-item] + "metadatas": [{"key": "value1"}, {"key": "value2"}, {"key": "value3"}], + "documents": ["doc1", "doc2", "doc3"], + "images": None, + "uris": None, + } + with pytest.raises(ValueError, match="Expected field embeddings to be a list"): + validate_record_set_consistency(non_list_record_set) + + # Test record set with multiple errors + multiple_error_record_set: RecordSet = { + "ids": [], + "embeddings": "not a list", # type: ignore[typeddict-item] + "metadatas": [{"key": "value1"}, {"key": "value2"}], + "documents": ["doc1"], + "images": None, + "uris": None, + } + with pytest.raises(ValueError) as exc_info: + validate_record_set_consistency(multiple_error_record_set) + + assert "Expected field ids to be a non-empty list" in str(exc_info.value) + assert "Expected field embeddings to be a list" in str(exc_info.value) + assert "Inconsistent number of records:" in str(exc_info.value)