@@ -209,6 +209,7 @@ def __init__(
209209 max_attempts = MAX_ATTEMPTS , # type: int
210210 request_timeout = constants .DEFAULT_REQUEST_TIMEOUT , # type: Union[float, Tuple[float, float]]
211211 handle_retry = exceptions .RetryWithExponentialBackoff (),
212+ service_account_file = None ,
212213 ):
213214 # type: (...) -> None
214215 self ._client_session = ClientSession (
@@ -230,6 +231,19 @@ def __init__(
230231 else :
231232 # mypy cannot follow module import
232233 self ._http_session = self .http .Session () # type: ignore
234+
235+ self .credentials = None
236+ self .auth_req = None
237+ if service_account_file is not None :
238+ import google .auth .transport .requests
239+ from google .oauth2 import service_account
240+
241+ self .auth_req = google .auth .transport .requests .Request ()
242+ self .credentials = service_account .Credentials .from_service_account_file (
243+ service_account_file , scopes = [constants .GCS_READ_ONLY ]
244+ )
245+ self ._http_session .headers .update (self .get_oauth_token ())
246+
233247 self ._http_session .headers .update (self .http_headers )
234248 self ._exceptions = self .HTTP_EXCEPTIONS
235249 self ._auth = auth
@@ -422,6 +436,17 @@ def process(self, http_response):
422436 columns = response .get ("columns" ),
423437 )
424438
439+ @property
440+ def http_session (self ):
441+ return self ._http_session
442+
443+ def get_oauth_token (self ):
444+ self .credentials .refresh (self .auth_req )
445+ return {
446+ constants .PRESTO_EXTRA_CREDENTIAL : "%s = %s"
447+ % (constants .GCS_CREDENTIALS_OAUTH_TOKEN_KEY , self .credentials .token )
448+ }
449+
425450
426451class PrestoResult (object ):
427452 """
@@ -466,12 +491,13 @@ def __init__(
466491 sql , # type: Text
467492 ):
468493 # type: (...) -> None
494+ self .auth_req = request .auth_req # type: Optional[Request]
495+ self .credentials = request .credentials # type: Optional[Credentials]
469496 self .query_id = None # type: Optional[Text]
470497
471498 self ._stats = {} # type: Dict[Any, Any]
472499 self ._warnings = [] # type: List[Dict[Any, Any]]
473500 self ._columns = None # type: Optional[List[Text]]
474-
475501 self ._finished = False
476502 self ._cancelled = False
477503 self ._request = request
@@ -506,6 +532,9 @@ def execute(self):
506532 if self ._cancelled :
507533 raise exceptions .PrestoUserError ("Query has been cancelled" , self .query_id )
508534
535+ if self .credentials is not None and not self .credentials .valid :
536+ self ._request .http_session .headers .update (self ._request .get_oauth_token ())
537+
509538 response = self ._request .post (self ._sql )
510539 status = self ._request .process (response )
511540 self .query_id = status .id
0 commit comments