From 912960835d12d07dea5ab961a95109a58421470c Mon Sep 17 00:00:00 2001 From: rdeaton-freenome <134093844+rdeaton-freenome@users.noreply.github.com> Date: Wed, 31 Jul 2024 15:04:56 -0700 Subject: [PATCH] Remove use of multiprocessing from the OAuth client (#2626) * Remove use of multiprocessing from the OAuth client Signed-off-by: Robert Deaton * Lint Signed-off-by: Robert Deaton --------- Signed-off-by: Robert Deaton Signed-off-by: mao3267 --- flytekit/clients/auth/auth_client.py | 29 +++++++++++----------------- 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/flytekit/clients/auth/auth_client.py b/flytekit/clients/auth/auth_client.py index 8e0f383075..cb77d4a2cf 100644 --- a/flytekit/clients/auth/auth_client.py +++ b/flytekit/clients/auth/auth_client.py @@ -4,7 +4,6 @@ import hashlib import http.server as _BaseHTTPServer import logging -import multiprocessing import os import re import typing @@ -12,7 +11,7 @@ import webbrowser from dataclasses import dataclass from http import HTTPStatus as _StatusCodes -from multiprocessing import get_context +from queue import Queue from urllib.parse import urlencode as _urlencode import click @@ -124,7 +123,7 @@ def __init__( request_handler_class: typing.Type[_BaseHTTPServer.BaseHTTPRequestHandler], bind_and_activate: bool = True, redirect_path: str = None, - queue: multiprocessing.Queue = None, + queue: Queue = None, ): _BaseHTTPServer.HTTPServer.__init__(self, server_address, request_handler_class, bind_and_activate) self._redirect_path = redirect_path @@ -142,9 +141,8 @@ def remote_metadata(self) -> EndpointMetadata: def handle_authorization_code(self, auth_code: str): self._queue.put(auth_code) - self.server_close() - def handle_request(self, queue: multiprocessing.Queue = None) -> typing.Any: + def handle_request(self, queue: Queue = None) -> typing.Any: self._queue = queue return super().handle_request() @@ -345,26 +343,21 @@ def get_creds_from_remote(self) -> Credentials: retrieve credentials """ # In the absence of globally-set token values, initiate the token request flow - ctx = get_context("fork") - q = ctx.Queue() + q = Queue() # First prepare the callback server in the background server = self._create_callback_server() - server_process = ctx.Process(target=server.handle_request, args=(q,)) - server_process.daemon = True + self._request_authorization_code() - try: - server_process.start() + server.handle_request(q) + server.server_close() - # Send the call to request the authorization code in the background - self._request_authorization_code() + # Send the call to request the authorization code in the background - # Request the access token once the auth code has been received. - auth_code = q.get() - return self._request_access_token(auth_code) - finally: - server_process.terminate() + # Request the access token once the auth code has been received. + auth_code = q.get() + return self._request_access_token(auth_code) def refresh_access_token(self, credentials: Credentials) -> Credentials: if credentials.refresh_token is None: