Skip to content

Commit

Permalink
move validation logic to server
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Aug 23, 2024
1 parent 44cdda2 commit 88a12a9
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 92 deletions.
63 changes: 0 additions & 63 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Include,
Loadable,
Metadata,
Metadatas,
Document,
Documents,
Image,
Expand All @@ -45,7 +44,6 @@
validate_ids,
validate_include,
validate_metadata,
validate_metadatas,
validate_embeddings,
validate_embedding_function,
validate_n_results,
Expand Down Expand Up @@ -180,65 +178,13 @@ def _unpack_embedding_set(

def _validate_embedding_set(
self,
ids: IDs,
embeddings: Optional[Embeddings],
metadatas: Optional[Metadatas],
documents: Optional[Documents],
images: Optional[Images],
uris: Optional[URIs],
require_embeddings_or_data: bool = True,
) -> None:
valid_ids = validate_ids(ids)
valid_embeddings = (
validate_embeddings(embeddings) if embeddings is not None else None
)
valid_metadatas = (
validate_metadatas(metadatas) if metadatas is not None else None
)

# Already validated from being unpacked from OneOrMany data types
valid_documents = documents
valid_images = images
valid_uris = uris

# Check that one of embeddings or ducuments or images is provided
if require_embeddings_or_data:
if (
valid_embeddings is None
and valid_documents is None
and valid_images is None
and valid_uris is None
):
raise ValueError(
"You must provide embeddings, documents, images, or uris."
)

# Only one of documents or images can be provided
if documents is not None and images is not None:
raise ValueError("You can only provide documents or images, not both.")

# Check that, if they're provided, the lengths of the arrays match the length of ids
if valid_embeddings is not None and len(valid_embeddings) != len(valid_ids):
raise ValueError(
f"Number of embeddings {len(valid_embeddings)} must match number of ids {len(valid_ids)}"
)
if valid_metadatas is not None and len(valid_metadatas) != len(valid_ids):
raise ValueError(
f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}"
)
if documents is not None and len(documents) != len(valid_ids):
raise ValueError(
f"Number of documents {len(documents)} must match number of ids {len(valid_ids)}"
)
if images is not None and len(images) != len(valid_ids):
raise ValueError(
f"Number of images {len(images)} must match number of ids {len(valid_ids)}"
)
if uris is not None and len(uris) != len(valid_ids):
raise ValueError(
f"Number of uris {len(uris)} must match number of ids {len(valid_ids)}"
)

def _compute_embeddings(
self,
embeddings: Optional[Embeddings],
Expand Down Expand Up @@ -453,12 +399,8 @@ def _process_add_request(
)

self._validate_embedding_set(
unpacked_embedding_set["ids"],
normalized_embeddings,
unpacked_embedding_set["metadatas"],
unpacked_embedding_set["documents"],
unpacked_embedding_set["images"],
unpacked_embedding_set["uris"],
)

prepared_embeddings = self._compute_embeddings(
Expand Down Expand Up @@ -517,13 +459,8 @@ def _process_upsert_or_update_request(
)

self._validate_embedding_set(
unpacked_embedding_set["ids"],
normalized_embeddings,
unpacked_embedding_set["metadatas"],
unpacked_embedding_set["documents"],
unpacked_embedding_set["images"],
unpacked_embedding_set["uris"],
require_embeddings_or_data,
)

prepared_embeddings = self._compute_embeddings_upsert_or_update_request(
Expand Down
109 changes: 89 additions & 20 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
validate_where,
validate_where_document,
validate_batch,
validate_ids,
validate_embeddings,
validate_metadatas,
)
from chromadb.telemetry.product.events import (
CollectionAddEvent,
Expand All @@ -48,7 +51,7 @@
)

import chromadb.types as t
from typing import Optional, Sequence, Generator, List, cast, Set, Dict
from typing import Optional, Sequence, Generator, List, cast, Set, Dict, Tuple, Any
from overrides import override
from uuid import UUID, uuid4
import time
Expand Down Expand Up @@ -340,10 +343,17 @@ def _add(
self._quota.static_check(metadatas, documents, embeddings, str(collection_id))
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.ADD)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},

self._validate_embedding_record_set(
collection=coll,
ids=ids,
embeddings=embeddings,
documents=documents,
uris=uris,
metadatas=metadatas,
require_embeddings_or_data=True,
)

records_to_submit = list(
_records(
t.Operation.ADD,
Expand All @@ -354,7 +364,6 @@ def _add(
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
Expand Down Expand Up @@ -382,10 +391,17 @@ def _update(
self._quota.static_check(metadatas, documents, embeddings, str(collection_id))
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},

self._validate_embedding_record_set(
collection=coll,
ids=ids,
embeddings=embeddings,
documents=documents,
uris=uris,
metadatas=metadatas,
require_embeddings_or_data=False,
)

records_to_submit = list(
_records(
t.Operation.UPDATE,
Expand All @@ -396,7 +412,6 @@ def _update(
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
Expand Down Expand Up @@ -426,10 +441,17 @@ def _upsert(
self._quota.static_check(metadatas, documents, embeddings, str(collection_id))
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},

self._validate_embedding_record_set(
collection=coll,
ids=ids,
embeddings=embeddings,
documents=documents,
uris=uris,
metadatas=metadatas,
require_embeddings_or_data=True,
)

records_to_submit = list(
_records(
t.Operation.UPSERT,
Expand All @@ -440,7 +462,6 @@ def _upsert(
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

return True
Expand Down Expand Up @@ -606,10 +627,12 @@ def _delete(
if len(ids_to_delete) == 0:
return []

self._validate_embedding_record_set(
collection=coll, ids=ids_to_delete, require_embeddings_or_data=False
)
records_to_submit = list(
_records(operation=t.Operation.DELETE, ids=ids_to_delete)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
Expand Down Expand Up @@ -799,16 +822,62 @@ def get_max_batch_size(self) -> int:
"SegmentAPI._validate_embedding_record_set", OpenTelemetryGranularity.ALL
)
def _validate_embedding_record_set(
self, collection: t.Collection, records: List[t.OperationRecord]
self,
ids: IDs,
collection: t.Collection,
require_embeddings_or_data: bool,
embeddings: Optional[Embeddings] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
metadatas: Optional[Metadatas] = None,
) -> None:
"""Validate the dimension of an embedding record before submitting it to the system."""
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},
)

add_attributes_to_current_span({"collection_id": str(collection["id"])})
for record in records:
if record["embedding"]:
self._validate_dimension(
collection, len(record["embedding"]), update=True

validate_ids(ids)
validate_embeddings(embeddings) if embeddings is not None else None
validate_metadatas(metadatas) if metadatas is not None else None

if (
require_embeddings_or_data
and embeddings is None
and documents is None
and uris is None
):
raise ValueError("You must provide embeddings, documents, or uris.")

entities: List[Tuple[Any, str]] = [
(embeddings, "embeddings"),
(metadatas, "metadatas"),
(documents, "documents"),
(uris, "uris"),
]

n_ids = len(ids)
for entity in entities:
if entity[0] is None:
continue

name = entity[1]
n = len(entity[0])

if n != len(ids):
raise ValueError(
f"Number of {name} ({n}) does not match number of ids ({n_ids})"
)

if embeddings is None:
return

"""Validate the dimension of an embedding record before submitting it to the system."""
for embedding in embeddings:
if embedding:
self._validate_dimension(collection, len(embedding), update=True)

# This method is intentionally left untraced because otherwise it can emit thousands of spans for requests containing many embeddings.
def _validate_dimension(
self, collection: t.Collection, dim: int, update: bool
Expand Down
18 changes: 9 additions & 9 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,15 +706,15 @@ def test_metadata_update_get_int_float(client):
def test_metadata_validation_add(client):
client.reset()
collection = client.create_collection("test_metadata_validation")
with pytest.raises(ValueError, match="metadata"):
with pytest.raises(Exception, match="metadata"):
collection.add(**bad_metadata_records)


def test_metadata_validation_update(client):
client.reset()
collection = client.create_collection("test_metadata_validation")
collection.add(**metadata_records)
with pytest.raises(ValueError, match="metadata"):
with pytest.raises(Exception, match="metadata"):
collection.update(ids=["id1"], metadatas={"value": {"nested": "5"}})


Expand Down Expand Up @@ -1203,17 +1203,17 @@ def test_invalid_id(client):
client.reset()
collection = client.create_collection("test_invalid_id")
# Add with non-string id
with pytest.raises(ValueError) as e:
with pytest.raises(Exception) as e:
collection.add(embeddings=[0, 0, 0], ids=[1], metadatas=[{}])
assert "ID" in str(e.value)

# Get with non-list id
with pytest.raises(ValueError) as e:
with pytest.raises(Exception) as e:
collection.get(ids=1)
assert "ID" in str(e.value)

# Delete with malformed ids
with pytest.raises(ValueError) as e:
with pytest.raises(Exception) as e:
collection.delete(ids=["valid", 0])
assert "ID" in str(e.value)

Expand Down Expand Up @@ -1548,12 +1548,12 @@ def test_invalid_embeddings(client):
"embeddings": [["0", "0", "0"], ["1.2", "2.24", "3.2"]],
"ids": ["id1", "id2"],
}
with pytest.raises(ValueError) as e:
with pytest.raises(Exception) as e:
collection.add(**invalid_records)
assert "embedding" in str(e.value)

# Query with invalid embeddings
with pytest.raises(ValueError) as e:
with pytest.raises(Exception) as e:
collection.query(
query_embeddings=[["1.1", "2.3", "3.2"]],
n_results=1,
Expand All @@ -1565,7 +1565,7 @@ def test_invalid_embeddings(client):
"embeddings": [[[0], [0], [0]], [[1.2], [2.24], [3.2]]],
"ids": ["id1", "id2"],
}
with pytest.raises(ValueError) as e:
with pytest.raises(Exception) as e:
collection.update(**invalid_records)
assert "embedding" in str(e.value)

Expand All @@ -1574,7 +1574,7 @@ def test_invalid_embeddings(client):
"embeddings": [[[1.1, 2.3, 3.2]], [[1.2, 2.24, 3.2]]],
"ids": ["id1", "id2"],
}
with pytest.raises(ValueError) as e:
with pytest.raises(Exception) as e:
collection.upsert(**invalid_records)
assert "embedding" in str(e.value)

Expand Down

0 comments on commit 88a12a9

Please sign in to comment.