Skip to content

Commit

Permalink
Merge pull request #19 from cyrannano/ollama-llm-integration
Browse files Browse the repository at this point in the history
Ollama llm integration
  • Loading branch information
Sammindinventory authored Aug 23, 2024
2 parents 8c6a90a + 1d9197e commit b95ca73
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mindsql/_utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@
OPENAI_VALUE_ERROR = "OpenAI API key is required"
PROMPT_EMPTY_EXCEPTION = "Prompt cannot be empty."
POSTGRESQL_SHOW_CREATE_TABLE_QUERY = """SELECT 'CREATE TABLE "' || table_name || '" (' || array_to_string(array_agg(column_name || ' ' || data_type), ', ') || ');' AS create_statement FROM information_schema.columns WHERE table_name = '{table}' GROUP BY table_name;"""
ANTHROPIC_VALUE_ERROR = "Anthropic API key is required"
ANTHROPIC_VALUE_ERROR = "Anthropic API key is required"
OLLAMA_CONFIG_REQUIRED = "{type} configuration is required."
105 changes: 105 additions & 0 deletions mindsql/llms/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from ollama import Client, Options

from .illm import ILlm
from .._utils.constants import PROMPT_EMPTY_EXCEPTION, OLLAMA_CONFIG_REQUIRED
from .._utils import logger

log = logger.init_loggers("Ollama Client")


class Ollama(ILlm):
def __init__(self, model_config: dict, client_config=None, client: Client = None):
"""
Initialize the class with an optional config parameter.
Parameters:
model_config (dict): The model configuration parameter.
config (dict): The configuration parameter.
client (Client): The client parameter.
Returns:
None
"""
self.client = client
self.client_config = client_config
self.model_config = model_config

if self.client is not None:
if self.client_config is not None:
log.warning("Client object provided. Ignoring client_config parameter.")
return

if client_config is None:
raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Client"))

if model_config is None:
raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Model"))

if 'model' not in model_config:
raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Model name"))

self.client = Client(**client_config)

def system_message(self, message: str) -> any:
"""
Create a system message.
Parameters:
message (str): The message parameter.
Returns:
any
"""
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
"""
Create a user message.
Parameters:
message (str): The message parameter.
Returns:
any
"""
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
"""
Create an assistant message.
Parameters:
message (str): The message parameter.
Returns:
any
"""
return {"role": "assistant", "content": message}

def invoke(self, prompt, **kwargs) -> str:
"""
Submit a prompt to the model for generating a response.
Parameters:
prompt (str): The prompt parameter.
**kwargs: Additional keyword arguments (optional).
- temperature (float): The temperature parameter for controlling randomness in generation.
Returns:
str
"""
if not prompt:
raise ValueError(PROMPT_EMPTY_EXCEPTION)

model = self.model_config.get('model')
temperature = kwargs.get('temperature', 0.1)

response = self.client.chat(
model=model,
messages=[self.user_message(prompt)],
options=Options(
temperature=temperature
)
)

return response['message']['content']
83 changes: 83 additions & 0 deletions tests/ollama_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import unittest
from unittest.mock import MagicMock, patch
from ollama import Client, Options

from mindsql.llms import ILlm
from mindsql.llms import Ollama
from mindsql._utils.constants import PROMPT_EMPTY_EXCEPTION, OLLAMA_CONFIG_REQUIRED


class TestOllama(unittest.TestCase):

def setUp(self):
# Common setup for each test case
self.model_config = {'model': 'sqlcoder'}
self.client_config = {'host': 'http://localhost:11434/'}
self.client_mock = MagicMock(spec=Client)

def test_initialization_with_client(self):
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
self.assertEqual(ollama.client, self.client_mock)
self.assertIsNone(ollama.client_config)
self.assertEqual(ollama.model_config, self.model_config)

def test_initialization_with_client_config(self):
ollama = Ollama(model_config=self.model_config, client_config=self.client_config)
self.assertIsNotNone(ollama.client)
self.assertEqual(ollama.client_config, self.client_config)
self.assertEqual(ollama.model_config, self.model_config)

def test_initialization_missing_client_and_client_config(self):
with self.assertRaises(ValueError) as context:
Ollama(model_config=self.model_config)
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Client"))

def test_initialization_missing_model_config(self):
with self.assertRaises(ValueError) as context:
Ollama(model_config=None, client_config=self.client_config)
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model"))

def test_initialization_missing_model_name(self):
with self.assertRaises(ValueError) as context:
Ollama(model_config={}, client_config=self.client_config)
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model name"))

def test_system_message(self):
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
message = ollama.system_message("Test system message")
self.assertEqual(message, {"role": "system", "content": "Test system message"})

def test_user_message(self):
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
message = ollama.user_message("Test user message")
self.assertEqual(message, {"role": "user", "content": "Test user message"})

def test_assistant_message(self):
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
message = ollama.assistant_message("Test assistant message")
self.assertEqual(message, {"role": "assistant", "content": "Test assistant message"})

@patch.object(Client, 'chat', return_value={'message': {'content': 'Test response'}})
def test_invoke_success(self, mock_chat):
ollama = Ollama(model_config=self.model_config, client=Client())
response = ollama.invoke("Test prompt")

# Check if the response is as expected
self.assertEqual(response, 'Test response')

# Verify that the chat method was called with the correct arguments
mock_chat.assert_called_once_with(
model=self.model_config['model'],
messages=[{"role": "user", "content": "Test prompt"}],
options=Options(temperature=0.1)
)

def test_invoke_empty_prompt(self):
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
with self.assertRaises(ValueError) as context:
ollama.invoke("")
self.assertEqual(str(context.exception), PROMPT_EMPTY_EXCEPTION)


if __name__ == '__main__':
unittest.main()

0 comments on commit b95ca73

Please sign in to comment.