Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Aug 23, 2024
1 parent beb307f commit 28ef4e0
Show file tree
Hide file tree
Showing 25 changed files with 351 additions and 79 deletions.
3 changes: 2 additions & 1 deletion chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
URIs,
Where,
QueryResult,
AddResult,
GetResult,
WhereDocument,
)
Expand Down Expand Up @@ -121,7 +122,7 @@ def _add(
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
"""[Internal] Add embeddings to a collection specified by UUID.
If (some) ids already exist, only the new embeddings will be added.
Expand Down
3 changes: 2 additions & 1 deletion chromadb/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Where,
QueryResult,
GetResult,
AddResult,
WhereDocument,
)
from chromadb.config import Component, Settings
Expand Down Expand Up @@ -112,7 +113,7 @@ async def _add(
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
"""[Internal] Add embeddings to a collection specified by UUID.
If (some) ids already exist, only the new embeddings will be added.
Expand Down
3 changes: 2 additions & 1 deletion chromadb/api/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
EmbeddingFunction,
Embeddings,
GetResult,
AddResult,
IDs,
Include,
Loadable,
Expand Down Expand Up @@ -266,7 +267,7 @@ async def _add(
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
return await self._server._add(
ids=ids,
collection_id=collection_id,
Expand Down
21 changes: 18 additions & 3 deletions chromadb/api/async_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Where,
WhereDocument,
GetResult,
AddResult,
QueryResult,
CollectionMetadata,
validate_batch,
Expand Down Expand Up @@ -448,11 +449,25 @@ async def _add(
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()})
await self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
return True

resp_json = await self._make_request(
"post",
"/collections/" + str(collection_id) + "/add",
json={
"ids": batch[0],
"embeddings": batch[1],
"metadatas": batch[2],
"documents": batch[3],
"uris": batch[4],
},
)

return AddResult(
ids=resp_json["ids"],
)

@trace_method("AsyncFastAPI._update", OpenTelemetryGranularity.ALL)
@override
Expand Down
3 changes: 2 additions & 1 deletion chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Loadable,
Metadatas,
QueryResult,
AddResult,
URIs,
)
from chromadb.config import Settings, System
Expand Down Expand Up @@ -214,7 +215,7 @@ def _add(
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
return self._server._add(
ids=ids,
collection_id=collection_id,
Expand Down
21 changes: 18 additions & 3 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Where,
WhereDocument,
GetResult,
AddResult,
QueryResult,
CollectionMetadata,
validate_batch,
Expand Down Expand Up @@ -413,15 +414,29 @@ def _add(
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
"""
Adds a batch of embeddings to the database
- pass in column oriented data lists
"""
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()})
self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
return True

resp_json = self._make_request(
"post",
"/collections/" + str(collection_id) + "/add",
json={
"ids": batch[0],
"embeddings": batch[1],
"metadatas": batch[2],
"documents": batch[3],
"uris": batch[4],
},
)

return AddResult(
ids=resp_json["ids"],
)

@trace_method("FastAPI._update", OpenTelemetryGranularity.ALL)
@override
Expand Down
9 changes: 5 additions & 4 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Where,
IDs,
GetResult,
AddResult,
QueryResult,
ID,
OneOrMany,
Expand All @@ -44,7 +45,7 @@ async def add(
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
) -> AddResult:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
Expand Down Expand Up @@ -74,7 +75,7 @@ async def add(
uris,
)

await self._client._add(
result = await self._client._add(
record_set["ids"],
self.id,
cast(Embeddings, record_set["embeddings"]),
Expand All @@ -83,6 +84,8 @@ async def add(
record_set["uris"],
)

return result

async def count(self) -> int:
"""The total number of embeddings added to the database
Expand Down Expand Up @@ -266,7 +269,6 @@ async def update(
documents,
images,
uris,
require_embeddings_or_data=False,
)

await self._client._update(
Expand Down Expand Up @@ -310,7 +312,6 @@ async def upsert(
documents,
images,
uris,
require_embeddings_or_data=True,
)

await self._client._upsert(
Expand Down
11 changes: 6 additions & 5 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Where,
IDs,
GetResult,
AddResult,
QueryResult,
ID,
OneOrMany,
Expand All @@ -40,7 +41,7 @@ def count(self) -> int:

def add(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]],
embeddings: Optional[ # type: ignore[type-arg]
Union[
OneOrMany[Embedding],
Expand All @@ -51,7 +52,7 @@ def add(
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
) -> AddResult:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
Expand Down Expand Up @@ -81,7 +82,7 @@ def add(
uris,
)

self._client._add(
result = self._client._add(
record_set["ids"],
self.id,
cast(Embeddings, record_set["embeddings"]),
Expand All @@ -90,6 +91,8 @@ def add(
record_set["uris"],
)

return result

def get(
self,
ids: Optional[OneOrMany[ID]] = None,
Expand Down Expand Up @@ -264,7 +267,6 @@ def update(
documents,
images,
uris,
require_embeddings_or_data=False,
)

self._client._update(
Expand Down Expand Up @@ -308,7 +310,6 @@ def upsert(
documents,
images,
uris,
require_embeddings_or_data=True,
)

self._client._upsert(
Expand Down
5 changes: 2 additions & 3 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def _update_model_after_modify_success(

def _process_add_request(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]],
embeddings: Optional[
Union[
OneOrMany[Embedding],
Expand All @@ -384,7 +384,7 @@ def _process_add_request(
uris: Optional[OneOrMany[URI]] = None,
) -> RecordSet:
unpacked_embedding_set = self._unpack_embedding_set(
ids,
ids if ids is not None else [],
embeddings,
metadatas,
documents,
Expand Down Expand Up @@ -446,7 +446,6 @@ def _process_upsert_or_update_request(
documents: Optional[OneOrMany[Document]],
images: Optional[OneOrMany[Image]],
uris: Optional[OneOrMany[URI]],
require_embeddings_or_data: bool = True,
) -> RecordSet:
unpacked_embedding_set = self._unpack_embedding_set(
ids, embeddings, metadatas, documents, images, uris
Expand Down
Loading

0 comments on commit 28ef4e0

Please sign in to comment.