Skip to content

Commit

Permalink
response.on_done() mechanism, closes #653
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Dec 1, 2024
1 parent 335b3e6 commit f9af563
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 15 deletions.
41 changes: 41 additions & 0 deletions docs/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,47 @@ response = conversation.prompt(

Access `conversation.responses` for a list of all of the responses that have so far been returned during the conversation.

## Running code when a response has completed

For some applications, such as tracking the tokens used by an application, it may be useful to execute code as soon as a response has finished being executed

You can do this using the `response.on_done(callback)` method, which causes your callback function to be called as soon as the response has finished (all tokens have been returned).

The signature of the method you provide is `def callback(response)` - it can be optionally an `async def` method when working with asynchronous models.

Example usage:

```python
import llm

model = llm.get_model("gpt-4o-mini")
response = model.prompt("a poem about a hippo")
response.on_done(lambda response: print(response.usage()))
print(response.text())
```
Which outputs:
```
Usage(input=20, output=494, details={})
In a sunlit glade by a bubbling brook,
Lived a hefty hippo, with a curious look.
...
```
Or using an `asyncio` model, where you need to `await response.on_done(done)` to queue up the callback:
```python
import asyncio, llm

async def run():
model = llm.get_async_model("gpt-4o-mini")
response = model.prompt("a short poem about a brick")
async def done(response):
print(await response.usage())
print(await response.text())
await response.on_done(done)
print(await response.text())

asyncio.run(run())
```

## Other functions

The `llm` top level package includes some useful utility functions.
Expand Down
31 changes: 31 additions & 0 deletions llm/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import base64
from dataclasses import dataclass, field
import datetime
Expand All @@ -10,6 +11,7 @@
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Iterable,
Iterator,
Expand Down Expand Up @@ -218,6 +220,7 @@ def __init__(
self.input_tokens: Optional[int] = None
self.output_tokens: Optional[int] = None
self.token_details: Optional[dict] = None
self.done_callbacks: List[Callable] = []

def set_usage(
self,
Expand Down Expand Up @@ -336,6 +339,16 @@ class Response(_BaseResponse):
model: "Model"
conversation: Optional["Conversation"] = None

def on_done(self, callback):
if not self._done:
self.done_callbacks.append(callback)
else:
callback(self)

def _on_done(self):
for callback in self.done_callbacks:
callback(self)

def __str__(self) -> str:
return self.text()

Expand Down Expand Up @@ -390,6 +403,7 @@ def __iter__(self) -> Iterator[str]:
self.conversation.responses.append(self)
self._end = time.monotonic()
self._done = True
self._on_done()

def __repr__(self):
text = "... not yet done ..."
Expand All @@ -402,6 +416,22 @@ class AsyncResponse(_BaseResponse):
model: "AsyncModel"
conversation: Optional["AsyncConversation"] = None

async def on_done(self, callback):
if not self._done:
self.done_callbacks.append(callback)
else:
if callable(callback):
callback = callback(self)
if asyncio.iscoroutine(callback):
await callback

async def _on_done(self):
for callback in self.done_callbacks:
if callable(callback):
callback = callback(self)
if asyncio.iscoroutine(callback):
await callback

def __aiter__(self):
self._start = time.monotonic()
self._start_utcnow = datetime.datetime.utcnow()
Expand Down Expand Up @@ -433,6 +463,7 @@ async def __anext__(self) -> str:
self.conversation.responses.append(self)
self._end = time.monotonic()
self._done = True
await self._on_done()
raise

async def _force(self):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,19 @@ async def test_async_model_conversation(async_mock_model):
response2 = await conversation.prompt("again")
text2 = await response2.text()
assert text2 == "joke 2"


@pytest.mark.asyncio
async def test_async_on_done(async_mock_model):
async_mock_model.enqueue(["hello world"])
response = await async_mock_model.prompt(prompt="hello")
caught = []

def done(response):
caught.append(response)

assert len(caught) == 0
await response.on_done(done)
await response.text()
assert response._done
assert len(caught) == 1
15 changes: 0 additions & 15 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,10 @@
from click.testing import CliRunner
from llm.models import Usage
import llm.cli
from unittest.mock import ANY
import pytest
import sys


def test_mock_model(mock_model):
mock_model.enqueue(["hello world"])
mock_model.enqueue(["second"])
model = llm.get_model("mock")
response = model.prompt(prompt="hello")
assert response.text() == "hello world"
assert str(response) == "hello world"
assert model.history[0][0].prompt == "hello"
assert response.usage() == Usage(input=1, output=1, details=None)
response2 = model.prompt(prompt="hello again")
assert response2.text() == "second"
assert response2.usage() == Usage(input=2, output=1, details=None)


@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
def test_chat_basic(mock_model, logs_db):
runner = CliRunner()
Expand Down
30 changes: 30 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import llm
from llm.cli import cli
from llm.migrations import migrate
from llm.models import Usage
import json
import os
import pathlib
Expand Down Expand Up @@ -610,3 +611,32 @@ def test_get_async_models():
assert all(isinstance(model, llm.AsyncModel) for model in models)
model_ids = [model.model_id for model in models]
assert "gpt-4o-mini" in model_ids


def test_mock_model(mock_model):
mock_model.enqueue(["hello world"])
mock_model.enqueue(["second"])
model = llm.get_model("mock")
response = model.prompt(prompt="hello")
assert response.text() == "hello world"
assert str(response) == "hello world"
assert model.history[0][0].prompt == "hello"
assert response.usage() == Usage(input=1, output=1, details=None)
response2 = model.prompt(prompt="hello again")
assert response2.text() == "second"
assert response2.usage() == Usage(input=2, output=1, details=None)


def test_sync_on_done(mock_model):
mock_model.enqueue(["hello world"])
model = llm.get_model("mock")
response = model.prompt(prompt="hello")
caught = []

def done(response):
caught.append(response)

response.on_done(done)
assert len(caught) == 0
str(response)
assert len(caught) == 1

0 comments on commit f9af563

Please sign in to comment.