Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions providers/slack/src/airflow/providers/slack/hooks/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import json
import time
import warnings
from collections.abc import Sequence
from functools import cached_property
Expand All @@ -28,7 +29,7 @@
from slack_sdk.errors import SlackApiError
from typing_extensions import NotRequired

from airflow.exceptions import AirflowNotFoundException
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.providers.slack.utils import ConnectionExtraConfig
from airflow.utils.helpers import exactly_one
Expand Down Expand Up @@ -291,7 +292,7 @@ def get_channel_id(self, channel_name: str) -> str:
"""
next_cursor = None
while not (channel_id := self._channels_mapping.get(channel_name)):
res = self.client.conversations_list(cursor=next_cursor, types="public_channel,private_channel")
res = self._call_conversations_list(cursor=next_cursor)
if TYPE_CHECKING:
# Slack SDK response type too broad, this should make mypy happy
assert isinstance(res.data, dict)
Expand All @@ -308,6 +309,37 @@ def get_channel_id(self, channel_name: str) -> str:
raise LookupError(msg)
return channel_id

def _call_conversations_list(self, cursor: str | None = None):
"""
Call ``conversations.list`` with automatic 429-retry.

.. versionchanged:: 3.0.0
Automatically retries on 429 responses (up to 5 times, honouring *Retry-After* header).

:param cursor: Pagination cursor returned by the previous ``conversations.list`` call.
Pass ``None`` (default) to start from the first page.
:raises AirflowException: If the method hits the rate-limit 5 times in a row.
:raises SlackApiError: Propagated when errors other than 429 occur.
:return: Slack SDK response for the page requested.
"""
max_retries = 5
for attempt in range(max_retries):
try:
return self.client.conversations_list(cursor=cursor, types="public_channel,private_channel")
except SlackApiError as e:
if e.response.status_code == 429 and attempt < max_retries:
retry_after = int(e.response.headers.get("Retry-After", 30))
self.log.warning(
"Rate limit hit. Retrying in %s seconds. Attempt %s/%s",
retry_after,
attempt + 1,
max_retries,
)
time.sleep(retry_after)
else:
raise
raise AirflowException("Max retries reached for conversations.list")

def test_connection(self):
"""
Tests the Slack API connection.
Expand Down
31 changes: 30 additions & 1 deletion providers/slack/tests/unit/slack/hooks/test_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from slack_sdk.http_retry.builtin_handlers import ConnectionErrorRetryHandler, RateLimitErrorRetryHandler
from slack_sdk.web.slack_response import SlackResponse

from airflow.exceptions import AirflowNotFoundException
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models.connection import Connection
from airflow.providers.slack.hooks.slack import SlackHook

Expand Down Expand Up @@ -88,6 +88,13 @@ def fake_slack_response(*, data: dict | bytes, status_code: int = 200, **kwargs)

return SlackResponse(status_code=status_code, data=data, **kwargs)

@staticmethod
def make_429():
resp = mock.MagicMock()
resp.status_code = 429
resp.headers = {"Retry-After": "1"}
return SlackApiError("ratelimited", response=resp)

@pytest.mark.parametrize(
"conn_id",
[
Expand Down Expand Up @@ -389,6 +396,28 @@ def test_get_channel_id(self, mocked_client):
with pytest.raises(LookupError, match="Unable to find slack channel"):
hook.get_channel_id("troubleshooting")

def test_call_conversations_list_retries_then_succeeds(self, monkeypatch):
ok_resp = self.fake_slack_response(data={"channels": []})
monkeypatch.setattr(
"airflow.providers.slack.hooks.slack.WebClient",
lambda **_: mock.MagicMock(
conversations_list=mock.Mock(side_effect=[self.make_429(), self.make_429(), ok_resp])
),
)
with mock.patch("time.sleep") as mocked_sleep:
hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID)
res = hook._call_conversations_list()
assert res is ok_resp
assert mocked_sleep.call_count == 2

def test_call_conversations_list_exceeds_max(self, monkeypatch):
monkeypatch.setattr(
"airflow.providers.slack.hooks.slack.WebClient",
lambda **_: mock.MagicMock(conversations_list=mock.Mock(side_effect=[self.make_429()] * 5)),
)
with pytest.raises(AirflowException, match="Max retries"):
SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID)._call_conversations_list()

def test_send_file_v2(self, mocked_client):
SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID).send_file_v2(
channel_id="C00000000", file_uploads={"file": "/foo/bar/file.txt", "filename": "foo.txt"}
Expand Down
Loading