Skip to content

Commit

Permalink
move validation logic to server
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Aug 22, 2024
1 parent c45e798 commit 5fa105d
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 88 deletions.
4 changes: 2 additions & 2 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
class AsyncCollection(CollectionCommon["AsyncServerAPI"]):
async def add(
self,
ids: Optional[OneOrMany[ID]] = None,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
Expand All @@ -45,7 +45,7 @@ async def add(
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> AddResult:
) -> None:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
Expand Down
3 changes: 1 addition & 2 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Include,
Metadata,
Document,
AddResult,
Image,
Where,
IDs,
Expand Down Expand Up @@ -52,7 +51,7 @@ def add(
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> AddResult:
) -> None:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
Expand Down
63 changes: 0 additions & 63 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
Include,
Loadable,
Metadata,
Metadatas,
Document,
Documents,
Image,
Expand All @@ -45,7 +44,6 @@
validate_ids,
validate_include,
validate_metadata,
validate_metadatas,
validate_embeddings,
validate_embedding_function,
validate_n_results,
Expand Down Expand Up @@ -180,65 +178,14 @@ def _unpack_embedding_set(

def _validate_embedding_set(
self,
ids: IDs,
embeddings: Optional[Embeddings],
metadatas: Optional[Metadatas],
documents: Optional[Documents],
images: Optional[Images],
uris: Optional[URIs],
require_embeddings_or_data: bool = True,
) -> None:
valid_ids = validate_ids(ids)
valid_embeddings = (
validate_embeddings(embeddings) if embeddings is not None else None
)
valid_metadatas = (
validate_metadatas(metadatas) if metadatas is not None else None
)

# Already validated from being unpacked from OneOrMany data types
valid_documents = documents
valid_images = images
valid_uris = uris

# Check that one of embeddings or ducuments or images is provided
if require_embeddings_or_data:
if (
valid_embeddings is None
and valid_documents is None
and valid_images is None
and valid_uris is None
):
raise ValueError(
"You must provide embeddings, documents, images, or uris."
)

# Only one of documents or images can be provided
if documents is not None and images is not None:
raise ValueError("You can only provide documents or images, not both.")

# Check that, if they're provided, the lengths of the arrays match the length of ids
if valid_embeddings is not None and len(valid_embeddings) != len(valid_ids):
raise ValueError(
f"Number of embeddings {len(valid_embeddings)} must match number of ids {len(valid_ids)}"
)
if valid_metadatas is not None and len(valid_metadatas) != len(valid_ids):
raise ValueError(
f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}"
)
if documents is not None and len(documents) != len(valid_ids):
raise ValueError(
f"Number of documents {len(documents)} must match number of ids {len(valid_ids)}"
)
if images is not None and len(images) != len(valid_ids):
raise ValueError(
f"Number of images {len(images)} must match number of ids {len(valid_ids)}"
)
if uris is not None and len(uris) != len(valid_ids):
raise ValueError(
f"Number of uris {len(uris)} must match number of ids {len(valid_ids)}"
)

def _compute_embeddings(
self,
embeddings: Optional[Embeddings],
Expand Down Expand Up @@ -455,13 +402,8 @@ def _process_add_request(
)

self._validate_embedding_set(
unpacked_embedding_set["ids"],
normalized_embeddings,
unpacked_embedding_set["metadatas"],
unpacked_embedding_set["documents"],
unpacked_embedding_set["images"],
unpacked_embedding_set["uris"],
require_embeddings_or_data=False,
)

prepared_embeddings = self._compute_embeddings(
Expand Down Expand Up @@ -519,13 +461,8 @@ def _process_upsert_or_update_request(
)

self._validate_embedding_set(
unpacked_embedding_set["ids"],
normalized_embeddings,
unpacked_embedding_set["metadatas"],
unpacked_embedding_set["documents"],
unpacked_embedding_set["images"],
unpacked_embedding_set["uris"],
require_embeddings_or_data=False,
)

prepared_embeddings = self._compute_embeddings_upsert_or_update_request(
Expand Down
106 changes: 85 additions & 21 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
validate_where,
validate_where_document,
validate_batch,
validate_ids,
validate_embeddings,
validate_metadatas,
)
from chromadb.telemetry.product.events import (
CollectionAddEvent,
Expand All @@ -48,7 +51,7 @@
)

import chromadb.types as t
from typing import Optional, Sequence, Generator, List, cast, Set, Dict
from typing import Optional, Sequence, Generator, List, cast, Set, Dict, Tuple, Any
from overrides import override
from uuid import UUID, uuid4
import time
Expand Down Expand Up @@ -340,10 +343,17 @@ def _add(
self._quota.static_check(metadatas, documents, embeddings, str(collection_id))
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.ADD)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},

self._validate_embedding_record_set(
collection=coll,
ids=ids,
embeddings=embeddings,
documents=documents,
uris=uris,
metadatas=metadatas,
require_embeddings_or_data=True,
)

records_to_submit = list(
_records(
t.Operation.ADD,
Expand All @@ -354,7 +364,6 @@ def _add(
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
Expand Down Expand Up @@ -382,10 +391,17 @@ def _update(
self._quota.static_check(metadatas, documents, embeddings, str(collection_id))
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},

self._validate_embedding_record_set(
collection=coll,
ids=ids,
embeddings=embeddings,
documents=documents,
uris=uris,
metadatas=metadatas,
require_embeddings_or_data=False,
)

records_to_submit = list(
_records(
t.Operation.UPDATE,
Expand All @@ -396,7 +412,6 @@ def _update(
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
Expand Down Expand Up @@ -426,10 +441,17 @@ def _upsert(
self._quota.static_check(metadatas, documents, embeddings, str(collection_id))
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},

self._validate_embedding_record_set(
collection=coll,
ids=ids,
embeddings=embeddings,
documents=documents,
uris=uris,
metadatas=metadatas,
require_embeddings_or_data=True,
)

records_to_submit = list(
_records(
t.Operation.UPSERT,
Expand All @@ -440,7 +462,6 @@ def _upsert(
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

return True
Expand Down Expand Up @@ -591,7 +612,6 @@ def _delete(
"""
)

coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.DELETE)

if (where or where_document) or not ids:
Expand All @@ -609,7 +629,6 @@ def _delete(
records_to_submit = list(
_records(operation=t.Operation.DELETE, ids=ids_to_delete)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
Expand Down Expand Up @@ -799,16 +818,61 @@ def get_max_batch_size(self) -> int:
"SegmentAPI._validate_embedding_record_set", OpenTelemetryGranularity.ALL
)
def _validate_embedding_record_set(
self, collection: t.Collection, records: List[t.OperationRecord]
self,
ids: IDs,
collection: t.Collection,
require_embeddings_or_data: bool,
embeddings: Optional[Embeddings] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
metadatas: Optional[Metadatas] = None,
) -> None:
"""Validate the dimension of an embedding record before submitting it to the system."""
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},
)

add_attributes_to_current_span({"collection_id": str(collection["id"])})
for record in records:
if record["embedding"]:
self._validate_dimension(
collection, len(record["embedding"]), update=True

validate_ids(ids)
validate_embeddings(embeddings) if embeddings is not None else None
validate_metadatas(metadatas) if metadatas is not None else None

if (
require_embeddings_or_data
and embeddings is None
and documents is None
and uris is None
):
raise ValueError("You must provide embeddings, documents, or uris.")

entities: List[Tuple[Any, str]] = [
(embeddings, "embeddings"),
(metadatas, "metadatas"),
(documents, "documents"),
(uris, "uris"),
]

n_ids = len(ids)
for entity in entities:
if entity[0] is None:
continue

name = entity[1]
n = len(entity[0])

if n != len(ids):
raise ValueError(
f"Number of {name} ({n}) does not match number of ids ({n_ids})"
)

if embeddings is None:
return

"""Validate the dimension of an embedding record before submitting it to the system."""
for embedding in embeddings:
self._validate_dimension(collection, len(embedding), update=True)

# This method is intentionally left untraced because otherwise it can emit thousands of spans for requests containing many embeddings.
def _validate_dimension(
self, collection: t.Collection, dim: int, update: bool
Expand Down

0 comments on commit 5fa105d

Please sign in to comment.