|
| 1 | +import datetime |
| 2 | +from collections.abc import Callable, Container, Iterable |
| 3 | + |
| 4 | +from discord.ext.commands import ( |
| 5 | + BucketType, |
| 6 | + CheckFailure, |
| 7 | + Cog, |
| 8 | + Command, |
| 9 | + CommandOnCooldown, |
| 10 | + Context, |
| 11 | + Cooldown, |
| 12 | + CooldownMapping, |
| 13 | +) |
| 14 | + |
| 15 | +from pydis_core.utils.logging import get_logger |
| 16 | + |
| 17 | +log = get_logger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +class ContextCheckFailure(CheckFailure): |
| 21 | + """Raised when a context-specific check fails.""" |
| 22 | + |
| 23 | + def __init__(self, redirect_channel: int | None) -> None: |
| 24 | + self.redirect_channel = redirect_channel |
| 25 | + |
| 26 | + if redirect_channel: |
| 27 | + redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" |
| 28 | + else: |
| 29 | + redirect_message = "" |
| 30 | + |
| 31 | + error_message = f"You are not allowed to use that command{redirect_message}." |
| 32 | + |
| 33 | + super().__init__(error_message) |
| 34 | + |
| 35 | + |
| 36 | +class InWhitelistCheckFailure(ContextCheckFailure): |
| 37 | + """Raised when the `in_whitelist` check fails.""" |
| 38 | + |
| 39 | + |
| 40 | +def in_whitelist_check( |
| 41 | + ctx: Context, |
| 42 | + redirect: int, |
| 43 | + channels: Container[int] = (), |
| 44 | + categories: Container[int] = (), |
| 45 | + roles: Container[int] = (), |
| 46 | + fail_silently: bool = False, |
| 47 | +) -> bool: |
| 48 | + """ |
| 49 | + Check if a command was issued in a whitelisted context. |
| 50 | +
|
| 51 | + The whitelists that can be provided are: |
| 52 | +
|
| 53 | + - `channels`: a container with channel ids for whitelisted channels |
| 54 | + - `categories`: a container with category ids for whitelisted categories |
| 55 | + - `roles`: a container with with role ids for whitelisted roles |
| 56 | +
|
| 57 | + If the command was invoked in a context that was not whitelisted, the member is either |
| 58 | + redirected to the `redirect` channel that was passed (default: #bot-commands) or simply |
| 59 | + told that they're not allowed to use this particular command (if `None` was passed). |
| 60 | + """ |
| 61 | + if redirect not in channels: |
| 62 | + # It does not make sense for the channel whitelist to not contain the redirection |
| 63 | + # channel (if applicable). That's why we add the redirection channel to the `channels` |
| 64 | + # container if it's not already in it. As we allow any container type to be passed, |
| 65 | + # we first create a tuple in order to safely add the redirection channel. |
| 66 | + # |
| 67 | + # Note: It's possible for the redirect channel to be in a whitelisted category, but |
| 68 | + # there's no easy way to check that and as a channel can easily be moved in and out of |
| 69 | + # categories, it's probably not wise to rely on its category in any case. |
| 70 | + channels = tuple(channels) + (redirect,) |
| 71 | + |
| 72 | + if channels and ctx.channel.id in channels: |
| 73 | + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") |
| 74 | + return True |
| 75 | + |
| 76 | + # Only check the category id if we have a category whitelist and the channel has a `category_id` |
| 77 | + if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories: |
| 78 | + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") |
| 79 | + return True |
| 80 | + |
| 81 | + # Only check the roles whitelist if we have one and ensure the author's roles attribute returns |
| 82 | + # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). |
| 83 | + if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())): |
| 84 | + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") |
| 85 | + return True |
| 86 | + |
| 87 | + log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") |
| 88 | + |
| 89 | + # Some commands are secret, and should produce no feedback at all. |
| 90 | + if not fail_silently: |
| 91 | + raise InWhitelistCheckFailure(redirect) |
| 92 | + return False |
| 93 | + |
| 94 | + |
| 95 | +def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *, |
| 96 | + bypass_roles: Iterable[int]) -> Callable: |
| 97 | + """ |
| 98 | + Applies a cooldown to a command, but allows members with certain roles to be ignored. |
| 99 | +
|
| 100 | + NOTE: this replaces the `Command.before_invoke` callback, which *might* introduce problems in the future. |
| 101 | + """ |
| 102 | + # Make it a set so lookup is hash based. |
| 103 | + bypass = set(bypass_roles) |
| 104 | + |
| 105 | + # This handles the actual cooldown logic. |
| 106 | + buckets = CooldownMapping(Cooldown(rate, per, type)) |
| 107 | + |
| 108 | + # Will be called after the command has been parse but before it has been invoked, ensures that |
| 109 | + # the cooldown won't be updated if the user screws up their input to the command. |
| 110 | + async def predicate(cog: Cog, ctx: Context) -> None: |
| 111 | + nonlocal bypass, buckets |
| 112 | + |
| 113 | + if any(role.id in bypass for role in ctx.author.roles): |
| 114 | + return |
| 115 | + |
| 116 | + # Cooldown logic, taken from discord.py internals. |
| 117 | + current = ctx.message.created_at.replace(tzinfo=datetime.UTC).timestamp() |
| 118 | + bucket = buckets.get_bucket(ctx.message) |
| 119 | + retry_after = bucket.update_rate_limit(current) |
| 120 | + if retry_after: |
| 121 | + raise CommandOnCooldown(bucket, retry_after) |
| 122 | + |
| 123 | + def wrapper(command: Command) -> Command: |
| 124 | + # NOTE: this could be changed if a subclass of Command were to be used. I didn't see the need for it |
| 125 | + # so I just made it raise an error when the decorator is applied before the actual command object exists. |
| 126 | + # |
| 127 | + # If the `before_invoke` detail is ever a problem then I can quickly just swap over. |
| 128 | + if not isinstance(command, Command): |
| 129 | + raise TypeError( |
| 130 | + "Decorator `cooldown_with_role_bypass` must be applied after the command decorator. " |
| 131 | + "This means it has to be above the command decorator in the code." |
| 132 | + ) |
| 133 | + |
| 134 | + command._before_invoke = predicate |
| 135 | + |
| 136 | + return command |
| 137 | + |
| 138 | + return wrapper |
| 139 | + |
| 140 | + |
| 141 | +async def has_any_role_check(ctx: Context, *roles: str | int) -> bool: |
| 142 | + """ |
| 143 | + Returns True if the context's author has any of the specified roles. |
| 144 | +
|
| 145 | + `roles` are the names or IDs of the roles for which to check. |
| 146 | + False is always returns if the context is outside a guild. |
| 147 | + """ |
| 148 | + if not ctx.guild: # Return False in a DM |
| 149 | + log.trace( |
| 150 | + f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. " |
| 151 | + "This command is restricted by the with_role decorator. Rejecting request." |
| 152 | + ) |
| 153 | + return False |
| 154 | + |
| 155 | + for role in ctx.author.roles: |
| 156 | + if role.id in roles: |
| 157 | + log.trace(f"{ctx.author} has the '{role.name}' role, and passes the check.") |
| 158 | + return True |
| 159 | + |
| 160 | + log.trace( |
| 161 | + f"{ctx.author} does not have the required role to use " |
| 162 | + f"the '{ctx.command.name}' command, so the request is rejected." |
| 163 | + ) |
| 164 | + return False |
| 165 | + |
| 166 | + |
| 167 | +async def has_no_roles_check(ctx: Context, *roles: str | int) -> bool: |
| 168 | + """ |
| 169 | + Returns True if the context's author doesn't have any of the specified roles. |
| 170 | +
|
| 171 | + `roles` are the names or IDs of the roles for which to check. |
| 172 | + False is always returns if the context is outside a guild. |
| 173 | + """ |
| 174 | + if not ctx.guild: # Return False in a DM |
| 175 | + log.trace( |
| 176 | + f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. " |
| 177 | + "This command is restricted by the without_role decorator. Rejecting request." |
| 178 | + ) |
| 179 | + return False |
| 180 | + |
| 181 | + author_roles = [role.id for role in ctx.author.roles] |
| 182 | + check = all(role not in author_roles for role in roles) |
| 183 | + log.trace( |
| 184 | + f"{ctx.author} tried to call the '{ctx.command.name}' command. " |
| 185 | + f"The result of the without_role check was {check}." |
| 186 | + ) |
| 187 | + return check |
0 commit comments