@@ -34,27 +34,67 @@ class OllamaContainer(DockerContainer):
3434 """
3535 Ollama Container
3636
37- Example:
37+ :param: image - the ollama image to use (default: :code:`ollama/ollama:0.1.44`)
38+ :param: ollama_home - the directory to mount for model data (default: None)
39+
40+ you may pass :code:`pathlib.Path.home() / ".ollama"` to re-use models
41+ that have already been pulled with ollama running on this host outside the container.
42+
43+ Examples:
3844
3945 .. doctest::
4046
4147 >>> from testcontainers.ollama import OllamaContainer
4248 >>> with OllamaContainer() as ollama:
4349 ... ollama.list_models()
4450 []
51+
52+ .. code-block:: python
53+
54+ >>> from json import loads
55+ >>> from pathlib import Path
56+ >>> from requests import post
57+ >>> from testcontainers.ollama import OllamaContainer
58+ >>> def split_by_line(generator):
59+ ... data = b''
60+ ... for each_item in generator:
61+ ... for line in each_item.splitlines(True):
62+ ... data += line
63+ ... if data.endswith((b'\\ r\\ r', b'\\ n\\ n', b'\\ r\\ n\\ r\\ n', b'\\ n')):
64+ ... yield from data.splitlines()
65+ ... data = b''
66+ ... if data:
67+ ... yield from data.splitlines()
68+
69+ >>> with OllamaContainer(ollama_home=Path.home() / ".ollama") as ollama:
70+ ... if "llama3:latest" not in [e["name"] for e in ollama.list_models()]:
71+ ... print("did not find 'llama3:latest', pulling")
72+ ... ollama.pull_model("llama3:latest")
73+ ... endpoint = ollama.get_endpoint()
74+ ... for chunk in split_by_line(
75+ ... post(url=f"{endpoint}/api/chat", stream=True, json={
76+ ... "model": "llama3:latest",
77+ ... "messages": [{
78+ ... "role": "user",
79+ ... "content": "what color is the sky? MAX ONE WORD"
80+ ... }]
81+ ... })
82+ ... ):
83+ ... print(loads(chunk)["message"]["content"], end="")
84+ Blue.
4585 """
4686
4787 OLLAMA_PORT = 11434
4888
4989 def __init__ (
5090 self ,
5191 image : str = "ollama/ollama:0.1.44" ,
52- ollama_dir : Optional [Union [str , PathLike ]] = None ,
92+ ollama_home : Optional [Union [str , PathLike ]] = None ,
5393 ** kwargs ,
5494 #
5595 ):
5696 super ().__init__ (image = image , ** kwargs )
57- self .ollama_dir = ollama_dir
97+ self .ollama_home = ollama_home
5898 self .with_exposed_ports (OllamaContainer .OLLAMA_PORT )
5999 self ._check_and_add_gpu_capabilities ()
60100
@@ -67,8 +107,8 @@ def start(self) -> "OllamaContainer":
67107 """
68108 Start the Ollama server
69109 """
70- if self .ollama_dir :
71- self .with_volume_mapping (self .ollama_dir , "/root/.ollama" , "rw" )
110+ if self .ollama_home :
111+ self .with_volume_mapping (self .ollama_home , "/root/.ollama" , "rw" )
72112 super ().start ()
73113 wait_for_logs (self , "Listening on " , timeout = 30 )
74114
0 commit comments