Skip to content

feat: rate limit improvements #1321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 86 additions & 40 deletions interactions/api/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,61 +95,105 @@ async def wait(self) -> None:


class BucketLock:
"""Manages the ratelimit for each bucket"""

def __init__(self) -> None:
self._lock: asyncio.Lock = asyncio.Lock()
"""Manages the rate limit for each bucket."""

DEFAULT_LIMIT = 1
DEFAULT_REMAINING = 1
DEFAULT_DELTA = 0.0

def __init__(self, header: CIMultiDictProxy | None = None) -> None:
self._semaphore: asyncio.Semaphore | None = None
if header is None:
self.bucket_hash: str | None = None
self.limit: int = self.DEFAULT_LIMIT
self.remaining: int = self.DEFAULT_REMAINING
self.delta: float = self.DEFAULT_DELTA
else:
self.ingest_ratelimit_header(header)

self.unlock_on_exit: bool = True
self.logger = constants.get_logger()

self.bucket_hash: str | None = None
self.limit: int = -1
self.remaining: int = -1
self.delta: float = 0.0
self._lock: asyncio.Lock = asyncio.Lock()

def __repr__(self) -> str:
return f"<BucketLock: {self.bucket_hash or 'Generic'}>"
return f"<BucketLock: {self.bucket_hash or 'Generic'}, limit: {self.limit}, remaining: {self.remaining}, delta: {self.delta}>"

@property
def locked(self) -> bool:
"""Return True if lock is acquired."""
return self._lock.locked()

def unlock(self) -> None:
"""Unlock this bucket."""
self._lock.release()
"""Returns whether the bucket is locked."""
if self._lock.locked():
return True
return self._semaphore is not None and self._semaphore.locked()

def ingest_ratelimit_header(self, header: CIMultiDictProxy) -> None:
"""
Ingests a discord rate limit header to configure this bucket lock.
Ingests the rate limit header.

Args:
header: A header from a http response
header: The header to ingest, containing rate limit information.

Updates the bucket_hash, limit, remaining, and delta attributes with the information from the header.
"""
self.bucket_hash = header.get("x-ratelimit-bucket")
self.limit = int(header.get("x-ratelimit-limit") or -1)
self.remaining = int(header.get("x-ratelimit-remaining") or -1)
self.delta = float(header.get("x-ratelimit-reset-after", 0.0))

async def blind_defer_unlock(self) -> None:
"""Unlocks the BucketLock but doesn't wait for completion."""
self.unlock_on_exit = False
loop = asyncio.get_running_loop()
loop.call_later(self.delta, self.unlock)

async def defer_unlock(self, reset_after: float | None = None) -> None:
"""Unlocks the BucketLock after a specified delay."""
self.unlock_on_exit = False
await asyncio.sleep(reset_after or self.delta)
self.unlock()
self.limit = int(header.get("x-ratelimit-limit", self.DEFAULT_LIMIT))
self.remaining = int(header.get("x-ratelimit-remaining", self.DEFAULT_REMAINING))
self.delta = float(header.get("x-ratelimit-reset-after", self.DEFAULT_DELTA))

if self._semaphore is None or self._semaphore._value != self.limit:
self._semaphore = asyncio.Semaphore(self.limit)

async def acquire(self) -> None:
"""Acquires the semaphore."""
if self._semaphore is None:
return

if self._lock.locked():
self.logger.debug(f"Waiting for bucket {self.bucket_hash} to unlock.")
async with self._lock:
pass

await self._semaphore.acquire()

def release(self) -> None:
"""
Releases the semaphore.

Note: If the bucket has been locked with lock_for_duration, this will not release the lock.
"""
if self._semaphore is None:
return
self._semaphore.release()

async def lock_for_duration(self, duration: float, block: bool = False) -> None:
"""
Locks the bucket for a given duration.

Args:
duration: The duration to lock the bucket for.
block: Whether to block until the bucket is unlocked.

Raises:
RuntimeError: If the bucket is already locked.
"""
if self._lock.locked():
raise RuntimeError("Attempted to lock a bucket that is already locked.")

async def _release() -> None:
await asyncio.sleep(duration)
self._lock.release()

if block:
await self._lock.acquire()
await _release()
else:
await self._lock.acquire()
asyncio.create_task(_release())

async def __aenter__(self) -> None:
await self._lock.acquire()
await self.acquire()

async def __aexit__(self, *args) -> None:
if self.unlock_on_exit and self._lock.locked():
self.unlock()
self.unlock_on_exit = True
self.release()


class HTTPClient(
Expand Down Expand Up @@ -363,7 +407,7 @@ async def request(
f"Reset in {result.get('retry_after')} seconds",
)
# lock this resource and wait for unlock
await lock.defer_unlock(float(result["retry_after"]))
await lock.lock_for_duration(float(result["retry_after"]), block=True)
else:
# endpoint ratelimit is reached
# 429's are unfortunately unavoidable, but we can attempt to avoid them
Expand All @@ -372,15 +416,17 @@ async def request(
self.logger.warning,
f"{route.endpoint} Has exceeded it's ratelimit ({lock.limit})! Reset in {lock.delta} seconds",
)
await lock.defer_unlock() # lock this route and wait for unlock
await lock.lock_for_duration(lock.delta, block=True)
continue
if lock.remaining == 0:
# Last call available in the bucket, lock until reset
self.log_ratelimit(
self.logger.debug,
f"{route.endpoint} Has exhausted its ratelimit ({lock.limit})! Locking route for {lock.delta} seconds",
)
await lock.blind_defer_unlock() # lock this route, but continue processing the current response
await lock.lock_for_duration(
lock.delta
) # lock this route, but continue processing the current response

elif response.status in {500, 502, 504}:
# Server issues, retry
Expand Down
67 changes: 46 additions & 21 deletions interactions/api/http/http_requests/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def get_channel(self, channel_id: "Snowflake_Type") -> discord_typings.Cha
channel

"""
result = await self.request(Route("GET", f"/channels/{int(channel_id)}"))
result = await self.request(Route("GET", "/channels/{channel_id}", channel_id=channel_id))
return cast(discord_typings.ChannelData, result)

@overload
Expand Down Expand Up @@ -109,7 +109,9 @@ async def get_channel_messages(
}
params = dict_filter_none(params)

result = await self.request(Route("GET", f"/channels/{int(channel_id)}/messages"), params=params)
result = await self.request(
Route("GET", "/channels/{channel_id}/messages", channel_id=channel_id), params=params
)
return cast(list[discord_typings.MessageData], result)

async def create_guild_channel(
Expand Down Expand Up @@ -168,7 +170,9 @@ async def create_guild_channel(
)
payload = dict_filter_none(payload)

result = await self.request(Route("POST", f"/guilds/{int(guild_id)}/channels"), payload=payload, reason=reason)
result = await self.request(
Route("POST", "/guilds/{guild_id}/channels", guild_id=guild_id), payload=payload, reason=reason
)
return cast(discord_typings.ChannelData, result)

async def move_channel(
Expand Down Expand Up @@ -200,7 +204,9 @@ async def move_channel(
}
payload = dict_filter_none(payload)

await self.request(Route("PATCH", f"/guilds/{int(guild_id)}/channels"), payload=payload, reason=reason)
await self.request(
Route("PATCH", "/guilds/{guild_id}/channels", guild_id=guild_id), payload=payload, reason=reason
)

async def modify_channel(
self, channel_id: "Snowflake_Type", data: dict, reason: str | None = None
Expand All @@ -217,7 +223,9 @@ async def modify_channel(
Channel object on success

"""
result = await self.request(Route("PATCH", f"/channels/{int(channel_id)}"), payload=data, reason=reason)
result = await self.request(
Route("PATCH", "/channels/{channel_id}", channel_id=channel_id), payload=data, reason=reason
)
return cast(discord_typings.ChannelData, result)

async def delete_channel(self, channel_id: "Snowflake_Type", reason: str | None = None) -> None:
Expand All @@ -229,7 +237,7 @@ async def delete_channel(self, channel_id: "Snowflake_Type", reason: str | None
reason: An optional reason for the audit log

"""
await self.request(Route("DELETE", f"/channels/{int(channel_id)}"), reason=reason)
await self.request(Route("DELETE", "/channels/{channel_id}", channel_id=channel_id), reason=reason)

async def get_channel_invites(self, channel_id: "Snowflake_Type") -> list[discord_typings.InviteData]:
"""
Expand All @@ -242,7 +250,7 @@ async def get_channel_invites(self, channel_id: "Snowflake_Type") -> list[discor
List of invite objects

"""
result = await self.request(Route("GET", f"/channels/{int(channel_id)}/invites"))
result = await self.request(Route("GET", "/channels/{channel_id}/invites", channel_id=channel_id))
return cast(list[discord_typings.InviteData], result)

@overload
Expand Down Expand Up @@ -336,7 +344,7 @@ async def create_channel_invite(
payload = dict_filter_none(payload)

result = await self.request(
Route("POST", f"/channels/{int(channel_id)}/invites"), payload=payload, reason=reason
Route("POST", "/channels/{channel_id}/invites", channel_id=channel_id), payload=payload, reason=reason
)
return cast(discord_typings.InviteData, result)

Expand All @@ -361,13 +369,14 @@ async def get_invite(

"""
params: PAYLOAD_TYPE = {
"invite_code": invite_code,
"with_counts": with_counts,
"with_expiration": with_expiration,
"guild_scheduled_event_id": int(scheduled_event_id) if scheduled_event_id else None,
}
params = dict_filter_none(params)

result = await self.request(Route("GET", f"/invites/{invite_code}", params=params))
result = await self.request(Route("GET", "/invites/{invite_code}", params=params))
return cast(discord_typings.InviteData, result)

async def delete_invite(self, invite_code: str, reason: str | None = None) -> discord_typings.InviteData:
Expand All @@ -382,7 +391,7 @@ async def delete_invite(self, invite_code: str, reason: str | None = None) -> di
The deleted invite object

"""
result = await self.request(Route("DELETE", f"/invites/{invite_code}"), reason=reason)
result = await self.request(Route("DELETE", "/invites/{invite_code}", invite_code=invite_code), reason=reason)
return cast(discord_typings.InviteData, result)

async def edit_channel_permission(
Expand All @@ -409,7 +418,12 @@ async def edit_channel_permission(
payload: PAYLOAD_TYPE = {"allow": allow, "deny": deny, "type": perm_type}

await self.request(
Route("PUT", f"/channels/{int(channel_id)}/permissions/{int(overwrite_id)}"),
Route(
"PUT",
"/channels/{channel_id}/permissions/{overwrite_id}",
channel_id=channel_id,
overwrite_id=overwrite_id,
),
payload=payload,
reason=reason,
)
Expand All @@ -429,7 +443,10 @@ async def delete_channel_permission(
reason: An optional reason for the audit log

"""
await self.request(Route("DELETE", f"/channels/{int(channel_id)}/{int(overwrite_id)}"), reason=reason)
await self.request(
Route("DELETE", "/channels/{channel_id}/{overwrite_id}", channel_id=channel_id, overwrite_id=overwrite_id),
reason=reason,
)

async def follow_news_channel(
self, channel_id: "Snowflake_Type", webhook_channel_id: "Snowflake_Type"
Expand All @@ -447,7 +464,9 @@ async def follow_news_channel(
"""
payload = {"webhook_channel_id": int(webhook_channel_id)}

result = await self.request(Route("POST", f"/channels/{int(channel_id)}/followers"), payload=payload)
result = await self.request(
Route("POST", "/channels/{channel_id}/followers", channel_id=channel_id), payload=payload
)
return cast(discord_typings.FollowedChannelData, result)

async def trigger_typing_indicator(self, channel_id: "Snowflake_Type") -> None:
Expand All @@ -458,7 +477,7 @@ async def trigger_typing_indicator(self, channel_id: "Snowflake_Type") -> None:
channel_id: The id of the channel to "type" in

"""
await self.request(Route("POST", f"/channels/{int(channel_id)}/typing"))
await self.request(Route("POST", "/channels/{channel_id}/typing", channel_id=channel_id))

async def get_pinned_messages(self, channel_id: "Snowflake_Type") -> list[discord_typings.MessageData]:
"""
Expand All @@ -471,7 +490,7 @@ async def get_pinned_messages(self, channel_id: "Snowflake_Type") -> list[discor
A list of pinned message objects

"""
result = await self.request(Route("GET", f"/channels/{int(channel_id)}/pins"))
result = await self.request(Route("GET", "/channels/{channel_id}/pins", channel_id=channel_id))
return cast(list[discord_typings.MessageData], result)

async def create_stage_instance(
Expand Down Expand Up @@ -514,7 +533,7 @@ async def get_stage_instance(self, channel_id: "Snowflake_Type") -> discord_typi
A stage instance.

"""
result = await self.request(Route("GET", f"/stage-instances/{int(channel_id)}"))
result = await self.request(Route("GET", "/stage-instances/{channel_id}", channel_id=channel_id))
return cast(discord_typings.StageInstanceData, result)

async def modify_stage_instance(
Expand All @@ -540,7 +559,7 @@ async def modify_stage_instance(
payload: PAYLOAD_TYPE = {"topic": topic, "privacy_level": privacy_level}
payload = dict_filter_none(payload)
result = await self.request(
Route("PATCH", f"/stage-instances/{int(channel_id)}"), payload=payload, reason=reason
Route("PATCH", "/stage-instances/{channel_id}", channel_id=channel_id), payload=payload, reason=reason
)
return cast(discord_typings.StageInstanceData, result)

Expand All @@ -553,7 +572,7 @@ async def delete_stage_instance(self, channel_id: "Snowflake_Type", reason: str
reason: The reason for the deletion

"""
await self.request(Route("DELETE", f"/stage-instances/{int(channel_id)}"), reason=reason)
await self.request(Route("DELETE", "/stage-instances/{channel_id}", channel_id=channel_id), reason=reason)

async def create_tag(
self,
Expand Down Expand Up @@ -582,7 +601,9 @@ async def create_tag(
}
payload = dict_filter_none(payload)

result = await self.request(Route("POST", f"/channels/{int(channel_id)}/tags"), payload=payload)
result = await self.request(
Route("POST", "/channels/{channel_id}/tags", channel_id=channel_id), payload=payload
)
return cast(discord_typings.ChannelData, result)

async def edit_tag(
Expand Down Expand Up @@ -614,7 +635,9 @@ async def edit_tag(
}
payload = dict_filter_none(payload)

result = await self.request(Route("PUT", f"/channels/{int(channel_id)}/tags/{int(tag_id)}"), payload=payload)
result = await self.request(
Route("PUT", "/channels/{channel_id}/tags/{tag_id}", channel_id=channel_id, tag_id=tag_id), payload=payload
)
return cast(discord_typings.ChannelData, result)

async def delete_tag(self, channel_id: "Snowflake_Type", tag_id: "Snowflake_Type") -> discord_typings.ChannelData:
Expand All @@ -625,5 +648,7 @@ async def delete_tag(self, channel_id: "Snowflake_Type", tag_id: "Snowflake_Type
channel_id: The ID of the forum channel to delete tag it.
tag_id: The ID of the tag to delete
"""
result = await self.request(Route("DELETE", f"/channels/{int(channel_id)}/tags/{int(tag_id)}"))
result = await self.request(
Route("DELETE", "/channels/{channel_id}/tags/{tag_id}", channel_id=channel_id, tag_id=tag_id)
)
return cast(discord_typings.ChannelData, result)
Loading