@@ -34,27 +34,67 @@ class OllamaContainer(DockerContainer):
3434    """ 
3535    Ollama Container 
3636
37-     Example: 
37+     :param: image - the ollama image to use (default: :code:`ollama/ollama:0.1.44`) 
38+     :param: ollama_home - the directory to mount for model data (default: None) 
39+ 
40+     you may pass :code:`pathlib.Path.home() / ".ollama"` to re-use models 
41+     that have already been pulled with ollama running on this host outside the container. 
42+ 
43+     Examples: 
3844
3945        .. doctest:: 
4046
4147            >>> from testcontainers.ollama import OllamaContainer 
4248            >>> with OllamaContainer() as ollama: 
4349            ...     ollama.list_models() 
4450            [] 
51+ 
52+         .. code-block:: python 
53+ 
54+             >>> from json import loads 
55+             >>> from pathlib import Path 
56+             >>> from requests import post 
57+             >>> from testcontainers.ollama import OllamaContainer 
58+             >>> def split_by_line(generator): 
59+             ...     data = b'' 
60+             ...     for each_item in generator: 
61+             ...         for line in each_item.splitlines(True): 
62+             ...             data += line 
63+             ...             if data.endswith((b'\\ r\\ r', b'\\ n\\ n', b'\\ r\\ n\\ r\\ n', b'\\ n')): 
64+             ...                 yield from data.splitlines() 
65+             ...                 data = b'' 
66+             ...     if data: 
67+             ...         yield from data.splitlines() 
68+ 
69+             >>> with OllamaContainer(ollama_home=Path.home() / ".ollama") as ollama: 
70+             ...     if "llama3:latest" not in [e["name"] for e in ollama.list_models()]: 
71+             ...         print("did not find 'llama3:latest', pulling") 
72+             ...         ollama.pull_model("llama3:latest") 
73+             ...     endpoint = ollama.get_endpoint() 
74+             ...     for chunk in split_by_line( 
75+             ...         post(url=f"{endpoint}/api/chat", stream=True, json={ 
76+             ...             "model": "llama3:latest", 
77+             ...             "messages": [{ 
78+             ...                 "role": "user", 
79+             ...                 "content": "what color is the sky? MAX ONE WORD" 
80+             ...             }] 
81+             ...         }) 
82+             ...      ): 
83+             ...          print(loads(chunk)["message"]["content"], end="") 
84+             Blue. 
4585    """ 
4686
4787    OLLAMA_PORT  =  11434 
4888
4989    def  __init__ (
5090        self ,
5191        image : str  =  "ollama/ollama:0.1.44" ,
52-         ollama_dir : Optional [Union [str , PathLike ]] =  None ,
92+         ollama_home : Optional [Union [str , PathLike ]] =  None ,
5393        ** kwargs ,
5494        # 
5595    ):
5696        super ().__init__ (image = image , ** kwargs )
57-         self .ollama_dir  =  ollama_dir 
97+         self .ollama_home  =  ollama_home 
5898        self .with_exposed_ports (OllamaContainer .OLLAMA_PORT )
5999        self ._check_and_add_gpu_capabilities ()
60100
@@ -67,8 +107,8 @@ def start(self) -> "OllamaContainer":
67107        """ 
68108        Start the Ollama server 
69109        """ 
70-         if  self .ollama_dir :
71-             self .with_volume_mapping (self .ollama_dir , "/root/.ollama" , "rw" )
110+         if  self .ollama_home :
111+             self .with_volume_mapping (self .ollama_home , "/root/.ollama" , "rw" )
72112        super ().start ()
73113        wait_for_logs (self , "Listening on " , timeout = 30 )
74114
0 commit comments