Skip to content

feat: Implement global event handlers #64

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ repos:

# Enforces a coding standard, looks for code smells, and can make suggestions about how the code could be refactored.
- repo: https://github.com/pycqa/pylint
rev: v3.0.1
rev: v3.1.0
hooks:
- id: pylint
exclude: (/test_|tests/|docs/)
Expand Down
2 changes: 1 addition & 1 deletion docs/how-to/visualize_views.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Define collection with implemented views
```python
llm = LiteLLM(model_name="gpt-3.5-turbo")
await country_similarity.update()
collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()])
collection = dbally.create_collection("recruitment", llm)
collection.add(CandidateView, lambda: CandidateView(engine))
collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine()))
```
Expand Down
5 changes: 4 additions & 1 deletion docs/quickstart/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,15 @@ Replace `...` with your OpenAI API key. Alternatively, you can set the `OPENAI_A
## Collection Definition

Next, create a db-ally collection. A [collection](../concepts/collections.md) is an object where you register views and execute queries. It also requires an AI model to use for generating [IQL queries](../concepts/iql.md) (in this case, the GPT model defined above).
The collection could have its own event handlers which override the globally defined handlers.

```python
import dbally
from dbally.audit import CLIEventHandler


async def main():
collection = dbally.create_collection("recruitment", llm)
collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler])
collection.add(CandidateView, lambda: CandidateView(engine))
```

Expand Down
3 changes: 2 additions & 1 deletion docs/quickstart/quickstart2_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem


async def main():
dbally.event_handlers = [CLIEventHandler()]
await country_similarity.update()

llm = LiteLLM(model_name="gpt-3.5-turbo")
collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()])
collection = dbally.create_collection("recruitment", llm)
collection.add(CandidateView, lambda: CandidateView(engine))

result = await collection.ask("Find someone from the United States with more than 2 years of experience.")
Expand Down
4 changes: 3 additions & 1 deletion docs/quickstart/quickstart3_code.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring
import dbally
import os
import asyncio
from typing_extensions import Annotated
Expand All @@ -9,7 +8,9 @@
from sqlalchemy.ext.automap import automap_base
import pandas as pd

import dbally
from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult
from dbally.audit import CLIEventHandler
from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex
from dbally.embeddings.litellm import LiteLLMEmbeddingClient
from dbally.llms.litellm import LiteLLM
Expand Down Expand Up @@ -125,6 +126,7 @@ def display_results(result: ExecutionResult):


async def main():
dbally.event_handlers = [CLIEventHandler()]
await country_similarity.update()

llm = LiteLLM(model_name="gpt-3.5-turbo")
Expand Down
4 changes: 3 additions & 1 deletion docs/quickstart/quickstart_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.automap import automap_base

import dbally
from dbally import decorators, SqlAlchemyBaseView
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.llms.litellm import LiteLLM
Expand Down Expand Up @@ -57,8 +58,9 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement:

async def main():
llm = LiteLLM(model_name="gpt-3.5-turbo")
dbally.event_handlers = [CLIEventHandler()]

collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()])
collection = dbally.create_collection("recruitment", llm)
collection.add(CandidateView, lambda: CandidateView(engine))

result = await collection.ask("Find me French candidates suitable for a senior data scientist position.")
Expand Down
4 changes: 3 additions & 1 deletion examples/visualize_views_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@

import dbally
from dbally.audit import CLIEventHandler
from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler
from dbally.gradio import create_gradio_interface
from dbally.llms.litellm import LiteLLM


async def main():
await country_similarity.update()
llm = LiteLLM(model_name="gpt-3.5-turbo")
collection = dbally.create_collection("candidates", llm, event_handlers=[CLIEventHandler()])
dbally.event_handlers = [CLIEventHandler(), BufferEventHandler()]
collection = dbally.create_collection("candidates", llm)
collection.add(CandidateView, lambda: CandidateView(engine))
collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine()))
gradio_interface = await create_gradio_interface(user_collection=collection)
Expand Down
20 changes: 13 additions & 7 deletions src/dbally/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" dbally """

from dbally.collection.collection import Collection
from typing import TYPE_CHECKING, List

from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError
from dbally.collection.results import ExecutionResult
from dbally.views import decorators
Expand All @@ -21,29 +22,34 @@
from .exceptions import DbAllyError
from .llms.clients.exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError

if TYPE_CHECKING:
from .audit import EventHandler

event_handlers: List["EventHandler"] = []

__all__ = [
"__version__",
"create_collection",
"decorators",
"MethodsBaseView",
"SqlAlchemyBaseView",
"Collection",
"event_handlers",
"BaseStructuredView",
"DataFrameBaseView",
"ExecutionResult",
"DbAllyError",
"ExecutionResult",
"EmbeddingError",
"EmbeddingConnectionError",
"EmbeddingResponseError",
"EmbeddingStatusError",
"IndexUpdateError",
"LLMError",
"LLMConnectionError",
"LLMResponseError",
"LLMStatusError",
"NoViewFoundError",
"IndexUpdateError",
"MethodsBaseView",
"NotGiven",
"NOT_GIVEN",
"NoViewFoundError",
"SqlAlchemyBaseView",
]

# Update the __module__ attribute for exported symbols so that
Expand Down
28 changes: 16 additions & 12 deletions src/dbally/_main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional

from .audit.event_handlers.base import EventHandler
from .collection import Collection
from .llms import LLM
from .nl_responder.nl_responder import NLResponder
from .view_selection.base import ViewSelector
from .view_selection.llm_view_selector import LLMViewSelector
from dbally.audit import EventHandler
from dbally.llms import LLM
from dbally.nl_responder.nl_responder import NLResponder
from dbally.view_selection import LLMViewSelector
from dbally.view_selection.base import ViewSelector

if TYPE_CHECKING:
from dbally.collection import Collection


def create_collection(
Expand All @@ -14,7 +16,7 @@ def create_collection(
event_handlers: Optional[List[EventHandler]] = None,
view_selector: Optional[ViewSelector] = None,
nl_responder: Optional[NLResponder] = None,
) -> Collection:
) -> "Collection":
"""
Create a new [Collection](collection.md) that is a container for registering views and the\
main entrypoint to db-ally features.
Expand All @@ -38,22 +40,24 @@ def create_collection(
llm: LLM used by the collection to generate responses for natural language queries.
event_handlers: Event handlers used by the collection during query executions. Can be used to\
log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance as\
[LangSmithEventHandler](event_handlers/langsmith_handler.md).
[LangSmithEventHandler](event_handlers/langsmith_handler.md). If provided, this parameter overrides the
global dbally.event_handlers.
view_selector: View selector used by the collection to select the best view for the given query.\
If None, a new instance of [LLMViewSelector][dbally.view_selection.llm_view_selector.LLMViewSelector]\
will be used.
nl_responder: NL responder used by the collection to respond to natural language queries. If None,\
a new instance of [NLResponder][dbally.nl_responder.nl_responder.NLResponder] will be used.

Returns:
a new instance of db-ally Collection
New instance of db-ally Collection.

Raises:
ValueError: if default LLM client is not configured
ValueError: If default LLM client is not configured.
"""
from dbally.collection import Collection # pylint: disable=import-outside-toplevel

view_selector = view_selector or LLMViewSelector(llm=llm)
nl_responder = nl_responder or NLResponder(llm=llm)
event_handlers = event_handlers or []

return Collection(
name,
Expand Down
28 changes: 28 additions & 0 deletions src/dbally/audit/event_handlers/buffer_event_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from io import StringIO

from rich.console import Console

from dbally.audit import CLIEventHandler


class BufferEventHandler(CLIEventHandler):
"""
This handler stores in buffer all interactions between LLM and user happening during `Collection.ask`\
execution.

### Usage

```python
import dbally
from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler

dbally.event_handlers = [BufferEventHandler()]
my_collection = dbally.create_collection("my_collection", llm)
```
"""

def __init__(self) -> None:
super().__init__()

self.buffer = StringIO()
self._console = Console(file=self.buffer, record=True)
17 changes: 8 additions & 9 deletions src/dbally/audit/event_handlers/cli_event_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import re
from io import StringIO
from sys import stdout
from typing import Optional

try:
Expand All @@ -24,28 +22,26 @@
class CLIEventHandler(EventHandler):
"""
This handler displays all interactions between LLM and user happening during `Collection.ask`\
execution inside the terminal or store them in the given buffer.
execution inside the terminal.

### Usage

```python
import dbally
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler

my_collection = dbally.create_collection("my_collection", llm, event_handlers=[CLIEventHandler()])
dbally.event_handlers = [CLIEventHandler()]
my_collection = dbally.create_collection("my_collection", llm)
```

After using `CLIEventHandler`, during every `Collection.ask` execution you will see output similar to the one below:

![Example output from CLIEventHandler](../../assets/event_handler_example.png)
"""

def __init__(self, buffer: Optional[StringIO] = None) -> None:
def __init__(self) -> None:
super().__init__()

self.buffer = buffer
out = self.buffer if buffer else stdout
self._console = Console(file=out, record=True) if RICH_OUTPUT else None
self._console = Console(record=True) if RICH_OUTPUT else None

def _print_syntax(self, content: str, lexer: Optional[str] = None) -> None:
if self._console:
Expand All @@ -69,6 +65,7 @@ async def request_start(self, user_request: RequestStart) -> None:
self._print_syntax("[grey53]\n=======================================")
self._print_syntax("[grey53]=======================================\n")

# pylint: disable=unused-argument
async def event_start(self, event: Event, request_context: None) -> None:
"""
Displays information that event has started, then all messages inside the prompt
Expand Down Expand Up @@ -98,6 +95,7 @@ async def event_start(self, event: Event, request_context: None) -> None:
f"[cyan bold]FETCHER: [grey53]{event.fetcher}\n"
)

# pylint: disable=unused-argument
async def event_end(self, event: Optional[Event], request_context: None, event_context: None) -> None:
"""
Displays the response from the LLM.
Expand All @@ -116,6 +114,7 @@ async def event_end(self, event: Optional[Event], request_context: None, event_c
self._print_syntax("[grey53]\n=======================================")
self._print_syntax("[grey53]=======================================\n")

# pylint: disable=unused-argument
async def request_end(self, output: RequestEnd, request_context: Optional[dict] = None) -> None:
"""
Displays the output of the request, namely the `results` and the `context`
Expand Down
1 change: 1 addition & 0 deletions src/dbally/audit/event_handlers/langsmith_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ async def event_start(self, event: Event, request_context: RunTree) -> RunTree:

raise ValueError("Unsupported event")

# pylint: disable=unused-argument
async def event_end(self, event: Optional[Event], request_context: RunTree, event_context: RunTree) -> None:
"""
Log the end of the event.
Expand Down
5 changes: 5 additions & 0 deletions src/dbally/audit/event_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@ def initialize_with_handlers(cls, event_handlers: List[EventHandler]) -> "EventT

Returns:
The initialized event store.

Raises:
ValueError: if invalid event handler object is passed as argument.
"""

instance = cls()

for handler in event_handlers:
if not isinstance(handler, EventHandler):
raise ValueError(f"Could not register {handler}. Handler must be instance of EvenHandler type")
instance.subscribe(handler)

return instance
Expand Down
Loading