-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from cyrannano/ollama-llm-integration
Ollama llm integration
- Loading branch information
Showing
3 changed files
with
190 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from ollama import Client, Options | ||
|
||
from .illm import ILlm | ||
from .._utils.constants import PROMPT_EMPTY_EXCEPTION, OLLAMA_CONFIG_REQUIRED | ||
from .._utils import logger | ||
|
||
log = logger.init_loggers("Ollama Client") | ||
|
||
|
||
class Ollama(ILlm): | ||
def __init__(self, model_config: dict, client_config=None, client: Client = None): | ||
""" | ||
Initialize the class with an optional config parameter. | ||
Parameters: | ||
model_config (dict): The model configuration parameter. | ||
config (dict): The configuration parameter. | ||
client (Client): The client parameter. | ||
Returns: | ||
None | ||
""" | ||
self.client = client | ||
self.client_config = client_config | ||
self.model_config = model_config | ||
|
||
if self.client is not None: | ||
if self.client_config is not None: | ||
log.warning("Client object provided. Ignoring client_config parameter.") | ||
return | ||
|
||
if client_config is None: | ||
raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Client")) | ||
|
||
if model_config is None: | ||
raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Model")) | ||
|
||
if 'model' not in model_config: | ||
raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Model name")) | ||
|
||
self.client = Client(**client_config) | ||
|
||
def system_message(self, message: str) -> any: | ||
""" | ||
Create a system message. | ||
Parameters: | ||
message (str): The message parameter. | ||
Returns: | ||
any | ||
""" | ||
return {"role": "system", "content": message} | ||
|
||
def user_message(self, message: str) -> any: | ||
""" | ||
Create a user message. | ||
Parameters: | ||
message (str): The message parameter. | ||
Returns: | ||
any | ||
""" | ||
return {"role": "user", "content": message} | ||
|
||
def assistant_message(self, message: str) -> any: | ||
""" | ||
Create an assistant message. | ||
Parameters: | ||
message (str): The message parameter. | ||
Returns: | ||
any | ||
""" | ||
return {"role": "assistant", "content": message} | ||
|
||
def invoke(self, prompt, **kwargs) -> str: | ||
""" | ||
Submit a prompt to the model for generating a response. | ||
Parameters: | ||
prompt (str): The prompt parameter. | ||
**kwargs: Additional keyword arguments (optional). | ||
- temperature (float): The temperature parameter for controlling randomness in generation. | ||
Returns: | ||
str | ||
""" | ||
if not prompt: | ||
raise ValueError(PROMPT_EMPTY_EXCEPTION) | ||
|
||
model = self.model_config.get('model') | ||
temperature = kwargs.get('temperature', 0.1) | ||
|
||
response = self.client.chat( | ||
model=model, | ||
messages=[self.user_message(prompt)], | ||
options=Options( | ||
temperature=temperature | ||
) | ||
) | ||
|
||
return response['message']['content'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import unittest | ||
from unittest.mock import MagicMock, patch | ||
from ollama import Client, Options | ||
|
||
from mindsql.llms import ILlm | ||
from mindsql.llms import Ollama | ||
from mindsql._utils.constants import PROMPT_EMPTY_EXCEPTION, OLLAMA_CONFIG_REQUIRED | ||
|
||
|
||
class TestOllama(unittest.TestCase): | ||
|
||
def setUp(self): | ||
# Common setup for each test case | ||
self.model_config = {'model': 'sqlcoder'} | ||
self.client_config = {'host': 'http://localhost:11434/'} | ||
self.client_mock = MagicMock(spec=Client) | ||
|
||
def test_initialization_with_client(self): | ||
ollama = Ollama(model_config=self.model_config, client=self.client_mock) | ||
self.assertEqual(ollama.client, self.client_mock) | ||
self.assertIsNone(ollama.client_config) | ||
self.assertEqual(ollama.model_config, self.model_config) | ||
|
||
def test_initialization_with_client_config(self): | ||
ollama = Ollama(model_config=self.model_config, client_config=self.client_config) | ||
self.assertIsNotNone(ollama.client) | ||
self.assertEqual(ollama.client_config, self.client_config) | ||
self.assertEqual(ollama.model_config, self.model_config) | ||
|
||
def test_initialization_missing_client_and_client_config(self): | ||
with self.assertRaises(ValueError) as context: | ||
Ollama(model_config=self.model_config) | ||
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Client")) | ||
|
||
def test_initialization_missing_model_config(self): | ||
with self.assertRaises(ValueError) as context: | ||
Ollama(model_config=None, client_config=self.client_config) | ||
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model")) | ||
|
||
def test_initialization_missing_model_name(self): | ||
with self.assertRaises(ValueError) as context: | ||
Ollama(model_config={}, client_config=self.client_config) | ||
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model name")) | ||
|
||
def test_system_message(self): | ||
ollama = Ollama(model_config=self.model_config, client=self.client_mock) | ||
message = ollama.system_message("Test system message") | ||
self.assertEqual(message, {"role": "system", "content": "Test system message"}) | ||
|
||
def test_user_message(self): | ||
ollama = Ollama(model_config=self.model_config, client=self.client_mock) | ||
message = ollama.user_message("Test user message") | ||
self.assertEqual(message, {"role": "user", "content": "Test user message"}) | ||
|
||
def test_assistant_message(self): | ||
ollama = Ollama(model_config=self.model_config, client=self.client_mock) | ||
message = ollama.assistant_message("Test assistant message") | ||
self.assertEqual(message, {"role": "assistant", "content": "Test assistant message"}) | ||
|
||
@patch.object(Client, 'chat', return_value={'message': {'content': 'Test response'}}) | ||
def test_invoke_success(self, mock_chat): | ||
ollama = Ollama(model_config=self.model_config, client=Client()) | ||
response = ollama.invoke("Test prompt") | ||
|
||
# Check if the response is as expected | ||
self.assertEqual(response, 'Test response') | ||
|
||
# Verify that the chat method was called with the correct arguments | ||
mock_chat.assert_called_once_with( | ||
model=self.model_config['model'], | ||
messages=[{"role": "user", "content": "Test prompt"}], | ||
options=Options(temperature=0.1) | ||
) | ||
|
||
def test_invoke_empty_prompt(self): | ||
ollama = Ollama(model_config=self.model_config, client=self.client_mock) | ||
with self.assertRaises(ValueError) as context: | ||
ollama.invoke("") | ||
self.assertEqual(str(context.exception), PROMPT_EMPTY_EXCEPTION) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |