Skip to content

feat: caching improvements #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions interactions/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1995,7 +1995,7 @@ def reload_extension(
sys.modules.pop(name, None)
raise ex from e

async def fetch_guild(self, guild_id: "Snowflake_Type") -> Optional[Guild]:
async def fetch_guild(self, guild_id: "Snowflake_Type", *, force: bool = False) -> Optional[Guild]:
"""
Fetch a guild.

Expand All @@ -2005,13 +2005,14 @@ async def fetch_guild(self, guild_id: "Snowflake_Type") -> Optional[Guild]:

Args:
guild_id: The ID of the guild to get
force: Whether to poll the API regardless of cache

Returns:
Guild Object if found, otherwise None

"""
try:
return await self.cache.fetch_guild(guild_id)
return await self.cache.fetch_guild(guild_id, force=force)
except NotFound:
return None

Expand Down Expand Up @@ -2060,7 +2061,7 @@ async def create_guild_from_template(
guild_data = await self.http.create_guild_from_guild_template(template_code, name, icon)
return Guild.from_dict(guild_data, self)

async def fetch_channel(self, channel_id: "Snowflake_Type") -> Optional["TYPE_ALL_CHANNEL"]:
async def fetch_channel(self, channel_id: "Snowflake_Type", *, force: bool = False) -> Optional["TYPE_ALL_CHANNEL"]:
"""
Fetch a channel.

Expand All @@ -2070,13 +2071,14 @@ async def fetch_channel(self, channel_id: "Snowflake_Type") -> Optional["TYPE_AL

Args:
channel_id: The ID of the channel to get
force: Whether to poll the API regardless of cache

Returns:
Channel Object if found, otherwise None

"""
try:
return await self.cache.fetch_channel(channel_id)
return await self.cache.fetch_channel(channel_id, force=force)
except NotFound:
return None

Expand All @@ -2096,7 +2098,7 @@ def get_channel(self, channel_id: "Snowflake_Type") -> Optional["TYPE_ALL_CHANNE
"""
return self.cache.get_channel(channel_id)

async def fetch_user(self, user_id: "Snowflake_Type") -> Optional[User]:
async def fetch_user(self, user_id: "Snowflake_Type", *, force: bool = False) -> Optional[User]:
"""
Fetch a user.

Expand All @@ -2106,13 +2108,14 @@ async def fetch_user(self, user_id: "Snowflake_Type") -> Optional[User]:

Args:
user_id: The ID of the user to get
force: Whether to poll the API regardless of cache

Returns:
User Object if found, otherwise None

"""
try:
return await self.cache.fetch_user(user_id)
return await self.cache.fetch_user(user_id, force=force)
except NotFound:
return None

Expand All @@ -2132,7 +2135,9 @@ def get_user(self, user_id: "Snowflake_Type") -> Optional[User]:
"""
return self.cache.get_user(user_id)

async def fetch_member(self, user_id: "Snowflake_Type", guild_id: "Snowflake_Type") -> Optional[Member]:
async def fetch_member(
self, user_id: "Snowflake_Type", guild_id: "Snowflake_Type", *, force: bool = False
) -> Optional[Member]:
"""
Fetch a member from a guild.

Expand All @@ -2143,13 +2148,14 @@ async def fetch_member(self, user_id: "Snowflake_Type", guild_id: "Snowflake_Typ
Args:
user_id: The ID of the member
guild_id: The ID of the guild to get the member from
force: Whether to poll the API regardless of cache

Returns:
Member object if found, otherwise None

"""
try:
return await self.cache.fetch_member(guild_id, user_id)
return await self.cache.fetch_member(guild_id, user_id, force=force)
except NotFound:
return None

Expand Down Expand Up @@ -2194,20 +2200,23 @@ async def fetch_scheduled_event(
except NotFound:
return None

async def fetch_custom_emoji(self, emoji_id: "Snowflake_Type", guild_id: "Snowflake_Type") -> Optional[CustomEmoji]:
async def fetch_custom_emoji(
self, emoji_id: "Snowflake_Type", guild_id: "Snowflake_Type", *, force: bool = False
) -> Optional[CustomEmoji]:
"""
Fetch a custom emoji by id.

Args:
emoji_id: The id of the custom emoji.
guild_id: The id of the guild the emoji belongs to.
force: Whether to poll the API regardless of cache.

Returns:
The custom emoji if found, otherwise None.

"""
try:
return await self.cache.fetch_emoji(guild_id, emoji_id)
return await self.cache.fetch_emoji(guild_id, emoji_id, force=force)
except NotFound:
return None

Expand Down
63 changes: 34 additions & 29 deletions interactions/client/smart_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,13 @@ def __attrs_post_init__(self) -> None:

# region User cache

async def fetch_user(self, user_id: "Snowflake_Type") -> User:
async def fetch_user(self, user_id: "Snowflake_Type", *, force: bool = False) -> User:
"""
Fetch a user by their ID.

Args:
user_id: The user's ID
force: If the cache should be ignored, and the user should be fetched from the API

Returns:
User object if found
Expand All @@ -113,9 +114,10 @@ async def fetch_user(self, user_id: "Snowflake_Type") -> User:
user_id = to_snowflake(user_id)

user = self.user_cache.get(user_id)
if user is None:
if (user is None or user._fetched is False) or force:
data = await self._client.http.get_user(user_id)
user = self.place_user_data(data)
user._fetched = True # the user object should set this to True, but we do it here just in case
return user

def get_user(self, user_id: Optional["Snowflake_Type"]) -> Optional[User]:
Expand Down Expand Up @@ -164,13 +166,16 @@ def delete_user(self, user_id: "Snowflake_Type") -> None:

# region Member cache

async def fetch_member(self, guild_id: "Snowflake_Type", user_id: "Snowflake_Type") -> Member:
async def fetch_member(
self, guild_id: "Snowflake_Type", user_id: "Snowflake_Type", *, force: bool = False
) -> Member:
"""
Fetch a member by their guild and user IDs.

Args:
guild_id: The ID of the guild this user belongs to
user_id: The ID of the user
force: If the cache should be ignored, and the member should be fetched from the API

Returns:
Member object if found
Expand All @@ -179,7 +184,7 @@ async def fetch_member(self, guild_id: "Snowflake_Type", user_id: "Snowflake_Typ
guild_id = to_snowflake(guild_id)
user_id = to_snowflake(user_id)
member = self.member_cache.get((guild_id, user_id))
if member is None:
if member is None or force:
data = await self._client.http.get_member(guild_id, user_id)
member = self.place_member_data(guild_id, data)
return member
Expand Down Expand Up @@ -323,15 +328,13 @@ async def is_user_in_guild(

return False

async def fetch_user_guild_ids(
self,
user_id: "Snowflake_Type",
) -> List["Snowflake_Type"]:
async def fetch_user_guild_ids(self, user_id: "Snowflake_Type") -> List["Snowflake_Type"]:
"""
Fetch a list of IDs for the guilds a user has joined.

Args:
user_id: The ID of the user

Returns:
A list of snowflakes for the guilds the client can see the user is within
"""
Expand Down Expand Up @@ -361,16 +364,15 @@ def get_user_guild_ids(self, user_id: "Snowflake_Type") -> List["Snowflake_Type"
# region Message cache

async def fetch_message(
self,
channel_id: "Snowflake_Type",
message_id: "Snowflake_Type",
self, channel_id: "Snowflake_Type", message_id: "Snowflake_Type", *, force: bool = False
) -> Message:
"""
Fetch a message from a channel based on their IDs.

Args:
channel_id: The ID of the channel the message is in
message_id: The ID of the message
force: If the cache should be ignored, and the message should be fetched from the API

Returns:
The message if found
Expand All @@ -379,7 +381,7 @@ async def fetch_message(
message_id = to_snowflake(message_id)
message = self.message_cache.get((channel_id, message_id))

if message is None:
if message is None or force:
data = await self._client.http.get_message(channel_id, message_id)
message = self.place_message_data(data)
if message.channel is None:
Expand Down Expand Up @@ -437,22 +439,20 @@ def delete_message(self, channel_id: "Snowflake_Type", message_id: "Snowflake_Ty
# endregion Message cache

# region Channel cache
async def fetch_channel(
self,
channel_id: "Snowflake_Type",
) -> "TYPE_ALL_CHANNEL":
async def fetch_channel(self, channel_id: "Snowflake_Type", *, force: bool = False) -> "TYPE_ALL_CHANNEL":
"""
Get a channel based on its ID.

Args:
channel_id: The ID of the channel
force: If the cache should be ignored, and the channel should be fetched from the API

Returns:
The channel if found
"""
channel_id = to_snowflake(channel_id)
channel = self.channel_cache.get(channel_id)
if channel is None:
if channel is None or force:
try:
data = await self._client.http.get_channel(channel_id)
channel = self.place_channel_data(data)
Expand Down Expand Up @@ -518,31 +518,33 @@ def place_dm_channel_id(self, user_id: "Snowflake_Type", channel_id: "Snowflake_
"""
self.dm_channels[to_snowflake(user_id)] = to_snowflake(channel_id)

async def fetch_dm_channel_id(self, user_id: "Snowflake_Type") -> "Snowflake_Type":
async def fetch_dm_channel_id(self, user_id: "Snowflake_Type", *, force: bool = False) -> "Snowflake_Type":
"""
Get the DM channel ID for a user.

Args:
user_id: The ID of the user
force: If the cache should be ignored, and the channel should be fetched from the API
"""
user_id = to_snowflake(user_id)
channel_id = self.dm_channels.get(user_id)
if channel_id is None:
if channel_id is None or force:
data = await self._client.http.create_dm(user_id)
channel = self.place_channel_data(data)
channel_id = channel.id
return channel_id

async def fetch_dm_channel(self, user_id: "Snowflake_Type") -> "DM":
async def fetch_dm_channel(self, user_id: "Snowflake_Type", *, force: bool = False) -> "DM":
"""
Fetch the DM channel for a user.

Args:
user_id: The ID of the user
force: If the cache should be ignored, and the channel should be fetched from the API
"""
user_id = to_snowflake(user_id)
channel_id = await self.fetch_dm_channel_id(user_id)
return await self.fetch_channel(channel_id)
channel_id = await self.fetch_dm_channel_id(user_id, force=force)
return await self.fetch_channel(channel_id, force=force)

def get_dm_channel(self, user_id: Optional["Snowflake_Type"]) -> Optional["DM"]:
"""
Expand Down Expand Up @@ -575,19 +577,20 @@ def delete_channel(self, channel_id: "Snowflake_Type") -> None:

# region Guild cache

async def fetch_guild(self, guild_id: "Snowflake_Type") -> Guild:
async def fetch_guild(self, guild_id: "Snowflake_Type", *, force: bool = False) -> Guild:
"""
Fetch a guild based on its ID.

Args:
guild_id: The ID of the guild
force: If the cache should be ignored, and the guild should be fetched from the API

Returns:
The guild if found
"""
guild_id = to_snowflake(guild_id)
guild = self.guild_cache.get(guild_id)
if guild is None:
if guild is None or force:
data = await self._client.http.get_guild(guild_id)
guild = self.place_guild_data(data)
return guild
Expand Down Expand Up @@ -648,21 +651,24 @@ async def fetch_role(
self,
guild_id: "Snowflake_Type",
role_id: "Snowflake_Type",
*,
force: bool = False,
) -> Role:
"""
Fetch a role based on the guild and its own ID.

Args:
guild_id: The ID of the guild this role belongs to
role_id: The ID of the role
force: If the cache should be ignored, and the role should be fetched from the API

Returns:
The role if found
"""
guild_id = to_snowflake(guild_id)
role_id = to_snowflake(role_id)
role = self.role_cache.get(role_id)
if role is None:
if role is None or force:
data = await self._client.http.get_roles(guild_id)
role = self.place_role_data(guild_id, data).get(role_id)
return role
Expand Down Expand Up @@ -830,9 +836,7 @@ def delete_bot_voice_state(self, guild_id: "Snowflake_Type") -> None:
# region Emoji cache

async def fetch_emoji(
self,
guild_id: "Snowflake_Type",
emoji_id: "Snowflake_Type",
self, guild_id: "Snowflake_Type", emoji_id: "Snowflake_Type", *, force: bool = False
) -> "CustomEmoji":
"""
Fetch an emoji based on the guild and its own ID.
Expand All @@ -842,14 +846,15 @@ async def fetch_emoji(
Args:
guild_id: The ID of the guild this emoji belongs to
emoji_id: The ID of the emoji
force: If the cache should be ignored, and the emoji should be fetched from the API

Returns:
The Emoji if found
"""
guild_id = to_snowflake(guild_id)
emoji_id = to_snowflake(emoji_id)
emoji = self.emoji_cache.get(emoji_id) if self.emoji_cache is not None else None
if emoji is None:
if emoji is None or force:
data = await self._client.http.get_guild_emoji(guild_id, emoji_id)
emoji = self.place_emoji_data(guild_id, data)

Expand Down
Loading