|  | 
|  | 1 | +from typing import TYPE_CHECKING | 
|  | 2 | + | 
|  | 3 | +from requests import ConnectionError, get | 
|  | 4 | + | 
|  | 5 | +from testcontainers.core.container import DockerContainer | 
|  | 6 | +from testcontainers.core.utils import raise_for_deprecated_parameter | 
|  | 7 | +from testcontainers.core.waiting_utils import wait_container_is_ready | 
|  | 8 | + | 
|  | 9 | +if TYPE_CHECKING: | 
|  | 10 | +    from requests import Response | 
|  | 11 | + | 
|  | 12 | + | 
|  | 13 | +class ChromaContainer(DockerContainer): | 
|  | 14 | +    """ | 
|  | 15 | +    The example below spins up a ChromaDB container, performs a healthcheck and creates a collection. | 
|  | 16 | +    The method :code:`get_client` can be used to create a client for the Chroma Python Client. | 
|  | 17 | +
 | 
|  | 18 | +    Example: | 
|  | 19 | +
 | 
|  | 20 | +        .. doctest:: | 
|  | 21 | +
 | 
|  | 22 | +            >>> import chromadb | 
|  | 23 | +            >>> from testcontainers.chroma import ChromaContainer | 
|  | 24 | +
 | 
|  | 25 | +            >>> with ChromaContainer() as chroma: | 
|  | 26 | +            ...   config = chroma.get_config() | 
|  | 27 | +            ...   client = chromadb.HttpClient(host=config["host"], port=config["port"]) | 
|  | 28 | +            ...   col = client.get_or_create_collection("test") | 
|  | 29 | +            ...   col.name | 
|  | 30 | +            'test' | 
|  | 31 | +    """ | 
|  | 32 | + | 
|  | 33 | +    def __init__( | 
|  | 34 | +        self, | 
|  | 35 | +        image: str = "chromadb/chroma:latest", | 
|  | 36 | +        port: int = 8000, | 
|  | 37 | +        **kwargs, | 
|  | 38 | +    ) -> None: | 
|  | 39 | +        """ | 
|  | 40 | +        Args: | 
|  | 41 | +            image: Docker image to use for the MinIO container. | 
|  | 42 | +            port: Port to expose on the container. | 
|  | 43 | +            access_key: Access key for client connections. | 
|  | 44 | +            secret_key: Secret key for client connections. | 
|  | 45 | +        """ | 
|  | 46 | +        raise_for_deprecated_parameter(kwargs, "port_to_expose", "port") | 
|  | 47 | +        super().__init__(image, **kwargs) | 
|  | 48 | +        self.port = port | 
|  | 49 | + | 
|  | 50 | +        self.with_exposed_ports(self.port) | 
|  | 51 | +        # self.with_command(f"server /data --address :{self.port}") | 
|  | 52 | + | 
|  | 53 | +    def get_config(self) -> dict: | 
|  | 54 | +        """This method returns the configuration of the Chroma container, | 
|  | 55 | +        including the endpoint. | 
|  | 56 | +
 | 
|  | 57 | +        Returns: | 
|  | 58 | +            dict: {`endpoint`: str} | 
|  | 59 | +        """ | 
|  | 60 | +        host_ip = self.get_container_host_ip() | 
|  | 61 | +        exposed_port = self.get_exposed_port(self.port) | 
|  | 62 | +        return { | 
|  | 63 | +            "endpoint": f"{host_ip}:{exposed_port}", | 
|  | 64 | +            "host": host_ip, | 
|  | 65 | +            "port": exposed_port, | 
|  | 66 | +        } | 
|  | 67 | + | 
|  | 68 | +    @wait_container_is_ready(ConnectionError) | 
|  | 69 | +    def _healthcheck(self) -> None: | 
|  | 70 | +        """This is an internal method used to check if the Chroma container | 
|  | 71 | +        is healthy and ready to receive requests.""" | 
|  | 72 | +        url = f"http://{self.get_config()['endpoint']}/api/v1/heartbeat" | 
|  | 73 | +        response: Response = get(url) | 
|  | 74 | +        response.raise_for_status() | 
|  | 75 | + | 
|  | 76 | +    def start(self) -> "ChromaContainer": | 
|  | 77 | +        """This method starts the Chroma container and runs the healthcheck | 
|  | 78 | +        to verify that the container is ready to use.""" | 
|  | 79 | +        super().start() | 
|  | 80 | +        self._healthcheck() | 
|  | 81 | +        return self | 
0 commit comments