From 61c14ee890154127c9a5396ee809f25083779df7 Mon Sep 17 00:00:00 2001 From: Braelyn Boynton Date: Tue, 13 Aug 2024 15:10:55 -0700 Subject: [PATCH] ollama provider --- agentops/llms/__init__.py | 90 +++++----- agentops/llms/ollama.py | 154 +++++++++--------- .../providers/ollama_canary.py | 24 +++ 3 files changed, 147 insertions(+), 121 deletions(-) create mode 100644 tests/core_manual_tests/providers/ollama_canary.py diff --git a/agentops/llms/__init__.py b/agentops/llms/__init__.py index c846536f..229cd51d 100644 --- a/agentops/llms/__init__.py +++ b/agentops/llms/__init__.py @@ -12,12 +12,7 @@ from .cohere import CohereProvider from .groq import GroqProvider from .litellm import override_litellm_completion, override_litellm_async_completion -from .ollama import ( - override_ollama_chat, - override_ollama_chat_client, - override_ollama_chat_async_client, - undo_override_ollama, -) +from .ollama import OllamaProvider from .openai import OpenAiProvider original_func = {} @@ -48,42 +43,42 @@ def __init__(self, client): self.client = client self.completion = "" - def _override_openai_v0_method(self, api, method_path, module): - def handle_response(result, kwargs, init_timestamp): - if api == "openai": - return handle_response_v0_openai(self, result, kwargs, init_timestamp) - return result - - def wrap_method(original_method): - if inspect.iscoroutinefunction(original_method): - - @functools.wraps(original_method) - async def async_method(*args, **kwargs): - init_timestamp = get_ISO_time() - response = await original_method(*args, **kwargs) - return handle_response(response, kwargs, init_timestamp) - - return async_method - - else: - - @functools.wraps(original_method) - def sync_method(*args, **kwargs): - init_timestamp = get_ISO_time() - response = original_method(*args, **kwargs) - return handle_response(response, kwargs, init_timestamp) - - return sync_method - - method_parts = method_path.split(".") - original_method = functools.reduce(getattr, method_parts, module) - new_method = wrap_method(original_method) - - if len(method_parts) == 1: - setattr(module, method_parts[0], new_method) - else: - parent = functools.reduce(getattr, method_parts[:-1], module) - setattr(parent, method_parts[-1], new_method) + # def _override_openai_v0_method(self, api, method_path, module): + # def handle_response(result, kwargs, init_timestamp): + # if api == "openai": + # return handle_response_v0_openai(self, result, kwargs, init_timestamp) + # return result + # + # def wrap_method(original_method): + # if inspect.iscoroutinefunction(original_method): + # + # @functools.wraps(original_method) + # async def async_method(*args, **kwargs): + # init_timestamp = get_ISO_time() + # response = await original_method(*args, **kwargs) + # return handle_response(response, kwargs, init_timestamp) + # + # return async_method + # + # else: + # + # @functools.wraps(original_method) + # def sync_method(*args, **kwargs): + # init_timestamp = get_ISO_time() + # response = original_method(*args, **kwargs) + # return handle_response(response, kwargs, init_timestamp) + # + # return sync_method + # + # method_parts = method_path.split(".") + # original_method = functools.reduce(getattr, method_parts, module) + # new_method = wrap_method(original_method) + # + # if len(method_parts) == 1: + # setattr(module, method_parts[0], new_method) + # else: + # parent = functools.reduce(getattr, method_parts[:-1], module) + # setattr(parent, method_parts[-1], new_method) def override_api(self): """ @@ -143,9 +138,8 @@ def override_api(self): module_version = version(api) if Version(module_version) >= parse("0.0.1"): - override_ollama_chat(self) - override_ollama_chat_client(self) - override_ollama_chat_async_client(self) + provider = OllamaProvider(self.client) + provider.override() else: logger.warning( f"Only Ollama>=0.0.1 supported. v{module_version} found." @@ -166,4 +160,8 @@ def stop_instrumenting(self): openai_provider = OpenAiProvider(self.client) openai_provider.undo_override() - undo_override_ollama(self) + groq_provider = GroqProvider(self.client) + groq_provider.undo_override() + + cohere_provider = CohereProvider(self.client) + cohere_provider.undo_override() diff --git a/agentops/llms/ollama.py b/agentops/llms/ollama.py index 89508c84..bcf753aa 100644 --- a/agentops/llms/ollama.py +++ b/agentops/llms/ollama.py @@ -2,111 +2,115 @@ import sys from typing import Optional -from ..event import ActionEvent, ErrorEvent, LLMEvent +from ..event import LLMEvent from ..session import Session -from ..log_config import logger from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id +from .instrumented_provider import InstrumentedProvider original_func = {} original_create = None original_create_async = None -def override_ollama_chat(tracker): - import ollama +class OllamaProvider(InstrumentedProvider): + def handle_response( + self, response, kwargs, init_timestamp, session: Optional[Session] = None + ) -> dict: + self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) - original_func["ollama.chat"] = ollama.chat + def handle_stream_chunk(chunk: dict): + message = chunk.get("message", {"role": None, "content": ""}) - def patched_function(*args, **kwargs): - # Call the original function with its original arguments - init_timestamp = get_ISO_time() - result = original_func["ollama.chat"](*args, **kwargs) - return tracker._handle_response_ollama( - result, kwargs, init_timestamp, session=kwargs.get("session", None) - ) + if chunk.get("done"): + self.llm_event.completion["content"] += message.get("content") + self.llm_event.end_timestamp = get_ISO_time() + self.llm_event.model = f'ollama/{chunk.get("model")}' + self.llm_event.returns = chunk + self.llm_event.returns["message"] = self.llm_event.completion + self.llm_event.prompt = kwargs["messages"] + self.llm_event.agent_id = check_call_stack_for_agent_id() + self.client.record(self.llm_event) - # Override the original method with the patched one - ollama.chat = patched_function + if self.llm_event.completion is None: + self.llm_event.completion = message + else: + self.llm_event.completion["content"] += message.get("content") + if inspect.isgenerator(response): -def override_ollama_chat_client(tracker): - from ollama import Client + def generator(): + for chunk in response: + handle_stream_chunk(chunk) + yield chunk - original_func["ollama.Client.chat"] = Client.chat + return generator() - def patched_function(*args, **kwargs): - # Call the original function with its original arguments - init_timestamp = get_ISO_time() - result = original_func["ollama.Client.chat"](*args, **kwargs) - return _handle_response_ollama(result, kwargs, init_timestamp) + self.llm_event.end_timestamp = get_ISO_time() - # Override the original method with the patched one - Client.chat = patched_function + self.llm_event.model = f'ollama/{response["model"]}' + self.llm_event.returns = response + self.llm_event.agent_id = check_call_stack_for_agent_id() + self.llm_event.prompt = kwargs["messages"] + self.llm_event.completion = response["message"] + self._safe_record(session, self.llm_event) + return response -def override_ollama_chat_async_client(tracker): - from ollama import AsyncClient + def override(self): + self._override_chat_client() + self._override_chat() - original_func["ollama.AsyncClient.chat"] = AsyncClient.chat + def undo_override(self): + if "ollama" in sys.modules: + import ollama - async def patched_function(*args, **kwargs): - # Call the original function with its original arguments - init_timestamp = get_ISO_time() - result = await original_func["ollama.AsyncClient.chat"](*args, **kwargs) - return _handle_response_ollama(result, kwargs, init_timestamp) + ollama.chat = original_func["ollama.chat"] + ollama.Client.chat = original_func["ollama.Client.chat"] + ollama.AsyncClient.chat = original_func["ollama.AsyncClient.chat"] - # Override the original method with the patched one - AsyncClient.chat = patched_function + def __init__(self, client): + super().__init__(client) + def _override_chat(self): + import ollama -def _handle_response_ollama( - tracker, response, kwargs, init_timestamp, session: Optional[Session] = None -) -> None: - tracker.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) - - def handle_stream_chunk(chunk: dict): - message = chunk.get("message", {"role": None, "content": ""}) - - if chunk.get("done"): - tracker.llm_event.completion["content"] += message.get("content") - tracker.llm_event.end_timestamp = get_ISO_time() - tracker.llm_event.model = f'ollama/{chunk.get("model")}' - tracker.llm_event.returns = chunk - tracker.llm_event.returns["message"] = tracker.llm_event.completion - tracker.llm_event.prompt = kwargs["messages"] - tracker.llm_event.agent_id = check_call_stack_for_agent_id() - tracker.client.record(tracker.llm_event) + original_func["ollama.chat"] = ollama.chat - if tracker.llm_event.completion is None: - tracker.llm_event.completion = message - else: - tracker.llm_event.completion["content"] += message.get("content") + def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + result = original_func["ollama.chat"](*args, **kwargs) + return self.handle_response( + result, kwargs, init_timestamp, session=kwargs.get("session", None) + ) - if inspect.isgenerator(response): + # Override the original method with the patched one + ollama.chat = patched_function - def generator(): - for chunk in response: - handle_stream_chunk(chunk) - yield chunk + def _override_chat_client(self): + from ollama import Client - return generator() + original_func["ollama.Client.chat"] = Client.chat - tracker.llm_event.end_timestamp = get_ISO_time() + def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + result = original_func["ollama.Client.chat"](*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp) - tracker.llm_event.model = f'ollama/{response["model"]}' - tracker.llm_event.returns = response - tracker.llm_event.agent_id = check_call_stack_for_agent_id() - tracker.llm_event.prompt = kwargs["messages"] - tracker.llm_event.completion = response["message"] + # Override the original method with the patched one + Client.chat = patched_function - safe_record(session, tracker.llm_event) - return response + def _override_chat_async_client(self): + from ollama import AsyncClient + original_func["ollama.AsyncClient.chat"] = AsyncClient.chat -def undo_override_ollama(tracker): - if "ollama" in sys.modules: - import ollama + async def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + result = await original_func["ollama.AsyncClient.chat"](*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp) - ollama.chat = original_func["ollama.chat"] - ollama.Client.chat = original_func["ollama.Client.chat"] - ollama.AsyncClient.chat = original_func["ollama.AsyncClient.chat"] + # Override the original method with the patched one + AsyncClient.chat = patched_function diff --git a/tests/core_manual_tests/providers/ollama_canary.py b/tests/core_manual_tests/providers/ollama_canary.py new file mode 100644 index 00000000..91b08493 --- /dev/null +++ b/tests/core_manual_tests/providers/ollama_canary.py @@ -0,0 +1,24 @@ +import agentops +from dotenv import load_dotenv +import ollama + +load_dotenv() +agentops.init(default_tags=["ollama-provider-test"]) + +response = ollama.chat( + model="llama3.1", + messages=[ + { + "role": "user", + "content": "Why is the sky blue?", + }, + ], +) +print(response) +print(response["message"]["content"]) + +agentops.end_session(end_state="Success") + +### +# Used to verify that one session is created with one LLM event +###