Skip to content

fix: no longer require token expiration from the OAuth resource server #462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 16, 2025
4 changes: 2 additions & 2 deletions airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def get_token_expiry_date(self) -> AirbyteDateTime:
def _has_access_token_been_initialized(self) -> bool:
return self._access_token is not None

def set_token_expiry_date(self, value: Union[str, int]) -> None:
self._token_expiry_date = self._parse_token_expiration_date(value)
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
self._token_expiry_date = value

def get_assertion_name(self) -> str:
return self.assertion_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
headers = self.get_refresh_request_headers()
return headers if headers else None

def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
"""
Returns the refresh token and its expiration datetime

Expand All @@ -148,6 +148,14 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
# PRIVATE METHODS
# ----------------

def _default_token_expiry_date(self) -> AirbyteDateTime:
"""
Returns the default token expiry date
"""
# 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
default_token_expiry_duration_hours = 1 # 1 hour
return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)

def _wrap_refresh_token_exception(
self, exception: requests.exceptions.RequestException
) -> bool:
Expand Down Expand Up @@ -257,14 +265,10 @@ def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) ->

def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
"""
Return the expiration datetime of the refresh token
Parse a string or integer token expiration date into a datetime object

:return: expiration datetime
"""
if not value and not self.token_has_expired():
# No expiry token was provided but the previous one is not expired so it's fine
return self.get_token_expiry_date()

if self.token_expiry_is_time_of_expiration:
if not self.token_expiry_date_format:
raise ValueError(
Expand Down Expand Up @@ -308,17 +312,30 @@ def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
"""
return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())

def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any:
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
"""
Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.

If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.

Args:
response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.

Returns:
str: The extracted token_expiry_date.
The extracted token_expiry_date or None if not found.
"""
return self._find_and_get_value_from_response(response_data, self.get_expires_in_name())
expires_in = self._find_and_get_value_from_response(
response_data, self.get_expires_in_name()
)
if expires_in is not None:
return self._parse_token_expiration_date(expires_in)

# expires_in is None
existing_expiry_date = self.get_token_expiry_date()
if existing_expiry_date and not self.token_has_expired():
return existing_expiry_date

return self._default_token_expiry_date()

def _find_and_get_value_from_response(
self,
Expand All @@ -344,7 +361,7 @@ def _find_and_get_value_from_response(
"""
if current_depth > max_depth:
# this is needed to avoid an inf loop, possible with a very deep nesting observed.
message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response."
message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
raise ResponseKeysMaxRecurtionReached(
internal_message=message, message=message, failure_type=FailureType.config_error
)
Expand Down Expand Up @@ -441,7 +458,7 @@ def get_token_expiry_date(self) -> AirbyteDateTime:
"""Expiration date of the access token"""

@abstractmethod
def set_token_expiry_date(self, value: Union[str, int]) -> None:
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
"""Setter for access token expiration date"""

@abstractmethod
Expand Down
31 changes: 4 additions & 27 deletions airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def get_grant_type(self) -> str:
def get_token_expiry_date(self) -> AirbyteDateTime:
return self._token_expiry_date

def set_token_expiry_date(self, value: Union[str, int]) -> None:
self._token_expiry_date = self._parse_token_expiration_date(value)
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
self._token_expiry_date = value

@property
def token_expiry_is_time_of_expiration(self) -> bool:
Expand Down Expand Up @@ -316,26 +316,6 @@ def token_has_expired(self) -> bool:
"""Returns True if the token is expired"""
return ab_datetime_now() > self.get_token_expiry_date()

@staticmethod
def get_new_token_expiry_date(
access_token_expires_in: str,
token_expiry_date_format: str | None = None,
) -> AirbyteDateTime:
"""
Calculate the new token expiry date based on the provided expiration duration or format.

Args:
access_token_expires_in (str): The duration (in seconds) until the access token expires, or the expiry date in a specific format.
token_expiry_date_format (str | None, optional): The format of the expiry date if provided. Defaults to None.

Returns:
AirbyteDateTime: The calculated expiry date of the access token.
"""
if token_expiry_date_format:
return ab_datetime_parse(access_token_expires_in)
else:
return ab_datetime_now() + timedelta(seconds=int(access_token_expires_in))

def get_access_token(self) -> str:
"""Retrieve new access and refresh token if the access token has expired.
The new refresh token is persisted with the set_refresh_token function
Expand All @@ -346,16 +326,13 @@ def get_access_token(self) -> str:
new_access_token, access_token_expires_in, new_refresh_token = (
self.refresh_access_token()
)
new_token_expiry_date: AirbyteDateTime = self.get_new_token_expiry_date(
access_token_expires_in, self._token_expiry_date_format
)
self.access_token = new_access_token
self.set_refresh_token(new_refresh_token)
self.set_token_expiry_date(new_token_expiry_date)
self.set_token_expiry_date(access_token_expires_in)
self._emit_control_message()
return self.access_token

def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override]
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override]
"""
Refreshes the access token by making a handled request and extracting the necessary token information.

Expand Down
73 changes: 66 additions & 7 deletions unit_tests/sources/declarative/auth/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def test_error_on_refresh_token_grant_without_refresh_token(self):
grant_type="refresh_token",
)

@freezegun.freeze_time("2022-01-01")
def test_refresh_access_token(self, mocker):
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
Expand All @@ -225,13 +226,15 @@ def test_refresh_access_token(self, mocker):
resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}
)
mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True)
token = oauth.refresh_access_token()
access_token, token_expiry_date = oauth.refresh_access_token()

assert ("access_token", 1000) == token
assert access_token == "access_token"
assert token_expiry_date == ab_datetime_now() + timedelta(seconds=1000)

filtered = filter_secrets("access_token")
assert filtered == "****"

@freezegun.freeze_time("2022-01-01")
def test_refresh_access_token_when_headers_provided(self, mocker):
expected_headers = {
"Authorization": "Bearer some_access_token",
Expand All @@ -256,9 +259,10 @@ def test_refresh_access_token_when_headers_provided(self, mocker):
mocked_request = mocker.patch.object(
requests, "request", side_effect=mock_request, autospec=True
)
token = oauth.refresh_access_token()
access_token, token_expiry_date = oauth.refresh_access_token()

assert ("access_token", 1000) == token
assert access_token == "access_token"
assert token_expiry_date == ab_datetime_now() + timedelta(seconds=1000)

assert mocked_request.call_args.kwargs["headers"] == expected_headers

Expand Down Expand Up @@ -314,6 +318,7 @@ def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp(
assert isinstance(oauth._token_expiry_date, AirbyteDateTime)
assert oauth.get_token_expiry_date() == ab_datetime_parse(expected_date)

@freezegun.freeze_time("2022-01-01")
def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_fetch_access_token(
self,
) -> None:
Expand All @@ -335,12 +340,65 @@ def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_
url="https://refresh_endpoint.com/",
body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token",
),
HttpResponse(body=json.dumps({"access_token": "new_access_token"})),
HttpResponse(
body=json.dumps({"access_token": "new_access_token", "expires_in": 1000})
),
)
oauth.get_access_token()

assert oauth.access_token == "new_access_token"
assert oauth._token_expiry_date == expiry_date
assert oauth._token_expiry_date == ab_datetime_now() + timedelta(seconds=1000)

@freezegun.freeze_time("2022-01-01")
@pytest.mark.parametrize(
"initial_expiry_date_delta, expected_new_expiry_date_delta, expected_access_token",
[
(timedelta(days=1), timedelta(days=1), "some_access_token"),
(timedelta(days=-1), timedelta(hours=1), "new_access_token"),
(None, timedelta(hours=1), "new_access_token"),
],
ids=[
"initial_expiry_date_in_future",
"initial_expiry_date_in_past",
"no_initial_expiry_date",
],
)
def test_no_expiry_date_provided_by_auth_server(
self,
initial_expiry_date_delta,
expected_new_expiry_date_delta,
expected_access_token,
) -> None:
initial_expiry_date = (
ab_datetime_now().add(initial_expiry_date_delta).isoformat()
if initial_expiry_date_delta
else None
)
expected_new_expiry_date = ab_datetime_now().add(expected_new_expiry_date_delta)
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="https://refresh_endpoint.com/",
client_id="some_client_id",
client_secret="some_client_secret",
token_expiry_date=initial_expiry_date,
access_token_value="some_access_token",
refresh_token="some_refresh_token",
config={},
parameters={},
grant_type="client",
)

with HttpMocker() as http_mocker:
http_mocker.post(
HttpRequest(
url="https://refresh_endpoint.com/",
body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token",
),
HttpResponse(body=json.dumps({"access_token": "new_access_token"})),
)
oauth.get_access_token()

assert oauth.access_token == expected_access_token
assert oauth._token_expiry_date == expected_new_expiry_date

@pytest.mark.parametrize(
"expires_in_response, token_expiry_date_format",
Expand Down Expand Up @@ -443,6 +501,7 @@ def test_set_token_expiry_date_no_format(self, mocker, expires_in_response, next
assert "access_token" == token
assert oauth.get_token_expiry_date() == ab_datetime_parse(next_day)

@freezegun.freeze_time("2022-01-01")
def test_profile_assertion(self, mocker):
with HttpMocker() as http_mocker:
jwt = JwtAuthenticator(
Expand Down Expand Up @@ -477,7 +536,7 @@ def test_profile_assertion(self, mocker):

token = oauth.refresh_access_token()

assert ("access_token", 1000) == token
assert ("access_token", ab_datetime_now().add(timedelta(seconds=1000))) == token

filtered = filter_secrets("access_token")
assert filtered == "****"
Expand Down
Loading
Loading