Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race conditions in the Authentication client #2635

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import logging
import os
import re
import threading
import time
import typing
import urllib.parse as _urlparse
import webbrowser
Expand Down Expand Up @@ -236,6 +238,9 @@ def __init__(
self._verify = verify
self._headers = {"content-type": "application/x-www-form-urlencoded"}
self._session = session or requests.Session()
self._lock = threading.Lock()
self._cached_credentials = None
self._cached_credentials_ts = None

self._request_auth_code_params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
Expand Down Expand Up @@ -339,25 +344,36 @@ def _request_access_token(self, auth_code) -> Credentials:

def get_creds_from_remote(self) -> Credentials:
"""
This is the entrypoint method. It will kickoff the full authentication flow and trigger a web-browser to
retrieve credentials
This is the entrypoint method. It will kickoff the full authentication
flow and trigger a web-browser to retrieve credentials. Because this
needs to open a port on localhost and may be called from a
multithreaded context (e.g. pyflyte register), this call may block
multiple threads and return a cached result for up to 60 seconds.
"""
# In the absence of globally-set token values, initiate the token request flow
q = Queue()
with self._lock:
cache_ttl_s = 60
if self._cached_credentials_ts is not None and self._cached_credentials_ts + cache_ttl_s < time.monotonic():
self._cached_credentials = None
rdeaton-freenome marked this conversation as resolved.
Show resolved Hide resolved
if self._cached_credentials is not None:
return self._cached_credentials
q = Queue()

# First prepare the callback server in the background
server = self._create_callback_server()
# First prepare the callback server in the background
server = self._create_callback_server()

self._request_authorization_code()
self._request_authorization_code()

server.handle_request(q)
server.server_close()
server.handle_request(q)
server.server_close()

# Send the call to request the authorization code in the background
# Send the call to request the authorization code in the background

# Request the access token once the auth code has been received.
auth_code = q.get()
return self._request_access_token(auth_code)
# Request the access token once the auth code has been received.
auth_code = q.get()
self._cached_credentials = self._request_access_token(auth_code)
self._cached_credentials_ts = time.monotonic()
return self._cached_credentials

def refresh_access_token(self, credentials: Credentials) -> Credentials:
if credentials.refresh_token is None:
Expand Down
Loading