|  | 
|  | 1 | +# | 
|  | 2 | +#    Licensed under the Apache License, Version 2.0 (the "License"); you may | 
|  | 3 | +#    not use this file except in compliance with the License. You may obtain | 
|  | 4 | +#    a copy of the License at | 
|  | 5 | +# | 
|  | 6 | +#         http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 7 | +# | 
|  | 8 | +#    Unless required by applicable law or agreed to in writing, software | 
|  | 9 | +#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | 
|  | 10 | +#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | 
|  | 11 | +#    License for the specific language governing permissions and limitations | 
|  | 12 | +#    under the License. | 
|  | 13 | + | 
|  | 14 | +from os import PathLike | 
|  | 15 | +from typing import Any, Optional, TypedDict, Union | 
|  | 16 | + | 
|  | 17 | +from docker.types.containers import DeviceRequest | 
|  | 18 | +from requests import get | 
|  | 19 | + | 
|  | 20 | +from testcontainers.core.container import DockerContainer | 
|  | 21 | +from testcontainers.core.waiting_utils import wait_for_logs | 
|  | 22 | + | 
|  | 23 | + | 
|  | 24 | +class OllamaModel(TypedDict): | 
|  | 25 | +    name: str | 
|  | 26 | +    model: str | 
|  | 27 | +    modified_at: str | 
|  | 28 | +    size: int | 
|  | 29 | +    digest: str | 
|  | 30 | +    details: dict[str, Any] | 
|  | 31 | + | 
|  | 32 | + | 
|  | 33 | +class OllamaContainer(DockerContainer): | 
|  | 34 | +    """ | 
|  | 35 | +    Ollama Container | 
|  | 36 | +
 | 
|  | 37 | +    Example: | 
|  | 38 | +
 | 
|  | 39 | +        .. doctest:: | 
|  | 40 | +
 | 
|  | 41 | +            >>> from testcontainers.ollama import OllamaContainer | 
|  | 42 | +            >>> with OllamaContainer() as ollama: | 
|  | 43 | +            ...     ollama.list_models() | 
|  | 44 | +            [] | 
|  | 45 | +    """ | 
|  | 46 | + | 
|  | 47 | +    OLLAMA_PORT = 11434 | 
|  | 48 | + | 
|  | 49 | +    def __init__( | 
|  | 50 | +        self, | 
|  | 51 | +        image: str = "ollama/ollama:0.1.44", | 
|  | 52 | +        ollama_dir: Optional[Union[str, PathLike]] = None, | 
|  | 53 | +        **kwargs, | 
|  | 54 | +        # | 
|  | 55 | +    ): | 
|  | 56 | +        super().__init__(image=image, **kwargs) | 
|  | 57 | +        self.ollama_dir = ollama_dir | 
|  | 58 | +        self.with_exposed_ports(OllamaContainer.OLLAMA_PORT) | 
|  | 59 | +        self._check_and_add_gpu_capabilities() | 
|  | 60 | + | 
|  | 61 | +    def _check_and_add_gpu_capabilities(self): | 
|  | 62 | +        info = self.get_docker_client().client.info() | 
|  | 63 | +        if "nvidia" in info["Runtimes"]: | 
|  | 64 | +            self._kwargs = {**self._kwargs, "device_requests": DeviceRequest(count=-1, capabilities=[["gpu"]])} | 
|  | 65 | + | 
|  | 66 | +    def start(self) -> "OllamaContainer": | 
|  | 67 | +        """ | 
|  | 68 | +        Start the Ollama server | 
|  | 69 | +        """ | 
|  | 70 | +        if self.ollama_dir: | 
|  | 71 | +            self.with_volume_mapping(self.ollama_dir, "/root/.ollama", "rw") | 
|  | 72 | +        super().start() | 
|  | 73 | +        wait_for_logs(self, "Listening on ", timeout=30) | 
|  | 74 | + | 
|  | 75 | +        return self | 
|  | 76 | + | 
|  | 77 | +    def get_endpoint(self): | 
|  | 78 | +        """ | 
|  | 79 | +        Return the endpoint of the Ollama server | 
|  | 80 | +        """ | 
|  | 81 | +        host = self.get_container_host_ip() | 
|  | 82 | +        exposed_port = self.get_exposed_port(OllamaContainer.OLLAMA_PORT) | 
|  | 83 | +        url = f"http://{host}:{exposed_port}" | 
|  | 84 | +        return url | 
|  | 85 | + | 
|  | 86 | +    @property | 
|  | 87 | +    def id(self) -> str: | 
|  | 88 | +        """ | 
|  | 89 | +        Return the container object | 
|  | 90 | +        """ | 
|  | 91 | +        return self._container.id | 
|  | 92 | + | 
|  | 93 | +    def pull_model(self, model_name: str) -> None: | 
|  | 94 | +        """ | 
|  | 95 | +        Pull a model from the Ollama server | 
|  | 96 | +
 | 
|  | 97 | +        Args: | 
|  | 98 | +            model_name (str): Name of the model | 
|  | 99 | +        """ | 
|  | 100 | +        self.exec(f"ollama pull {model_name}") | 
|  | 101 | + | 
|  | 102 | +    def list_models(self) -> list[OllamaModel]: | 
|  | 103 | +        endpoint = self.get_endpoint() | 
|  | 104 | +        response = get(url=f"{endpoint}/api/tags") | 
|  | 105 | +        response.raise_for_status() | 
|  | 106 | +        return response.json().get("models", []) | 
|  | 107 | + | 
|  | 108 | +    def commit_to_image(self, image_name: str) -> None: | 
|  | 109 | +        """ | 
|  | 110 | +        Commit the current container to a new image | 
|  | 111 | +
 | 
|  | 112 | +        Args: | 
|  | 113 | +            image_name (str): Name of the new image | 
|  | 114 | +        """ | 
|  | 115 | +        docker_client = self.get_docker_client() | 
|  | 116 | +        existing_images = docker_client.client.images.list(name=image_name) | 
|  | 117 | +        if not existing_images and self.id: | 
|  | 118 | +            docker_client.client.containers.get(self.id).commit( | 
|  | 119 | +                repository=image_name, conf={"Labels": {"org.testcontainers.session-id": ""}} | 
|  | 120 | +            ) | 
0 commit comments