Skip to content

Commit

Permalink
Remove use of multiprocessing from the OAuth client (flyteorg#2626)
Browse files Browse the repository at this point in the history
* Remove use of multiprocessing from the OAuth client

Signed-off-by: Robert Deaton <robert.deaton@freenome.com>

* Lint

Signed-off-by: Robert Deaton <robert.deaton@freenome.com>

---------

Signed-off-by: Robert Deaton <robert.deaton@freenome.com>
Signed-off-by: mao3267 <chenvincent610@gmail.com>
  • Loading branch information
rdeaton-freenome authored and mao3267 committed Aug 2, 2024
1 parent 1e7bfbe commit 9129608
Showing 1 changed file with 11 additions and 18 deletions.
29 changes: 11 additions & 18 deletions flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import hashlib
import http.server as _BaseHTTPServer
import logging
import multiprocessing
import os
import re
import typing
import urllib.parse as _urlparse
import webbrowser
from dataclasses import dataclass
from http import HTTPStatus as _StatusCodes
from multiprocessing import get_context
from queue import Queue
from urllib.parse import urlencode as _urlencode

import click
Expand Down Expand Up @@ -124,7 +123,7 @@ def __init__(
request_handler_class: typing.Type[_BaseHTTPServer.BaseHTTPRequestHandler],
bind_and_activate: bool = True,
redirect_path: str = None,
queue: multiprocessing.Queue = None,
queue: Queue = None,
):
_BaseHTTPServer.HTTPServer.__init__(self, server_address, request_handler_class, bind_and_activate)
self._redirect_path = redirect_path
Expand All @@ -142,9 +141,8 @@ def remote_metadata(self) -> EndpointMetadata:

def handle_authorization_code(self, auth_code: str):
self._queue.put(auth_code)
self.server_close()

def handle_request(self, queue: multiprocessing.Queue = None) -> typing.Any:
def handle_request(self, queue: Queue = None) -> typing.Any:
self._queue = queue
return super().handle_request()

Expand Down Expand Up @@ -345,26 +343,21 @@ def get_creds_from_remote(self) -> Credentials:
retrieve credentials
"""
# In the absence of globally-set token values, initiate the token request flow
ctx = get_context("fork")
q = ctx.Queue()
q = Queue()

# First prepare the callback server in the background
server = self._create_callback_server()

server_process = ctx.Process(target=server.handle_request, args=(q,))
server_process.daemon = True
self._request_authorization_code()

try:
server_process.start()
server.handle_request(q)
server.server_close()

# Send the call to request the authorization code in the background
self._request_authorization_code()
# Send the call to request the authorization code in the background

# Request the access token once the auth code has been received.
auth_code = q.get()
return self._request_access_token(auth_code)
finally:
server_process.terminate()
# Request the access token once the auth code has been received.
auth_code = q.get()
return self._request_access_token(auth_code)

def refresh_access_token(self, credentials: Credentials) -> Credentials:
if credentials.refresh_token is None:
Expand Down

0 comments on commit 9129608

Please sign in to comment.