-
Couldn't load subscription status.
- Fork 51
Add rate limiter #121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rate limiter #121
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| """Implementation of a rate limiter based on a token bucket.""" | ||
| import threading | ||
| import time | ||
| from typing import Any, Optional | ||
|
|
||
| from langchain.schema.runnable import Runnable, RunnableLambda | ||
| from langchain.schema.runnable.utils import Input, Output | ||
|
|
||
|
|
||
| class RateLimiter: | ||
| def __init__( | ||
| self, | ||
| *, | ||
| requests_per_second: float = 1, | ||
| check_every_n_seconds: float = 0.1, | ||
| max_bucket_size: float = 1, | ||
| ) -> None: | ||
| """A rate limiter based on a token bucket. | ||
|
|
||
| These *tokens* have NOTHING to do with LLM tokens. They are just | ||
| a way to keep track of how many requests can be made at a given time. | ||
|
|
||
| This rate limiter is designed to work in a threaded environment. | ||
|
|
||
| It works by filling up a bucket with tokens at a given rate. Each | ||
| request consumes a given number of tokens. If there are not enough | ||
| tokens in the bucket, the request is blocked until there are enough | ||
| tokens. | ||
|
|
||
| Args: | ||
| requests_per_second: The number of tokens to add per second to the bucket. | ||
| Must be at least 1. The tokens represent "credit" that can be used | ||
| to make requests. | ||
| check_every_n_seconds: check whether the tokens are available | ||
| every this many seconds. Can be a float to represent | ||
| fractions of a second. | ||
| max_bucket_size: The maximum number of tokens that can be in the bucket. | ||
| This is used to prevent bursts of requests. | ||
| """ | ||
| # Number of requests that we can make per second. | ||
| self.requests_per_second = requests_per_second | ||
| # Number of tokens in the bucket. | ||
| self.available_tokens = 0.0 | ||
| self.max_bucket_size = max_bucket_size | ||
| # A lock to ensure that tokens can only be consumed by one thread | ||
| # at a given time. | ||
| self._consume_lock = threading.Lock() | ||
| # The last time we tried to consume tokens. | ||
| self.last: Optional[time.time] = None | ||
| self.check_every_n_seconds = check_every_n_seconds | ||
|
|
||
| def _consume(self) -> bool: | ||
| """Consume the given amount of tokens if possible. | ||
|
|
||
| Returns: | ||
| True means that the tokens were consumed, and the caller can proceed to | ||
| make the request. A False means that the tokens were not consumed, and | ||
| the caller should try again later. | ||
| """ | ||
| with self._consume_lock: | ||
| now = time.time() | ||
|
|
||
| # initialize on first call to avoid a burst | ||
| if self.last is None: | ||
| self.last = now | ||
|
|
||
| elapsed = now - self.last | ||
|
|
||
| if elapsed * self.requests_per_second >= 1: | ||
| self.available_tokens += elapsed * self.requests_per_second | ||
| self.last = now | ||
|
|
||
| # Make sure that we don't exceed the bucket size. | ||
| # This is used to prevent bursts of requests. | ||
| self.available_tokens = min(self.available_tokens, self.max_bucket_size) | ||
|
|
||
| # As long as we have at least one token, we can proceed. | ||
| if self.available_tokens >= 1: | ||
| self.available_tokens -= 1 | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
| def wait(self) -> None: | ||
| """Blocking call to wait until the given number of tokens are available.""" | ||
| while not self._consume(): | ||
| time.sleep(self.check_every_n_seconds) | ||
|
|
||
|
|
||
| def with_rate_limit( | ||
| runnable: Runnable[Input, Output], | ||
| rate_limiter: RateLimiter, | ||
| ) -> Runnable[Input, Output]: | ||
| """Add a rate limiter to the runnable. | ||
|
|
||
| Args: | ||
| runnable: The runnable to throttle. | ||
| rate_limiter: The throttle to use. | ||
|
|
||
| Returns: | ||
| A runnable lambda that acts as a throttled passthrough. | ||
| """ | ||
|
|
||
| def _wait(input: dict, **kwargs: Any) -> dict: | ||
| """Wait for the rate limiter to allow the request to proceed.""" | ||
| rate_limiter.wait() | ||
| return input | ||
|
|
||
| return RunnableLambda(_wait).with_config({"name": "Wait"}) | runnable | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| from langchain.schema.runnable import Runnable, RunnableLambda, RunnablePassthrough | ||
| from langchain.tools.render import format_tool_to_openai_function | ||
|
|
||
| from langchain_benchmarks import rate_limiting, with_rate_limit | ||
| from langchain_benchmarks.schema import ToolUsageTask | ||
|
|
||
|
|
||
|
|
@@ -24,33 +25,44 @@ def _ensure_output_exists(inputs: dict) -> dict: | |
|
|
||
| class OpenAIAgentFactory: | ||
| def __init__( | ||
| self, task: ToolUsageTask, *, model: str = "gpt-3.5-turbo-16k" | ||
| self, | ||
| task: ToolUsageTask, | ||
| *, | ||
| model: str = "gpt-3.5-turbo-16k", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ooc: why do we need a default here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can remove in a separate PR -- probably shouldn't be here |
||
| rate_limiter: Optional[rate_limiting.RateLimiter] = None, | ||
| ) -> None: | ||
| """Create an OpenAI agent factory for the given task. | ||
|
|
||
| Args: | ||
| task: The task to create an agent factory for. | ||
| model: The model to use -- this must be an open AI model. | ||
| rate_limiter: The rate limiter to use | ||
| """ | ||
| self.task = task | ||
| self.model = model | ||
| self.rate_limiter = rate_limiter | ||
|
|
||
| def create(self) -> Runnable: | ||
| """Agent Executor""" | ||
| # For backwards compatibility | ||
| return self() | ||
|
|
||
| def __call__(self) -> Runnable: | ||
| llm = ChatOpenAI( | ||
| model = ChatOpenAI( | ||
| model=self.model, | ||
| temperature=0, | ||
| ) | ||
|
|
||
| env = self.task.create_environment() | ||
|
|
||
| llm_with_tools = llm.bind( | ||
| model = model.bind( | ||
| functions=[format_tool_to_openai_function(t) for t in env.tools] | ||
| ) | ||
|
|
||
| if rate_limiting: | ||
| # Rate limited model | ||
| model = with_rate_limit(model, self.rate_limiter) | ||
|
|
||
| prompt = ChatPromptTemplate.from_messages( | ||
| [ | ||
| ( | ||
|
|
@@ -70,7 +82,7 @@ def __call__(self) -> Runnable: | |
| ), | ||
| } | ||
| | prompt | ||
| | llm_with_tools | ||
| | model | ||
| | OpenAIFunctionsAgentOutputParser() | ||
| ) | ||
|
|
||
|
|
||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| import pytest | ||
| from freezegun import freeze_time | ||
|
|
||
| from langchain_benchmarks.rate_limiting import RateLimiter | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "delta_time, requests_per_second, max_bucket_size, expected_result", | ||
| [ | ||
| ( | ||
| 1, | ||
| 1, | ||
| 1, | ||
| True, | ||
| ), | ||
| ( | ||
| 0.5, | ||
| 1, | ||
| 1, | ||
| False, | ||
| ), | ||
| ( | ||
| 0.5, | ||
| 2, | ||
| 1, | ||
| True, | ||
| ), | ||
| ], | ||
| ) | ||
| def test_consume( | ||
| delta_time: float, | ||
| requests_per_second: float, | ||
| max_bucket_size: float, | ||
| expected_result: bool, | ||
| ) -> None: | ||
| """Test the consumption of tokens over time. | ||
|
|
||
| Args: | ||
| delta_time: The time in seconds to add to the initial time. | ||
| requests_per_second: The rate at which tokens are added per second. | ||
| max_bucket_size: The maximum size of the token bucket. | ||
| expected_result: The expected result of the consume operation. | ||
| """ | ||
| rate_limiter = RateLimiter( | ||
| requests_per_second=requests_per_second, max_bucket_size=max_bucket_size | ||
| ) | ||
|
|
||
| with freeze_time(auto_tick_seconds=delta_time): | ||
| assert rate_limiter._consume() is False | ||
| assert rate_limiter._consume() is expected_result | ||
|
|
||
|
|
||
| def test_consume_count_tokens() -> None: | ||
| """Test to check that the bucket size is used correctly.""" | ||
| rate_limiter = RateLimiter( | ||
| requests_per_second=60, | ||
| max_bucket_size=10, | ||
| ) | ||
|
|
||
| with freeze_time(auto_tick_seconds=100): | ||
| assert rate_limiter._consume() is False | ||
| assert rate_limiter._consume() is True | ||
| assert ( | ||
| rate_limiter.available_tokens == 9 | ||
| ) # Max bucket size is 10, so 10 - 1 = 9 |
Uh oh!
There was an error while loading. Please reload this page.