Skip to content

Commit 1cba6a4

Browse files
feat: support regex component callbacks (#1332)
* feat: support regex component callbacks * docs: document regex matching component callback * ci: correct from checks. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ba44e99 commit 1cba6a4

File tree

3 files changed

+41
-8
lines changed

3 files changed

+41
-8
lines changed

docs/src/Guides/05 Components.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,17 @@ When responding to a component you need to satisfy discord either by responding
250250
)
251251
)
252252
```
253+
254+
=== ":four: Persistent Callbacks, with regex"
255+
Ah, I see you are a masochist. You want to use regex to match your custom_ids. Well who am I to stop you?
256+
257+
```python
258+
@component_callback(re.compile(r"\w*"))
259+
async def test_callback(ctx: interactions.ComponentContext):
260+
await ctx.send(f"Clicked {ctx.custom_id}")
261+
```
262+
263+
Just like normal `@component_callback`, you can specify a regex pattern to match your custom_ids, instead of explicitly passing strings.
264+
This is useful if you have a lot of components with similar custom_ids, and you want to handle them all in the same callback.
265+
266+
Please do bare in mind that using regex patterns can be a bit slower than using strings, especially if you have a lot of components.

interactions/client/client.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def __init__(
383383
] = {}
384384
"""A dictionary of registered application commands in a tree"""
385385
self._component_callbacks: Dict[str, Callable[..., Coroutine]] = {}
386+
self._regex_component_callbacks: Dict[re.Pattern, Callable[..., Coroutine]] = {}
386387
self._modal_callbacks: Dict[str, Callable[..., Coroutine]] = {}
387388
self._global_autocompletes: Dict[str, GlobalAutoComplete] = {}
388389
self.processors: Dict[str, Callable[..., Coroutine]] = {}
@@ -1270,10 +1271,15 @@ def add_component_callback(self, command: ComponentCommand) -> None:
12701271
12711272
"""
12721273
for listener in command.listeners:
1273-
# I know this isn't an ideal solution, but it means we can lookup callbacks with O(1)
1274-
if listener in self._component_callbacks.keys():
1275-
raise ValueError(f"Duplicate Component! Multiple component callbacks for `{listener}`")
1276-
self._component_callbacks[listener] = command
1274+
if isinstance(listener, re.Pattern):
1275+
if listener in self._regex_component_callbacks.keys():
1276+
raise ValueError(f"Duplicate Component! Multiple component callbacks for `{listener}`")
1277+
self._regex_component_callbacks[listener] = command
1278+
else:
1279+
# I know this isn't an ideal solution, but it means we can lookup callbacks with O(1)
1280+
if listener in self._component_callbacks.keys():
1281+
raise ValueError(f"Duplicate Component! Multiple component callbacks for `{listener}`")
1282+
self._component_callbacks[listener] = command
12771283
continue
12781284

12791285
def add_modal_callback(self, command: ModalCommand) -> None:
@@ -1410,7 +1416,7 @@ async def wrap(*args, **kwargs) -> Absent[List[Dict]]:
14101416
if cmd_name not in found and warn_missing:
14111417
self.logger.error(
14121418
f'Detected yet to sync slash command "/{cmd_name}" for scope '
1413-
f"{'global' if scope == GLOBAL_SCOPE else scope}"
1419+
f'{"global" if scope == GLOBAL_SCOPE else scope}'
14141420
)
14151421
continue
14161422
found.add(cmd_name)
@@ -1668,7 +1674,15 @@ async def _dispatch_interaction(self, event: RawGatewayEvent) -> None:
16681674
component_type = interaction_data["data"]["component_type"]
16691675

16701676
self.dispatch(events.Component(ctx=ctx))
1671-
if callback := self._component_callbacks.get(ctx.custom_id):
1677+
component_callback = self._component_callbacks.get(ctx.custom_id)
1678+
if not component_callback:
1679+
# evaluate regex component callbacks
1680+
for regex, callback in self._regex_component_callbacks.items():
1681+
if regex.match(ctx.custom_id):
1682+
component_callback = callback
1683+
break
1684+
1685+
if component_callback:
16721686
await self.__dispatch_interaction(
16731687
ctx=ctx,
16741688
callback=callback(ctx),

interactions/models/internal/application_commands.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ class ComponentCommand(InteractionCommand):
807807
name: str = attrs.field(
808808
repr=False,
809809
)
810-
listeners: list[str] = attrs.field(repr=False, factory=list)
810+
listeners: list[str | re.Pattern] = attrs.field(repr=False, factory=list)
811811

812812

813813
@attrs.define(eq=False, order=False, hash=False, kw_only=True)
@@ -1122,13 +1122,16 @@ def message_context_menu(
11221122
)
11231123

11241124

1125-
def component_callback(*custom_id: str) -> Callable[[AsyncCallable], ComponentCommand]:
1125+
def component_callback(*custom_id: str | re.Pattern) -> Callable[[AsyncCallable], ComponentCommand]:
11261126
"""
11271127
Register a coroutine as a component callback.
11281128
11291129
Component callbacks work the same way as commands, just using components as a way of invoking, instead of messages.
11301130
Your callback will be given a single argument, `ComponentContext`
11311131
1132+
Note:
1133+
This can optionally take a regex pattern, which will be used to match against the custom ID of the component
1134+
11321135
Args:
11331136
*custom_id: The custom ID of the component to wait for
11341137
@@ -1141,6 +1144,8 @@ def wrapper(func: AsyncCallable) -> ComponentCommand:
11411144
return ComponentCommand(name=f"ComponentCallback::{custom_id}", callback=func, listeners=custom_id)
11421145

11431146
custom_id = _unpack_helper(custom_id)
1147+
if not all(isinstance(i, re.Pattern) for i in custom_id) or all(isinstance(i, str) for i in custom_id):
1148+
raise ValueError("All custom IDs be either a string or a regex pattern, not a mix of both.")
11441149
return wrapper
11451150

11461151

0 commit comments

Comments
 (0)