Skip to content

Commit

Permalink
remove normalize embeddings func
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Sep 3, 2024
1 parent c2315b9 commit df9c88d
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 43 deletions.
26 changes: 4 additions & 22 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,7 @@ def _validate_and_prepare_query_request(
)
valid_query_embeddings = (
validate_embeddings(
self._normalize_embeddings(
cast(Embeddings, maybe_cast_one_to_many_embedding(query_embeddings))
)
cast(Embeddings, maybe_cast_one_to_many_embedding(query_embeddings))
)
if query_embeddings is not None
else None
Expand Down Expand Up @@ -389,14 +387,6 @@ def _process_add_request(
uris=uris,
)

normalized_embeddings = (
self._normalize_embeddings(unpacked_record_set["embeddings"])
if unpacked_record_set["embeddings"] is not None
else None
)

unpacked_record_set["embeddings"] = normalized_embeddings

self._validate_record_set(
record_set=unpacked_record_set,
require_data=True,
Expand All @@ -408,8 +398,8 @@ def _process_add_request(
images=unpacked_record_set["images"],
uris=unpacked_record_set["uris"],
)
if normalized_embeddings is None
else normalized_embeddings
if unpacked_record_set["embeddings"] is None
else unpacked_record_set["embeddings"]
)

unpacked_record_set["embeddings"] = prepared_embeddings
Expand Down Expand Up @@ -440,20 +430,12 @@ def _process_upsert_or_update_request(
uris=uris,
)

normalized_embeddings = (
self._normalize_embeddings(unpacked_record_set["embeddings"])
if unpacked_record_set["embeddings"] is not None
else None
)

unpacked_record_set["embeddings"] = normalized_embeddings

self._validate_record_set(
record_set=unpacked_record_set,
require_data=require_data,
)

prepared_embeddings = normalized_embeddings
prepared_embeddings = unpacked_record_set["embeddings"]
try:
prepared_embeddings = (
self._compute_embeddings(
Expand Down
4 changes: 4 additions & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def maybe_cast_one_to_many_embedding(
if isinstance(target[0], (int, float)):
return cast(Embeddings, [target])
# Already a sequence

if isinstance(target, np.ndarray):
return cast(Embeddings, target.tolist())

return cast(Embeddings, target)


Expand Down
30 changes: 30 additions & 0 deletions chromadb/test/api/test_validations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
from chromadb.api.types import validate_embeddings, Embeddings


def test_embeddings_validation() -> None:
invalid_embeddings = [[0, 0, True], [1.2, 2.24, 3.2]]

with pytest.raises(ValueError) as e:
validate_embeddings(invalid_embeddings) # type: ignore[arg-type]

assert "Expected each value in the embedding to be a int or float" in str(e)

invalid_embeddings = [[0, 0, "invalid"], [1.2, 2.24, 3.2]]

with pytest.raises(ValueError) as e:
validate_embeddings(invalid_embeddings) # type: ignore[arg-type]

assert "Expected each value in the embedding to be a int or float" in str(e)

with pytest.raises(ValueError) as e:
validate_embeddings("invalid") # type: ignore[arg-type]

assert "Expected embeddings to be a list, got str" in str(e)


def test_0dim_embedding_validation() -> None:
embds: Embeddings = [[]]
with pytest.raises(ValueError) as e:
validate_embeddings(embds)
assert "Expected each embedding in the embeddings to be a non-empty list" in str(e)
30 changes: 9 additions & 21 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from hypothesis import given, settings, HealthCheck
from typing import Dict, Set, cast, Union, DefaultDict, Any, List
from dataclasses import dataclass
from chromadb.api.types import ID, Embeddings, Include, IDs, validate_embeddings
from chromadb.api.types import (
ID,
Include,
IDs,
validate_embeddings,
maybe_cast_one_to_many_embedding,
)
from chromadb.config import System
import chromadb.errors as errors
from chromadb.api import ClientAPI
Expand Down Expand Up @@ -796,7 +802,7 @@ def test_autocasting_validate_embeddings_for_compatible_types(
supported_types: List[Any],
) -> None:
embds = strategies.create_embeddings(10, 10, supported_types)
validated_embeddings = validate_embeddings(Collection._normalize_embeddings(embds))
validated_embeddings = validate_embeddings(maybe_cast_one_to_many_embedding(embds)) # type: ignore[arg-type]
assert all(
[
isinstance(value, list)
Expand All @@ -816,7 +822,7 @@ def test_autocasting_validate_embeddings_with_ndarray(
supported_types: List[Any],
) -> None:
embds = strategies.create_embeddings_ndarray(10, 10, supported_types)
validated_embeddings = validate_embeddings(Collection._normalize_embeddings(embds))
validated_embeddings = validate_embeddings(maybe_cast_one_to_many_embedding(embds)) # type: ignore[arg-type]
assert all(
[
isinstance(value, list)
Expand All @@ -829,21 +835,3 @@ def test_autocasting_validate_embeddings_with_ndarray(
for value in validated_embeddings
]
)


@given(unsupported_types=st.sampled_from([str, bool]))
def test_autocasting_validate_embeddings_incompatible_types(
unsupported_types: List[Any],
) -> None:
embds = strategies.create_embeddings(10, 10, unsupported_types)
with pytest.raises(ValueError) as e:
validate_embeddings(Collection._normalize_embeddings(embds))

assert "Expected each value in the embedding to be a int or float" in str(e)


def test_0dim_embedding_validation() -> None:
embds: Embeddings = [[]]
with pytest.raises(ValueError) as e:
validate_embeddings(embds)
assert "Expected each embedding in the embeddings to be a non-empty list" in str(e)

0 comments on commit df9c88d

Please sign in to comment.