Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions langchain_benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from langchain_benchmarks.model_registration import model_registry
from langchain_benchmarks.rate_limiting import RateLimiter, with_rate_limit
from langchain_benchmarks.registration import registry
from langchain_benchmarks.utils._langsmith import (
clone_public_dataset,
Expand All @@ -10,5 +11,7 @@
"clone_public_dataset",
"download_public_dataset",
"model_registry",
"RateLimiter",
"registry",
"with_rate_limit",
]
109 changes: 109 additions & 0 deletions langchain_benchmarks/rate_limiting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Implementation of a rate limiter based on a token bucket."""
import threading
import time
from typing import Any, Optional

from langchain.schema.runnable import Runnable, RunnableLambda
from langchain.schema.runnable.utils import Input, Output


class RateLimiter:
def __init__(
self,
*,
requests_per_second: float = 1,
check_every_n_seconds: float = 0.1,
max_bucket_size: float = 1,
) -> None:
"""A rate limiter based on a token bucket.

These *tokens* have NOTHING to do with LLM tokens. They are just
a way to keep track of how many requests can be made at a given time.

This rate limiter is designed to work in a threaded environment.

It works by filling up a bucket with tokens at a given rate. Each
request consumes a given number of tokens. If there are not enough
tokens in the bucket, the request is blocked until there are enough
tokens.

Args:
requests_per_second: The number of tokens to add per second to the bucket.
Must be at least 1. The tokens represent "credit" that can be used
to make requests.
check_every_n_seconds: check whether the tokens are available
every this many seconds. Can be a float to represent
fractions of a second.
max_bucket_size: The maximum number of tokens that can be in the bucket.
This is used to prevent bursts of requests.
"""
# Number of requests that we can make per second.
self.requests_per_second = requests_per_second
# Number of tokens in the bucket.
self.available_tokens = 0.0
self.max_bucket_size = max_bucket_size
# A lock to ensure that tokens can only be consumed by one thread
# at a given time.
self._consume_lock = threading.Lock()
# The last time we tried to consume tokens.
self.last: Optional[time.time] = None
self.check_every_n_seconds = check_every_n_seconds

def _consume(self) -> bool:
"""Consume the given amount of tokens if possible.

Returns:
True means that the tokens were consumed, and the caller can proceed to
make the request. A False means that the tokens were not consumed, and
the caller should try again later.
"""
with self._consume_lock:
now = time.time()

# initialize on first call to avoid a burst
if self.last is None:
self.last = now

elapsed = now - self.last

if elapsed * self.requests_per_second >= 1:
self.available_tokens += elapsed * self.requests_per_second
self.last = now

# Make sure that we don't exceed the bucket size.
# This is used to prevent bursts of requests.
self.available_tokens = min(self.available_tokens, self.max_bucket_size)

# As long as we have at least one token, we can proceed.
if self.available_tokens >= 1:
self.available_tokens -= 1
return True

return False

def wait(self) -> None:
"""Blocking call to wait until the given number of tokens are available."""
while not self._consume():
time.sleep(self.check_every_n_seconds)


def with_rate_limit(
runnable: Runnable[Input, Output],
rate_limiter: RateLimiter,
) -> Runnable[Input, Output]:
"""Add a rate limiter to the runnable.

Args:
runnable: The runnable to throttle.
rate_limiter: The throttle to use.

Returns:
A runnable lambda that acts as a throttled passthrough.
"""

def _wait(input: dict, **kwargs: Any) -> dict:
"""Wait for the rate limiter to allow the request to proceed."""
rate_limiter.wait()
return input

return RunnableLambda(_wait).with_config({"name": "Wait"}) | runnable
20 changes: 16 additions & 4 deletions langchain_benchmarks/tool_usage/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from langchain.schema.runnable import Runnable, RunnableLambda, RunnablePassthrough
from langchain.tools.render import format_tool_to_openai_function

from langchain_benchmarks import rate_limiting, with_rate_limit
from langchain_benchmarks.schema import ToolUsageTask


Expand All @@ -24,33 +25,44 @@ def _ensure_output_exists(inputs: dict) -> dict:

class OpenAIAgentFactory:
def __init__(
self, task: ToolUsageTask, *, model: str = "gpt-3.5-turbo-16k"
self,
task: ToolUsageTask,
*,
model: str = "gpt-3.5-turbo-16k",
Copy link
Collaborator

@hinthornw hinthornw Dec 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooc: why do we need a default here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can remove in a separate PR -- probably shouldn't be here

rate_limiter: Optional[rate_limiting.RateLimiter] = None,
) -> None:
"""Create an OpenAI agent factory for the given task.

Args:
task: The task to create an agent factory for.
model: The model to use -- this must be an open AI model.
rate_limiter: The rate limiter to use
"""
self.task = task
self.model = model
self.rate_limiter = rate_limiter

def create(self) -> Runnable:
"""Agent Executor"""
# For backwards compatibility
return self()

def __call__(self) -> Runnable:
llm = ChatOpenAI(
model = ChatOpenAI(
model=self.model,
temperature=0,
)

env = self.task.create_environment()

llm_with_tools = llm.bind(
model = model.bind(
functions=[format_tool_to_openai_function(t) for t in env.tools]
)

if rate_limiting:
# Rate limited model
model = with_rate_limit(model, self.rate_limiter)

prompt = ChatPromptTemplate.from_messages(
[
(
Expand All @@ -70,7 +82,7 @@ def __call__(self) -> Runnable:
),
}
| prompt
| llm_with_tools
| model
| OpenAIFunctionsAgentOutputParser()
)

Expand Down
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pytest-mock = "^3.11.1"
pytest-socket = "^0.6.0"
pytest-watch = "^4.2.0"
pytest-timeout = "^2.2.0"
freezegun = "^1.3.1"


[tool.ruff]
Expand Down
7 changes: 5 additions & 2 deletions tests/unit_tests/test_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ def test_public_api() -> None:
[
"clone_public_dataset",
"download_public_dataset",
"registry",
"model_registry",
]
"RateLimiter",
"registry",
"with_rate_limit",
],
key=lambda x: x.lower(),
)
65 changes: 65 additions & 0 deletions tests/unit_tests/test_rate_limiting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest
from freezegun import freeze_time

from langchain_benchmarks.rate_limiting import RateLimiter


@pytest.mark.parametrize(
"delta_time, requests_per_second, max_bucket_size, expected_result",
[
(
1,
1,
1,
True,
),
(
0.5,
1,
1,
False,
),
(
0.5,
2,
1,
True,
),
],
)
def test_consume(
delta_time: float,
requests_per_second: float,
max_bucket_size: float,
expected_result: bool,
) -> None:
"""Test the consumption of tokens over time.

Args:
delta_time: The time in seconds to add to the initial time.
requests_per_second: The rate at which tokens are added per second.
max_bucket_size: The maximum size of the token bucket.
expected_result: The expected result of the consume operation.
"""
rate_limiter = RateLimiter(
requests_per_second=requests_per_second, max_bucket_size=max_bucket_size
)

with freeze_time(auto_tick_seconds=delta_time):
assert rate_limiter._consume() is False
assert rate_limiter._consume() is expected_result


def test_consume_count_tokens() -> None:
"""Test to check that the bucket size is used correctly."""
rate_limiter = RateLimiter(
requests_per_second=60,
max_bucket_size=10,
)

with freeze_time(auto_tick_seconds=100):
assert rate_limiter._consume() is False
assert rate_limiter._consume() is True
assert (
rate_limiter.available_tokens == 9
) # Max bucket size is 10, so 10 - 1 = 9