@@ -155,7 +155,9 @@ def get_token_internal():
155155
156156 return DataLakeCredential (out )
157157
158+
158159class DataLakeCredential :
160+ # Be careful modifying this. DataLakeCredential is a general class in azure, and we have to maintain parity.
159161 def __init__ (self , token ):
160162 self .token = token
161163
@@ -191,21 +193,15 @@ def refresh_token(self, authority=None, retry_policy=None):
191193 context = adal .AuthenticationContext (authority +
192194 self .token ['tenant' ])
193195
194- @retry_decorator_for_auth (retry_policy = retry_policy )
195- def get_token_internal ():
196- # Internal function used so as to use retry decorator
197- if self .token .get ('secret' ) and self .token .get ('client' ):
198- out = context .acquire_token_with_client_credentials (self .token ['resource' ],
199- self .token ['client' ],
200- self .token ['secret' ])
201- out .update ({'secret' : self .token ['secret' ]})
202- else :
203- out = context .acquire_token_with_refresh_token (self .token ['refresh' ],
204- client_id = self .token ['client' ],
205- resource = self .token ['resource' ])
206- return out
207-
208- out = get_token_internal ()
196+ if self .token .get ('secret' ) and self .token .get ('client' ):
197+ out = context .acquire_token_with_client_credentials (self .token ['resource' ],
198+ self .token ['client' ],
199+ self .token ['secret' ])
200+ out .update ({'secret' : self .token ['secret' ]})
201+ else :
202+ out = context .acquire_token_with_refresh_token (self .token ['refresh' ],
203+ client_id = self .token ['client' ],
204+ resource = self .token ['resource' ])
209205 # common items to update
210206 out .update ({'access' : out ['accessToken' ],
211207 'time' : time .time (), 'tenant' : self .token ['tenant' ],
@@ -271,7 +267,9 @@ def __init__(self, store_name=default_store, token=None,
271267 # There is a case where the user can opt to exclude an API version, in which case
272268 # the service itself decides on the API version to use (it's default).
273269 self .api_version = api_version or None
274- self .head = {'Authorization' : token .signed_session (retry_policy = None ).headers ['Authorization' ]}
270+ self .head = None
271+ self ._check_token () # Retryable method. Will ensure that signed_session token is current when we set it on next line
272+ self .head = {'Authorization' : token .signed_session ().headers ['Authorization' ]}
275273 self .url = 'https://%s.%s/' % (store_name , url_suffix )
276274 self .webhdfs = 'webhdfs/v1/'
277275 self .extended_operations = 'webhdfsext/'
@@ -296,11 +294,15 @@ def session(self):
296294 self .local .session = s
297295 return s
298296
299- def _check_token (self , retry_policy = None ):
300- cur_session = self .token .signed_session (retry_policy = retry_policy )
301- if not self .head or self .head .get ('Authorization' ) != cur_session .headers ['Authorization' ]:
302- self .head = {'Authorization' : cur_session .headers ['Authorization' ]}
303- self .local .session = None
297+
298+ def _check_token (self , retry_policy = None ):
299+ @retry_decorator_for_auth (retry_policy = retry_policy )
300+ def check_token_internal ():
301+ cur_session = self .token .signed_session ()
302+ if not self .head or self .head .get ('Authorization' ) != cur_session .headers ['Authorization' ]:
303+ self .head = {'Authorization' : cur_session .headers ['Authorization' ]}
304+ self .local .session = None
305+ check_token_internal ()
304306
305307 def _log_request (self , method , url , op , path , params , headers , retry_count ):
306308 msg = "HTTP Request\n {} {}\n " .format (method .upper (), url )
0 commit comments