Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race conditions in the Authentication client #2635

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import logging
import os
import re
import threading
import time
import typing
import urllib.parse as _urlparse
import webbrowser
Expand Down Expand Up @@ -236,6 +238,9 @@
self._verify = verify
self._headers = {"content-type": "application/x-www-form-urlencoded"}
self._session = session or requests.Session()
self._lock = threading.Lock()
self._cached_credentials = None
self._cached_credentials_ts = None

self._request_auth_code_params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
Expand Down Expand Up @@ -339,25 +344,38 @@

def get_creds_from_remote(self) -> Credentials:
"""
This is the entrypoint method. It will kickoff the full authentication flow and trigger a web-browser to
retrieve credentials
This is the entrypoint method. It will kickoff the full authentication
flow and trigger a web-browser to retrieve credentials. Because this
needs to open a port on localhost and may be called from a
multithreaded context (e.g. pyflyte register), this call may block
multiple threads and return a cached result for up to 60 seconds.
"""
# In the absence of globally-set token values, initiate the token request flow
q = Queue()
with self._lock:
# Clear cache if it's been more than 60 seconds since the last check
cache_ttl_s = 60

Check warning on line 356 in flytekit/clients/auth/auth_client.py

View check run for this annotation

Codecov / codecov/patch

flytekit/clients/auth/auth_client.py#L356

Added line #L356 was not covered by tests
if self._cached_credentials_ts is not None and self._cached_credentials_ts + cache_ttl_s < time.monotonic():
self._cached_credentials = None

Check warning on line 358 in flytekit/clients/auth/auth_client.py

View check run for this annotation

Codecov / codecov/patch

flytekit/clients/auth/auth_client.py#L358

Added line #L358 was not covered by tests
rdeaton-freenome marked this conversation as resolved.
Show resolved Hide resolved

# First prepare the callback server in the background
server = self._create_callback_server()
if self._cached_credentials is not None:
return self._cached_credentials
q = Queue()

Check warning on line 362 in flytekit/clients/auth/auth_client.py

View check run for this annotation

Codecov / codecov/patch

flytekit/clients/auth/auth_client.py#L361-L362

Added lines #L361 - L362 were not covered by tests

self._request_authorization_code()
# First prepare the callback server in the background
server = self._create_callback_server()

Check warning on line 365 in flytekit/clients/auth/auth_client.py

View check run for this annotation

Codecov / codecov/patch

flytekit/clients/auth/auth_client.py#L365

Added line #L365 was not covered by tests

server.handle_request(q)
server.server_close()
self._request_authorization_code()

Check warning on line 367 in flytekit/clients/auth/auth_client.py

View check run for this annotation

Codecov / codecov/patch

flytekit/clients/auth/auth_client.py#L367

Added line #L367 was not covered by tests

# Send the call to request the authorization code in the background
server.handle_request(q)
server.server_close()

Check warning on line 370 in flytekit/clients/auth/auth_client.py

View check run for this annotation

Codecov / codecov/patch

flytekit/clients/auth/auth_client.py#L369-L370

Added lines #L369 - L370 were not covered by tests

# Request the access token once the auth code has been received.
auth_code = q.get()
return self._request_access_token(auth_code)
# Send the call to request the authorization code in the background

# Request the access token once the auth code has been received.
auth_code = q.get()
self._cached_credentials = self._request_access_token(auth_code)
self._cached_credentials_ts = time.monotonic()
return self._cached_credentials

Check warning on line 378 in flytekit/clients/auth/auth_client.py

View check run for this annotation

Codecov / codecov/patch

flytekit/clients/auth/auth_client.py#L375-L378

Added lines #L375 - L378 were not covered by tests

def refresh_access_token(self, credentials: Credentials) -> Credentials:
if credentials.refresh_token is None:
Expand Down
Loading