Skip to content

Commit

Permalink
ollama provider
Browse files Browse the repository at this point in the history
  • Loading branch information
bboynton97 committed Aug 13, 2024
1 parent 67ddeba commit 61c14ee
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 121 deletions.
90 changes: 44 additions & 46 deletions agentops/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
from .cohere import CohereProvider
from .groq import GroqProvider
from .litellm import override_litellm_completion, override_litellm_async_completion
from .ollama import (
override_ollama_chat,
override_ollama_chat_client,
override_ollama_chat_async_client,
undo_override_ollama,
)
from .ollama import OllamaProvider
from .openai import OpenAiProvider

original_func = {}
Expand Down Expand Up @@ -48,42 +43,42 @@ def __init__(self, client):
self.client = client
self.completion = ""

def _override_openai_v0_method(self, api, method_path, module):
def handle_response(result, kwargs, init_timestamp):
if api == "openai":
return handle_response_v0_openai(self, result, kwargs, init_timestamp)
return result

def wrap_method(original_method):
if inspect.iscoroutinefunction(original_method):

@functools.wraps(original_method)
async def async_method(*args, **kwargs):
init_timestamp = get_ISO_time()
response = await original_method(*args, **kwargs)
return handle_response(response, kwargs, init_timestamp)

return async_method

else:

@functools.wraps(original_method)
def sync_method(*args, **kwargs):
init_timestamp = get_ISO_time()
response = original_method(*args, **kwargs)
return handle_response(response, kwargs, init_timestamp)

return sync_method

method_parts = method_path.split(".")
original_method = functools.reduce(getattr, method_parts, module)
new_method = wrap_method(original_method)

if len(method_parts) == 1:
setattr(module, method_parts[0], new_method)
else:
parent = functools.reduce(getattr, method_parts[:-1], module)
setattr(parent, method_parts[-1], new_method)
# def _override_openai_v0_method(self, api, method_path, module):
# def handle_response(result, kwargs, init_timestamp):
# if api == "openai":
# return handle_response_v0_openai(self, result, kwargs, init_timestamp)
# return result
#
# def wrap_method(original_method):
# if inspect.iscoroutinefunction(original_method):
#
# @functools.wraps(original_method)
# async def async_method(*args, **kwargs):
# init_timestamp = get_ISO_time()
# response = await original_method(*args, **kwargs)
# return handle_response(response, kwargs, init_timestamp)
#
# return async_method
#
# else:
#
# @functools.wraps(original_method)
# def sync_method(*args, **kwargs):
# init_timestamp = get_ISO_time()
# response = original_method(*args, **kwargs)
# return handle_response(response, kwargs, init_timestamp)
#
# return sync_method
#
# method_parts = method_path.split(".")
# original_method = functools.reduce(getattr, method_parts, module)
# new_method = wrap_method(original_method)
#
# if len(method_parts) == 1:
# setattr(module, method_parts[0], new_method)
# else:
# parent = functools.reduce(getattr, method_parts[:-1], module)
# setattr(parent, method_parts[-1], new_method)

def override_api(self):
"""
Expand Down Expand Up @@ -143,9 +138,8 @@ def override_api(self):
module_version = version(api)

if Version(module_version) >= parse("0.0.1"):
override_ollama_chat(self)
override_ollama_chat_client(self)
override_ollama_chat_async_client(self)
provider = OllamaProvider(self.client)
provider.override()
else:
logger.warning(
f"Only Ollama>=0.0.1 supported. v{module_version} found."
Expand All @@ -166,4 +160,8 @@ def stop_instrumenting(self):
openai_provider = OpenAiProvider(self.client)
openai_provider.undo_override()

undo_override_ollama(self)
groq_provider = GroqProvider(self.client)
groq_provider.undo_override()

cohere_provider = CohereProvider(self.client)
cohere_provider.undo_override()
154 changes: 79 additions & 75 deletions agentops/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,111 +2,115 @@
import sys
from typing import Optional

from ..event import ActionEvent, ErrorEvent, LLMEvent
from ..event import LLMEvent
from ..session import Session
from ..log_config import logger
from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id
from .instrumented_provider import InstrumentedProvider

original_func = {}
original_create = None
original_create_async = None


def override_ollama_chat(tracker):
import ollama
class OllamaProvider(InstrumentedProvider):
def handle_response(
self, response, kwargs, init_timestamp, session: Optional[Session] = None
) -> dict:
self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)

original_func["ollama.chat"] = ollama.chat
def handle_stream_chunk(chunk: dict):
message = chunk.get("message", {"role": None, "content": ""})

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_func["ollama.chat"](*args, **kwargs)
return tracker._handle_response_ollama(
result, kwargs, init_timestamp, session=kwargs.get("session", None)
)
if chunk.get("done"):
self.llm_event.completion["content"] += message.get("content")
self.llm_event.end_timestamp = get_ISO_time()
self.llm_event.model = f'ollama/{chunk.get("model")}'
self.llm_event.returns = chunk
self.llm_event.returns["message"] = self.llm_event.completion
self.llm_event.prompt = kwargs["messages"]
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.client.record(self.llm_event)

# Override the original method with the patched one
ollama.chat = patched_function
if self.llm_event.completion is None:
self.llm_event.completion = message
else:
self.llm_event.completion["content"] += message.get("content")

if inspect.isgenerator(response):

def override_ollama_chat_client(tracker):
from ollama import Client
def generator():
for chunk in response:
handle_stream_chunk(chunk)
yield chunk

original_func["ollama.Client.chat"] = Client.chat
return generator()

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_func["ollama.Client.chat"](*args, **kwargs)
return _handle_response_ollama(result, kwargs, init_timestamp)
self.llm_event.end_timestamp = get_ISO_time()

# Override the original method with the patched one
Client.chat = patched_function
self.llm_event.model = f'ollama/{response["model"]}'
self.llm_event.returns = response
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.llm_event.prompt = kwargs["messages"]
self.llm_event.completion = response["message"]

self._safe_record(session, self.llm_event)
return response

def override_ollama_chat_async_client(tracker):
from ollama import AsyncClient
def override(self):
self._override_chat_client()
self._override_chat()

original_func["ollama.AsyncClient.chat"] = AsyncClient.chat
def undo_override(self):
if "ollama" in sys.modules:
import ollama

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = await original_func["ollama.AsyncClient.chat"](*args, **kwargs)
return _handle_response_ollama(result, kwargs, init_timestamp)
ollama.chat = original_func["ollama.chat"]
ollama.Client.chat = original_func["ollama.Client.chat"]
ollama.AsyncClient.chat = original_func["ollama.AsyncClient.chat"]

# Override the original method with the patched one
AsyncClient.chat = patched_function
def __init__(self, client):
super().__init__(client)

def _override_chat(self):
import ollama

def _handle_response_ollama(
tracker, response, kwargs, init_timestamp, session: Optional[Session] = None
) -> None:
tracker.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)

def handle_stream_chunk(chunk: dict):
message = chunk.get("message", {"role": None, "content": ""})

if chunk.get("done"):
tracker.llm_event.completion["content"] += message.get("content")
tracker.llm_event.end_timestamp = get_ISO_time()
tracker.llm_event.model = f'ollama/{chunk.get("model")}'
tracker.llm_event.returns = chunk
tracker.llm_event.returns["message"] = tracker.llm_event.completion
tracker.llm_event.prompt = kwargs["messages"]
tracker.llm_event.agent_id = check_call_stack_for_agent_id()
tracker.client.record(tracker.llm_event)
original_func["ollama.chat"] = ollama.chat

if tracker.llm_event.completion is None:
tracker.llm_event.completion = message
else:
tracker.llm_event.completion["content"] += message.get("content")
def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_func["ollama.chat"](*args, **kwargs)
return self.handle_response(
result, kwargs, init_timestamp, session=kwargs.get("session", None)
)

if inspect.isgenerator(response):
# Override the original method with the patched one
ollama.chat = patched_function

def generator():
for chunk in response:
handle_stream_chunk(chunk)
yield chunk
def _override_chat_client(self):
from ollama import Client

return generator()
original_func["ollama.Client.chat"] = Client.chat

tracker.llm_event.end_timestamp = get_ISO_time()
def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_func["ollama.Client.chat"](*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp)

tracker.llm_event.model = f'ollama/{response["model"]}'
tracker.llm_event.returns = response
tracker.llm_event.agent_id = check_call_stack_for_agent_id()
tracker.llm_event.prompt = kwargs["messages"]
tracker.llm_event.completion = response["message"]
# Override the original method with the patched one
Client.chat = patched_function

safe_record(session, tracker.llm_event)
return response
def _override_chat_async_client(self):
from ollama import AsyncClient

original_func["ollama.AsyncClient.chat"] = AsyncClient.chat

def undo_override_ollama(tracker):
if "ollama" in sys.modules:
import ollama
async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = await original_func["ollama.AsyncClient.chat"](*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp)

ollama.chat = original_func["ollama.chat"]
ollama.Client.chat = original_func["ollama.Client.chat"]
ollama.AsyncClient.chat = original_func["ollama.AsyncClient.chat"]
# Override the original method with the patched one
AsyncClient.chat = patched_function
24 changes: 24 additions & 0 deletions tests/core_manual_tests/providers/ollama_canary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import agentops
from dotenv import load_dotenv
import ollama

load_dotenv()
agentops.init(default_tags=["ollama-provider-test"])

response = ollama.chat(
model="llama3.1",
messages=[
{
"role": "user",
"content": "Why is the sky blue?",
},
],
)
print(response)
print(response["message"]["content"])

agentops.end_session(end_state="Success")

###
# Used to verify that one session is created with one LLM event
###

0 comments on commit 61c14ee

Please sign in to comment.