Skip to content

Commit

Permalink
paying for the sins of our fathers
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma authored and atroyn committed Sep 12, 2024
1 parent 261abd5 commit 8253eed
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 226 deletions.
253 changes: 64 additions & 189 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Include,
Loadable,
Metadata,
Metadatas,
Document,
Documents,
Image,
Expand All @@ -47,7 +46,7 @@
validate_n_results,
validate_where,
validate_where_document,
does_record_set_contain_any_data,
record_set_contains_one_of,
)

# TODO: We should rename the types in chromadb.types to be Models where
Expand Down Expand Up @@ -146,105 +145,54 @@ def __repr__(self) -> str:
def get_model(self) -> CollectionModel:
return self._model

@staticmethod
def _unpack_record_set(
self,
ids: OneOrMany[ID],
embeddings: Optional[ # type: ignore[type-arg]
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
]
],
metadatas: Optional[OneOrMany[Metadata]],
documents: Optional[OneOrMany[Document]],
embeddings: Optional[Union[OneOrMany[Embedding], OneOrMany[np.ndarray]]] = None, # type: ignore[type-arg]
metadatas: Optional[OneOrMany[Metadata]] = None,
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> RecordSet:
unpacked_ids = maybe_cast_one_to_many(ids)
unpacked_embeddings = maybe_cast_one_to_many_embedding(embeddings)
unpacked_metadatas = maybe_cast_one_to_many(metadatas)
unpacked_documents = maybe_cast_one_to_many(documents)
unpacked_images = maybe_cast_one_to_many(images)
unpacked_uris = maybe_cast_one_to_many(uris)
return {
"ids": cast(IDs, unpacked_ids),
"embeddings": unpacked_embeddings,
"metadatas": unpacked_metadatas,
"documents": unpacked_documents,
"images": unpacked_images,
"uris": unpacked_uris,
"ids": cast(IDs, maybe_cast_one_to_many(ids)),
"embeddings": maybe_cast_one_to_many_embedding(embeddings),
"metadatas": maybe_cast_one_to_many(metadatas),
"documents": maybe_cast_one_to_many(documents),
"images": maybe_cast_one_to_many(images),
"uris": maybe_cast_one_to_many(uris),
}

@staticmethod
def _validate_record_set(
self,
ids: IDs,
embeddings: Optional[Embeddings],
metadatas: Optional[Metadatas],
documents: Optional[Documents],
images: Optional[Images],
uris: Optional[URIs],
require_embeddings_or_data: bool = True,
record_set: RecordSet,
require_data: bool,
) -> None:
valid_ids = validate_ids(ids)
valid_embeddings = (
validate_embeddings(embeddings) if embeddings is not None else None
)
valid_metadatas = (
validate_metadatas(metadatas) if metadatas is not None else None
)

# No additional validation needed for documents, images, or uris
valid_documents = documents
valid_images = images
valid_uris = uris

# Check that one of embeddings or ducuments or images is provided
if require_embeddings_or_data:
if (
valid_embeddings is None
and valid_documents is None
and valid_images is None
and valid_uris is None
):
raise ValueError(
"You must provide embeddings, documents, images, or uris."
)
else:
# will replace this with does_record_set_contain_any_data in the following PR
if (
valid_embeddings is None
and valid_documents is None
and valid_images is None
and valid_uris is None
and valid_metadatas is None
):
raise ValueError("You must provide either data or metadatas.")
validate_ids(record_set["ids"])
validate_embeddings(record_set["embeddings"]) if record_set[
"embeddings"
] is not None else None
validate_metadatas(record_set["metadatas"]) if record_set[
"metadatas"
] is not None else None

# Only one of documents or images can be provided
if valid_documents is not None and valid_images is not None:
if record_set["documents"] is not None and record_set["images"] is not None:
raise ValueError("You can only provide documents or images, not both.")

# Check that, if they're provided, the lengths of the arrays match the length of ids
if valid_embeddings is not None and len(valid_embeddings) != len(valid_ids):
raise ValueError(
f"Number of embeddings {len(valid_embeddings)} must match number of ids {len(valid_ids)}"
)
if valid_metadatas is not None and len(valid_metadatas) != len(valid_ids):
raise ValueError(
f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}"
)
if valid_documents is not None and len(valid_documents) != len(valid_ids):
raise ValueError(
f"Number of documents {len(valid_documents)} must match number of ids {len(valid_ids)}"
)
if valid_images is not None and len(valid_images) != len(valid_ids):
raise ValueError(
f"Number of images {len(valid_images)} must match number of ids {len(valid_ids)}"
)
if valid_uris is not None and len(valid_uris) != len(valid_ids):
raise ValueError(
f"Number of uris {len(valid_uris)} must match number of ids {len(valid_ids)}"
)
required_fields: Include = ["embeddings", "documents", "images", "uris"] # type: ignore[list-item]
if not require_data:
required_fields += ["metadatas"] # type: ignore[list-item]

if not record_set_contains_one_of(record_set, include=required_fields):
raise ValueError(f"You must provide one of {required_fields}")

valid_ids = record_set["ids"]
for key in ["embeddings", "metadatas", "documents", "images", "uris"]:
if record_set[key] is not None and len(record_set[key]) != len(valid_ids): # type: ignore[literal-required]
raise ValueError(
f"Number of {key} {len(record_set[key])} must match number of ids {len(valid_ids)}" # type: ignore[literal-required]
)

def _compute_embeddings(
self,
Expand Down Expand Up @@ -362,14 +310,7 @@ def _validate_and_prepare_query_request(
valid_include = validate_include(include, allow_distances=True)
valid_n_results = validate_n_results(n_results)

embeddings_to_normalize = maybe_cast_one_to_many_embedding(query_embeddings)
normalized_embeddings = (
self._normalize_embeddings(embeddings_to_normalize)
if embeddings_to_normalize is not None
else None
)

valid_query_embeddings = None
normalized_embeddings = maybe_cast_one_to_many_embedding(query_embeddings)
if normalized_embeddings is not None:
valid_query_embeddings = validate_embeddings(normalized_embeddings)
else:
Expand Down Expand Up @@ -436,7 +377,7 @@ def _process_add_request(
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> RecordSet:
unpacked_record_set = self._unpack_record_set(
record_set = self._unpack_record_set(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
Expand All @@ -445,39 +386,19 @@ def _process_add_request(
uris=uris,
)

normalized_embeddings = (
self._normalize_embeddings(unpacked_record_set["embeddings"])
if unpacked_record_set["embeddings"] is not None
else None
)

self._validate_record_set(
ids=unpacked_record_set["ids"],
embeddings=normalized_embeddings,
metadatas=unpacked_record_set["metadatas"],
documents=unpacked_record_set["documents"],
images=unpacked_record_set["images"],
uris=unpacked_record_set["uris"],
record_set,
require_data=True,
)

prepared_embeddings = (
self._compute_embeddings(
documents=unpacked_record_set["documents"],
images=unpacked_record_set["images"],
uris=unpacked_record_set["uris"],
if record_set["embeddings"] is None:
record_set["embeddings"] = self._compute_embeddings(
documents=record_set["documents"],
images=record_set["images"],
uris=record_set["uris"],
)
if normalized_embeddings is None
else normalized_embeddings
)

return {
"ids": unpacked_record_set["ids"],
"embeddings": prepared_embeddings,
"metadatas": unpacked_record_set["metadatas"],
"documents": unpacked_record_set["documents"],
"images": unpacked_record_set["images"],
"uris": unpacked_record_set["uris"],
}
return record_set

def _process_upsert_request(
self,
Expand All @@ -493,7 +414,7 @@ def _process_upsert_request(
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> RecordSet:
unpacked_record_set = self._unpack_record_set(
record_set = self._unpack_record_set(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
Expand All @@ -502,37 +423,20 @@ def _process_upsert_request(
uris=uris,
)

normalized_embeddings = (
self._normalize_embeddings(unpacked_record_set["embeddings"])
if unpacked_record_set["embeddings"] is not None
else None
)

self._validate_record_set(
ids=unpacked_record_set["ids"],
embeddings=normalized_embeddings,
metadatas=unpacked_record_set["metadatas"],
documents=unpacked_record_set["documents"],
images=unpacked_record_set["images"],
uris=unpacked_record_set["uris"],
record_set,
require_data=True,
)

prepared_embeddings = normalized_embeddings
if prepared_embeddings is None:
prepared_embeddings = self._compute_embeddings(
documents=unpacked_record_set["documents"],
images=unpacked_record_set["images"],
# TODO: Correctly handle Upsert for URIs
if record_set["embeddings"] is None:
record_set["embeddings"] = self._compute_embeddings(
documents=record_set["documents"],
images=record_set["images"],
uris=None,
)

return {
"ids": unpacked_record_set["ids"],
"embeddings": prepared_embeddings,
"metadatas": unpacked_record_set["metadatas"],
"documents": unpacked_record_set["documents"],
"images": unpacked_record_set["images"],
"uris": unpacked_record_set["uris"],
}
return record_set

def _process_update_request(
self,
Expand All @@ -548,7 +452,7 @@ def _process_update_request(
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> RecordSet:
unpacked_record_set = self._unpack_record_set(
record_set = self._unpack_record_set(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
Expand All @@ -557,40 +461,22 @@ def _process_update_request(
uris=uris,
)

normalized_embeddings = (
self._normalize_embeddings(unpacked_record_set["embeddings"])
if unpacked_record_set["embeddings"] is not None
else None
)

self._validate_record_set(
ids=unpacked_record_set["ids"],
embeddings=normalized_embeddings,
metadatas=unpacked_record_set["metadatas"],
documents=unpacked_record_set["documents"],
images=unpacked_record_set["images"],
uris=unpacked_record_set["uris"],
require_embeddings_or_data=False,
record_set,
require_data=False,
)

prepared_embeddings = normalized_embeddings
if prepared_embeddings is None and does_record_set_contain_any_data(
unpacked_record_set, include=["documents", "images"]
# TODO: Correctly handle Update for URIs
if record_set["embeddings"] is None and record_set_contains_one_of(
record_set, include=["documents", "images"] # type: ignore[list-item]
):
prepared_embeddings = self._compute_embeddings(
documents=unpacked_record_set["documents"],
images=unpacked_record_set["images"],
record_set["embeddings"] = self._compute_embeddings(
documents=record_set["documents"],
images=record_set["images"],
uris=None,
)

return {
"ids": unpacked_record_set["ids"],
"embeddings": prepared_embeddings,
"metadatas": unpacked_record_set["metadatas"],
"documents": unpacked_record_set["documents"],
"images": unpacked_record_set["images"],
"uris": unpacked_record_set["uris"],
}
return record_set

def _validate_and_prepare_delete_request(
self,
Expand All @@ -606,17 +492,6 @@ def _validate_and_prepare_delete_request(

return (ids, where, where_document)

@staticmethod
def _normalize_embeddings(
embeddings: Union[ # type: ignore[type-arg]
OneOrMany[Embedding],
OneOrMany[np.ndarray],
]
) -> Embeddings:
if isinstance(embeddings, np.ndarray):
return embeddings.tolist() # type: ignore
return embeddings # type: ignore

def _embed(self, input: Any) -> Embeddings:
if self._embedding_function is None:
raise ValueError(
Expand Down
Loading

0 comments on commit 8253eed

Please sign in to comment.