Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2331,13 +2331,15 @@ definitions:
- IGNORE
- RESET_PAGINATION
- RATE_LIMITED
- REFRESH_TOKEN_THEN_RETRY
examples:
- SUCCESS
- FAIL
- RETRY
- IGNORE
- RESET_PAGINATION
- RATE_LIMITED
- REFRESH_TOKEN_THEN_RETRY
failure_type:
title: Failure Type
description: Failure type of traced exception if a response matches the filter.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.

# generated by datamodel-codegen:
# filename: declarative_component_schema.yaml

Expand Down Expand Up @@ -543,6 +541,7 @@ class Action(Enum):
IGNORE = "IGNORE"
RESET_PAGINATION = "RESET_PAGINATION"
RATE_LIMITED = "RATE_LIMITED"
REFRESH_TOKEN_THEN_RETRY = "REFRESH_TOKEN_THEN_RETRY"


class FailureType(Enum):
Expand All @@ -563,6 +562,7 @@ class HttpResponseFilter(BaseModel):
"IGNORE",
"RESET_PAGINATION",
"RATE_LIMITED",
"REFRESH_TOKEN_THEN_RETRY",
],
title="Action",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ResponseAction(Enum):
IGNORE = "IGNORE"
RESET_PAGINATION = "RESET_PAGINATION"
RATE_LIMITED = "RATE_LIMITED"
REFRESH_TOKEN_THEN_RETRY = "REFRESH_TOKEN_THEN_RETRY"


@dataclass
Expand Down
38 changes: 34 additions & 4 deletions airbyte_cdk/sources/streams/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ def __str__(self) -> str:
class HttpClient:
_DEFAULT_MAX_RETRY: int = 5
_DEFAULT_MAX_TIME: int = 60 * 10
_ACTIONS_TO_RETRY_ON = {ResponseAction.RETRY, ResponseAction.RATE_LIMITED}
_ACTIONS_TO_RETRY_ON = {
ResponseAction.RETRY,
ResponseAction.RATE_LIMITED,
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
}

def __init__(
self,
Expand Down Expand Up @@ -452,6 +456,31 @@ def _handle_error_resolution(
# backoff retry loop. Adding `\n` to the message and ignore 'end' ensure that few messages are printed at the same time.
print(f"{message}\n", end="", flush=True)

# Handle REFRESH_TOKEN_THEN_RETRY: Force refresh the OAuth token before retry
# This is useful when the API returns 401 but the stored token expiry hasn't been reached yet
# Only OAuth authenticators have refresh_and_set_access_token method
# Non-OAuth auth types (e.g., BearerAuthenticator) will fall through to normal retry
if error_resolution.response_action == ResponseAction.REFRESH_TOKEN_THEN_RETRY:
if (
hasattr(self._session, "auth")
and self._session.auth is not None
and hasattr(self._session.auth, "refresh_and_set_access_token")
):
try:
self._session.auth.refresh_and_set_access_token() # type: ignore[union-attr]
self._logger.info(
"Refreshed OAuth token due to REFRESH_TOKEN_THEN_RETRY response action"
)
except Exception as refresh_error:
self._logger.warning(
f"Failed to refresh OAuth token: {refresh_error}. Proceeding with retry using existing token."
)
else:
self._logger.warning(
"REFRESH_TOKEN_THEN_RETRY action received but authenticator does not support token refresh. "
"Proceeding with normal retry."
)

if error_resolution.response_action == ResponseAction.FAIL:
if response is not None:
filtered_response_message = filter_secrets(
Expand Down Expand Up @@ -481,9 +510,10 @@ def _handle_error_resolution(
self._logger.info(error_resolution.error_message or log_message)

# TODO: Consider dynamic retry count depending on subsequent error codes
elif (
error_resolution.response_action == ResponseAction.RETRY
or error_resolution.response_action == ResponseAction.RATE_LIMITED
elif error_resolution.response_action in (
ResponseAction.RETRY,
ResponseAction.RATE_LIMITED,
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
):
user_defined_backoff_time = None
for backoff_strategy in self._backoff_strategies:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,21 @@ def get_auth_header(self) -> Mapping[str, Any]:
def get_access_token(self) -> str:
"""Returns the access token"""
if self.token_has_expired():
token, expires_in = self.refresh_access_token()
self.access_token = token
self.set_token_expiry_date(expires_in)
self.refresh_and_set_access_token()

return self.access_token

def refresh_and_set_access_token(self) -> None:
"""Force refresh the access token and update internal state.

This method refreshes the access token regardless of whether it has expired,
and updates the internal token and expiry date. Subclasses may override this
to handle additional state updates (e.g., persisting new refresh tokens).
"""
token, expires_in = self.refresh_access_token()
self.access_token = token
self.set_token_expiry_date(expires_in)

def token_has_expired(self) -> bool:
"""Returns True if the token is expired"""
return ab_datetime_now() > self.get_token_expiry_date()
Expand Down
24 changes: 16 additions & 8 deletions airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,20 +318,28 @@ def token_has_expired(self) -> bool:

def get_access_token(self) -> str:
"""Retrieve new access and refresh token if the access token has expired.
The new refresh token is persisted with the set_refresh_token function

The new refresh token is persisted with the set_refresh_token function.

Returns:
str: The current access_token, updated if it was previously expired.
"""
if self.token_has_expired():
new_access_token, access_token_expires_in, new_refresh_token = (
self.refresh_access_token()
)
self.access_token = new_access_token
self.set_refresh_token(new_refresh_token)
self.set_token_expiry_date(access_token_expires_in)
self._emit_control_message()
self.refresh_and_set_access_token()
return self.access_token

def refresh_and_set_access_token(self) -> None:
"""Force refresh the access token and update internal state.

For single-use refresh tokens, this also persists the new refresh token
and emits a control message to update the connector config.
"""
new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token()
self.access_token = new_access_token
self.set_refresh_token(new_refresh_token)
self.set_token_expiry_date(access_token_expires_in)
self._emit_control_message()

def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override]
"""
Refreshes the access token by making a handled request and extracting the necessary token information.
Expand Down
222 changes: 222 additions & 0 deletions unit_tests/sources/streams/http/test_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,3 +837,225 @@ def backoff_time(self, response_or_exception, attempt_count):
with pytest.raises(AirbyteTracedException) as e:
http_client.send_request(http_method="get", url="https://airbyte.io/", request_kwargs={})
assert e.value.failure_type == expected_failure_type


class MockOAuthAuthenticator:
def __init__(self):
self.access_token = "old_token"
self._token_expiry_date = None
self.refresh_called = False

def refresh_and_set_access_token(self):
self.refresh_called = True
self.access_token = "new_refreshed_token"
self._token_expiry_date = "2099-01-01T00:00:00Z"

def __call__(self, request):
request.headers["Authorization"] = f"Bearer {self.access_token}"
return request


def test_refresh_token_then_retry_action_refreshes_oauth_token(mocker):
mock_authenticator = MockOAuthAuthenticator()
mocked_session = MagicMock(spec=requests.Session)
mocked_session.auth = mock_authenticator

http_client = HttpClient(
name="test",
logger=MagicMock(),
error_handler=HttpStatusErrorHandler(
logger=MagicMock(),
error_mapping={
401: ErrorResolution(
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
FailureType.transient_error,
"Token expired, refreshing",
)
},
),
session=mocked_session,
)

prepared_request = requests.PreparedRequest()
mocked_response = MagicMock(spec=requests.Response)
mocked_response.status_code = 401
mocked_response.headers = {}
mocked_response.ok = False
mocked_session.send.return_value = mocked_response

with pytest.raises(DefaultBackoffException):
http_client._send(prepared_request, {})

assert mock_authenticator.refresh_called
assert mock_authenticator.access_token == "new_refreshed_token"
assert mock_authenticator._token_expiry_date == "2099-01-01T00:00:00Z"


def test_refresh_token_then_retry_action_without_oauth_authenticator_proceeds_with_retry(mocker):
mocked_session = MagicMock(spec=requests.Session)
mocked_session.auth = None

mocked_logger = MagicMock()
http_client = HttpClient(
name="test",
logger=mocked_logger,
error_handler=HttpStatusErrorHandler(
logger=MagicMock(),
error_mapping={
401: ErrorResolution(
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
FailureType.transient_error,
"Token expired, refreshing",
)
},
),
session=mocked_session,
)

prepared_request = requests.PreparedRequest()
mocked_response = MagicMock(spec=requests.Response)
mocked_response.status_code = 401
mocked_response.headers = {}
mocked_response.ok = False
mocked_session.send.return_value = mocked_response

with pytest.raises(DefaultBackoffException):
http_client._send(prepared_request, {})

mocked_logger.warning.assert_called()


def test_refresh_token_then_retry_action_handles_refresh_failure_gracefully(mocker):
class FailingOAuthAuthenticator:
def __init__(self):
self.access_token = "old_token"

def refresh_and_set_access_token(self):
raise Exception("Token refresh failed")

def __call__(self, request):
return request

mock_authenticator = FailingOAuthAuthenticator()
mocked_session = MagicMock(spec=requests.Session)
mocked_session.auth = mock_authenticator

mocked_logger = MagicMock()
http_client = HttpClient(
name="test",
logger=mocked_logger,
error_handler=HttpStatusErrorHandler(
logger=MagicMock(),
error_mapping={
401: ErrorResolution(
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
FailureType.transient_error,
"Token expired, refreshing",
)
},
),
session=mocked_session,
)

prepared_request = requests.PreparedRequest()
mocked_response = MagicMock(spec=requests.Response)
mocked_response.status_code = 401
mocked_response.headers = {}
mocked_response.ok = False
mocked_session.send.return_value = mocked_response

with pytest.raises(DefaultBackoffException):
http_client._send(prepared_request, {})

mocked_logger.warning.assert_called()


def test_refresh_token_then_retry_action_with_single_use_refresh_token_authenticator(mocker):
from airbyte_cdk.sources.streams.http.requests_native_auth import (
SingleUseRefreshTokenOauth2Authenticator,
)

mock_authenticator = MagicMock(spec=SingleUseRefreshTokenOauth2Authenticator)

mocked_session = MagicMock(spec=requests.Session)
mocked_session.auth = mock_authenticator

http_client = HttpClient(
name="test",
logger=MagicMock(),
error_handler=HttpStatusErrorHandler(
logger=MagicMock(),
error_mapping={
401: ErrorResolution(
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
FailureType.transient_error,
"Token expired, refreshing",
)
},
),
session=mocked_session,
)

prepared_request = requests.PreparedRequest()
mocked_response = MagicMock(spec=requests.Response)
mocked_response.status_code = 401
mocked_response.headers = {}
mocked_response.ok = False
mocked_session.send.return_value = mocked_response

with pytest.raises(DefaultBackoffException):
http_client._send(prepared_request, {})

mock_authenticator.refresh_and_set_access_token.assert_called_once()


@pytest.mark.usefixtures("mock_sleep")
def test_refresh_token_then_retry_action_retries_and_succeeds_after_token_refresh():
mock_authenticator = MockOAuthAuthenticator()
mocked_session = MagicMock(spec=requests.Session)
mocked_session.auth = mock_authenticator

valid_response = MagicMock(spec=requests.Response)
valid_response.status_code = 200
valid_response.ok = True
valid_response.headers = {}

call_count = 0

def update_response(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
retry_response = MagicMock(spec=requests.Response)
retry_response.ok = False
retry_response.status_code = 401
retry_response.headers = {}
return retry_response
else:
return valid_response

mocked_session.send.side_effect = update_response

http_client = HttpClient(
name="test",
logger=MagicMock(),
error_handler=HttpStatusErrorHandler(
logger=MagicMock(),
error_mapping={
401: ErrorResolution(
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
FailureType.transient_error,
"Token expired, refreshing",
)
},
),
session=mocked_session,
)

prepared_request = requests.PreparedRequest()
returned_response = http_client._send_with_retry(prepared_request, request_kwargs={})

assert mock_authenticator.refresh_called
assert mock_authenticator.access_token == "new_refreshed_token"
assert returned_response == valid_response
assert call_count == 2