Skip to content

Commit

Permalink
Improved component model (#629)
Browse files Browse the repository at this point in the history
## Description of changes

All component instances now inherit from a base `Component` class
defining `start` and `stop` methods, and to track their dependencies
throughout the system.

The `System` class has been updated to start and stop all components in
dependency order.

All these material changes are in `chromadb/config.py`

This was necessary to properly implement testing in the new
architecture. The new version of the system has more stateful components
with (e.g.) active subscriptions. It is necessary to provide an explicit
mechanism to shut down the whole stack or else consumers could continue
operating unexpectedly in the background and not be garbage collected.

Most of the changes in this PR are to fix existing type signature
errors. All components of the system are now subclasses of
`EnforceOverrides` which ensures that all signatures match at runtime,
and which operates independently from the type checker (so the `#type:
ignore` is not sufficient to fix it.)

This PR also disables the `type-abstract` MyPy error code. In my opinion
this is an incorrect type rule although there is
[https://github.com/python/mypy/issues/4717](active discussion) on the
topic.

## Test plan

Unit + integration tests updated and passing.

## Documentation Changes

No changes to user-facing APIs.

---------

Co-authored-by: hammadb <hammad@trychroma.com>
  • Loading branch information
levand and HammadB authored Jun 2, 2023
1 parent 1c7dba3 commit 75f5a81
Show file tree
Hide file tree
Showing 22 changed files with 572 additions and 154 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ repos:
rev: 'v1.2.0'
hooks:
- id: mypy
args: [--strict, --ignore-missing-imports, --follow-imports=silent]
args: [--strict, --ignore-missing-imports, --follow-imports=silent, --disable-error-code=type-abstract]
additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest"]
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"--ignore-missing-imports",
"--show-column-numbers",
"--no-pretty",
"--strict"
"--strict",
"--disable-error-code=type-abstract"
]
}
11 changes: 7 additions & 4 deletions chromadb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import chromadb.config
import logging
from chromadb.telemetry.events import ClientStartEvent
from chromadb.telemetry import Telemetry
from chromadb.config import Settings, System
from chromadb.api import API

Expand All @@ -22,14 +23,16 @@ def get_settings() -> Settings:


def Client(settings: Settings = __settings) -> API:
"""Return a chroma.API instance based on the provided or environmental
settings, optionally overriding the DB instance."""
"""Return a running chroma.API instance"""

system = System(settings)

telemetry_client = system.get_telemetry()
telemetry_client = system.instance(Telemetry)
api = system.instance(API)

system.start()

# Submit event for client start
telemetry_client.capture(ClientStartEvent())

return system.get_api()
return api
8 changes: 2 additions & 6 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,11 @@
GetResult,
WhereDocument,
)
from chromadb.config import Component
import chromadb.utils.embedding_functions as ef
from chromadb.telemetry import Telemetry


class API(ABC):
@abstractmethod
def __init__(self, telemetry_client: Telemetry):
pass

class API(Component, ABC):
@abstractmethod
def heartbeat(self) -> int:
"""Returns the current server time in nanoseconds to check if the server is alive
Expand Down
29 changes: 26 additions & 3 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,27 @@
from chromadb.api.models.Collection import Collection
import chromadb.errors as errors
from uuid import UUID
from chromadb.telemetry import Telemetry
from overrides import override


class FastAPI(API):
def __init__(self, system: System):
super().__init__(system)
url_prefix = "https" if system.settings.chroma_server_ssl_enabled else "http"
system.settings.require("chroma_server_host")
system.settings.require("chroma_server_http_port")
self._api_url = f"{url_prefix}://{system.settings.chroma_server_host}:{system.settings.chroma_server_http_port}/api/v1"
self._telemetry_client = system.get_telemetry()
self._telemetry_client = self.require(Telemetry)

@override
def heartbeat(self) -> int:
"""Returns the current server time in nanoseconds to check if the server is alive"""
resp = requests.get(self._api_url)
raise_chroma_error(resp)
return int(resp.json()["nanosecond heartbeat"])

@override
def list_collections(self) -> Sequence[Collection]:
"""Returns a list of all collections"""
resp = requests.get(self._api_url + "/collections")
Expand All @@ -49,6 +54,7 @@ def list_collections(self) -> Sequence[Collection]:

return collections

@override
def create_collection(
self,
name: str,
Expand All @@ -73,6 +79,7 @@ def create_collection(
metadata=resp_json["metadata"],
)

@override
def get_collection(
self,
name: str,
Expand All @@ -90,6 +97,7 @@ def get_collection(
metadata=resp_json["metadata"],
)

@override
def get_or_create_collection(
self,
name: str,
Expand All @@ -102,6 +110,7 @@ def get_or_create_collection(
name, metadata, embedding_function, get_or_create=True
)

@override
def _modify(
self,
id: UUID,
Expand All @@ -115,11 +124,13 @@ def _modify(
)
raise_chroma_error(resp)

@override
def delete_collection(self, name: str) -> None:
"""Deletes a collection"""
resp = requests.delete(self._api_url + "/collections/" + name)
raise_chroma_error(resp)

@override
def _count(self, collection_id: UUID) -> int:
"""Returns the number of embeddings in the database"""
resp = requests.get(
Expand All @@ -128,13 +139,15 @@ def _count(self, collection_id: UUID) -> int:
raise_chroma_error(resp)
return cast(int, resp.json())

def _peek(self, collection_id: UUID, limit: int = 10) -> GetResult:
@override
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
return self._get(
collection_id,
limit=limit,
limit=n,
include=["embeddings", "documents", "metadatas"],
)

@override
def _get(
self,
collection_id: UUID,
Expand Down Expand Up @@ -177,6 +190,7 @@ def _get(
documents=body.get("documents", None),
)

@override
def _delete(
self,
collection_id: UUID,
Expand All @@ -196,6 +210,7 @@ def _delete(
raise_chroma_error(resp)
return cast(IDs, resp.json())

@override
def _add(
self,
ids: IDs,
Expand Down Expand Up @@ -227,6 +242,7 @@ def _add(
raise_chroma_error(resp)
return True

@override
def _update(
self,
collection_id: UUID,
Expand Down Expand Up @@ -255,6 +271,7 @@ def _update(
resp.raise_for_status()
return True

@override
def _upsert(
self,
collection_id: UUID,
Expand Down Expand Up @@ -285,6 +302,7 @@ def _upsert(
resp.raise_for_status()
return True

@override
def _query(
self,
collection_id: UUID,
Expand Down Expand Up @@ -320,18 +338,21 @@ def _query(
documents=body.get("documents", None),
)

@override
def reset(self) -> bool:
"""Resets the database"""
resp = requests.post(self._api_url + "/reset")
raise_chroma_error(resp)
return cast(bool, resp.json())

@override
def persist(self) -> bool:
"""Persists the database"""
resp = requests.post(self._api_url + "/persist")
raise_chroma_error(resp)
return cast(bool, resp.json())

@override
def raw_sql(self, sql: str) -> pd.DataFrame:
"""Runs a raw SQL query against the database"""
resp = requests.post(
Expand All @@ -340,6 +361,7 @@ def raw_sql(self, sql: str) -> pd.DataFrame:
raise_chroma_error(resp)
return pd.DataFrame.from_dict(resp.json())

@override
def create_index(self, collection_name: str) -> bool:
"""Creates an index for the given space key"""
resp = requests.post(
Expand All @@ -348,6 +370,7 @@ def create_index(self, collection_name: str) -> bool:
raise_chroma_error(resp)
return cast(bool, resp.json())

@override
def get_version(self) -> str:
"""Returns the version of the server"""
resp = requests.get(self._api_url + "/version")
Expand Down
Loading

0 comments on commit 75f5a81

Please sign in to comment.