Skip to content

Commit c511cf0

Browse files
committed
feat: Added ChromaDB container
1 parent 5758310 commit c511cf0

File tree

4 files changed

+103
-1
lines changed

4 files changed

+103
-1
lines changed

modules/chroma/README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
.. autoclass:: testcontainers.chroma.ChromaContainer
2+
.. title:: testcontainers.minio.ChromaContainer
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from typing import TYPE_CHECKING
2+
3+
import chromadb
4+
from requests import ConnectionError, get
5+
6+
from testcontainers.core.container import DockerContainer
7+
from testcontainers.core.utils import raise_for_deprecated_parameter
8+
from testcontainers.core.waiting_utils import wait_container_is_ready
9+
10+
if TYPE_CHECKING:
11+
from requests import Response
12+
13+
14+
class ChromaContainer(DockerContainer):
15+
"""
16+
The example below spins up a ChromaDB container, performs a healthcheck and creates a collection.
17+
The method :code:`get_client` can be used to create a client for the Chroma Python Client.
18+
19+
Example:
20+
21+
.. doctest::
22+
23+
>>> import io
24+
>>> from testcontainers.chroma import ChromaContainer
25+
26+
>>> with ChromaContainer() as chroma:
27+
... client = chroma.get_client()
28+
... client.heartbeat()
29+
... client.get_or_create_collection("test")
30+
"""
31+
32+
def __init__(
33+
self,
34+
image: str = "chromadb/chroma:latest",
35+
port: int = 8000,
36+
**kwargs,
37+
) -> None:
38+
"""
39+
Args:
40+
image: Docker image to use for the MinIO container.
41+
port: Port to expose on the container.
42+
access_key: Access key for client connections.
43+
secret_key: Secret key for client connections.
44+
"""
45+
raise_for_deprecated_parameter(kwargs, "port_to_expose", "port")
46+
super().__init__(image, **kwargs)
47+
self.port = port
48+
49+
self.with_exposed_ports(self.port)
50+
# self.with_command(f"server /data --address :{self.port}")
51+
52+
def get_client(self, **kwargs) -> chromadb.ClientAPI:
53+
"""Returns a Chroma HttpClient to connect to the container.
54+
55+
Returns:
56+
HttpClient: Chroma HttpClient according to
57+
https://docs.trychroma.com/reference/Client#httpclient
58+
"""
59+
host_ip = self.get_container_host_ip()
60+
exposed_port = self.get_exposed_port(self.port)
61+
return chromadb.HttpClient(host=host_ip, port=int(exposed_port))
62+
63+
def get_config(self) -> dict:
64+
"""This method returns the configuration of the Chroma container,
65+
including the endpoint.
66+
67+
Returns:
68+
dict: {`endpoint`: str}
69+
"""
70+
host_ip = self.get_container_host_ip()
71+
exposed_port = self.get_exposed_port(self.port)
72+
return {
73+
"endpoint": f"{host_ip}:{exposed_port}",
74+
}
75+
76+
@wait_container_is_ready(ConnectionError)
77+
def _healthcheck(self) -> None:
78+
"""This is an internal method used to check if the Chroma container
79+
is healthy and ready to receive requests."""
80+
url = f"http://{self.get_config()['endpoint']}/api/v1/heartbeat"
81+
response: Response = get(url)
82+
response.raise_for_status()
83+
84+
def start(self) -> "ChromaContainer":
85+
"""This method starts the Chroma container and runs the healthcheck
86+
to verify that the container is ready to use."""
87+
super().start()
88+
self._healthcheck()
89+
return self
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from testcontainers.chroma import ChromaContainer
2+
3+
4+
def test_docker_run_chroma():
5+
with ChromaContainer(image="chromadb/chroma:0.4.24") as chroma:
6+
client = chroma.get_client()
7+
col = client.get_or_create_collection("test")
8+
assert col.name == "test"

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ packages = [
4848
{ include = "testcontainers", from = "modules/postgres" },
4949
{ include = "testcontainers", from = "modules/rabbitmq" },
5050
{ include = "testcontainers", from = "modules/redis" },
51-
{ include = "testcontainers", from = "modules/selenium" }
51+
{ include = "testcontainers", from = "modules/selenium" },
52+
{ include = "testcontainers", from = "modules/chroma" },
5253
]
5354

5455
[tool.poetry.urls]
@@ -84,6 +85,7 @@ cx_Oracle = { version = "*", optional = true }
8485
pika = { version = "*", optional = true }
8586
redis = { version = "*", optional = true }
8687
selenium = { version = "*", optional = true }
88+
chroma = { version = "*", optional = true }
8789

8890
[tool.poetry.extras]
8991
arangodb = ["python-arango"]
@@ -108,6 +110,7 @@ postgres = []
108110
rabbitmq = ["pika"]
109111
redis = ["redis"]
110112
selenium = ["selenium"]
113+
chroma = ["chromadb-client"]
111114

112115
[tool.poetry.group.dev.dependencies]
113116
mypy = "1.7.1"

0 commit comments

Comments
 (0)