2626else :
2727 import urllib
2828
29- from .retry import ExponentialRetryPolicy
29+ from .retry import ExponentialRetryPolicy , retry_decorator_for_auth
3030
3131# 3rd party imports
3232import adal
7474def auth (tenant_id = None , username = None ,
7575 password = None , client_id = default_client ,
7676 client_secret = None , resource = DEFAULT_RESOURCE_ENDPOINT ,
77- require_2fa = False , authority = None , ** kwargs ):
77+ require_2fa = False , authority = None , retry_policy = None , ** kwargs ):
7878 """ User/password authentication
7979
8080 Parameters
@@ -103,6 +103,7 @@ def auth(tenant_id=None, username=None,
103103 -------
104104 :type DataLakeCredential :mod: `A DataLakeCredential object`
105105 """
106+
106107 if not authority :
107108 authority = 'https://login.microsoftonline.com/'
108109
@@ -124,24 +125,30 @@ def auth(tenant_id=None, username=None,
124125 if not client_secret :
125126 client_secret = os .environ .get ('azure_client_secret' , None )
126127
127- # You can explicitly authenticate with 2fa, or pass in nothing to the auth call and
128+ # You can explicitly authenticate with 2fa, or pass in nothing to the auth call
128129 # and the user will be prompted to login interactively through a browser.
129- if require_2fa or (username is None and password is None and client_secret is None ):
130- code = context .acquire_user_code (resource , client_id )
131- print (code ['message' ])
132- out = context .acquire_token_with_device_code (resource , code , client_id )
133-
134- elif username and password :
135- out = context .acquire_token_with_username_password (resource , username ,
136- password , client_id )
137- elif client_id and client_secret :
138- out = context .acquire_token_with_client_credentials (resource , client_id ,
139- client_secret )
140- # for service principal, we store the secret in the credential object for use when refreshing.
141- out .update ({'secret' : client_secret })
142- else :
143- raise ValueError ("No authentication method found for credentials" )
144130
131+ @retry_decorator_for_auth (retry_policy = retry_policy )
132+ def get_token_internal ():
133+ # Internal function used so as to use retry decorator
134+ if require_2fa or (username is None and password is None and client_secret is None ):
135+ code = context .acquire_user_code (resource , client_id )
136+ print (code ['message' ])
137+ out = context .acquire_token_with_device_code (resource , code , client_id )
138+
139+ elif username and password :
140+ out = context .acquire_token_with_username_password (resource , username ,
141+ password , client_id )
142+ elif client_id and client_secret :
143+ out = context .acquire_token_with_client_credentials (resource , client_id ,
144+ client_secret )
145+ # for service principal, we store the secret in the credential object for use when refreshing.
146+ out .update ({'secret' : client_secret })
147+ else :
148+ raise ValueError ("No authentication method found for credentials" )
149+ return out
150+
151+ out = get_token_internal ()
145152 out .update ({'access' : out ['accessToken' ], 'resource' : resource ,
146153 'refresh' : out .get ('refreshToken' , False ),
147154 'time' : time .time (), 'tenant' : tenant_id , 'client' : client_id })
@@ -152,22 +159,22 @@ class DataLakeCredential:
152159 def __init__ (self , token ):
153160 self .token = token
154161
155- def signed_session (self ):
162+ def signed_session (self , retry_policy = None ):
156163 # type: () -> requests.Session
157164 """Create requests session with any required auth headers applied.
158165
159166 :rtype: requests.Session
160167 """
161168 session = requests .Session ()
162169 if time .time () - self .token ['time' ] > self .token ['expiresIn' ] - 100 :
163- self .refresh_token ()
170+ self .refresh_token (retry_poliy = retry_policy )
164171
165172 scheme , token = self .token ['tokenType' ], self .token ['access' ]
166173 header = "{} {}" .format (scheme , token )
167174 session .headers ['Authorization' ] = header
168175 return session
169176
170- def refresh_token (self , authority = None ):
177+ def refresh_token (self , authority = None , retry_policy = None ):
171178 """ Refresh an expired authorization token
172179
173180 Parameters
@@ -183,15 +190,22 @@ def refresh_token(self, authority=None):
183190
184191 context = adal .AuthenticationContext (authority +
185192 self .token ['tenant' ])
186- if self .token .get ('secret' ) and self .token .get ('client' ):
187- out = context .acquire_token_with_client_credentials (self .token ['resource' ], self .token ['client' ],
188- self .token ['secret' ])
189- out .update ({'secret' : self .token ['secret' ]})
190- else :
191- out = context .acquire_token_with_refresh_token (self .token ['refresh' ],
192- client_id = self .token ['client' ],
193- resource = self .token ['resource' ])
194- out .update ({'refresh' : out ['refreshToken' ]})
193+
194+ @retry_decorator_for_auth (retry_policy = retry_policy )
195+ def get_token_internal ():
196+ # Internal function used so as to use retry decorator
197+ if self .token .get ('secret' ) and self .token .get ('client' ):
198+ out = context .acquire_token_with_client_credentials (self .token ['resource' ],
199+ self .token ['client' ],
200+ self .token ['secret' ])
201+ out .update ({'secret' : self .token ['secret' ]})
202+ else :
203+ out = context .acquire_token_with_refresh_token (self .token ['refresh' ],
204+ client_id = self .token ['client' ],
205+ resource = self .token ['resource' ])
206+ return out
207+
208+ out = get_token_internal ()
195209 # common items to update
196210 out .update ({'access' : out ['accessToken' ],
197211 'time' : time .time (), 'tenant' : self .token ['tenant' ],
@@ -257,7 +271,7 @@ def __init__(self, store_name=default_store, token=None,
257271 # There is a case where the user can opt to exclude an API version, in which case
258272 # the service itself decides on the API version to use (it's default).
259273 self .api_version = api_version or None
260- self .head = {'Authorization' : token .signed_session ().headers ['Authorization' ]}
274+ self .head = {'Authorization' : token .signed_session (retry_policy = None ).headers ['Authorization' ]}
261275 self .url = 'https://%s.%s/' % (store_name , url_suffix )
262276 self .webhdfs = 'webhdfs/v1/'
263277 self .extended_operations = 'webhdfsext/'
@@ -282,8 +296,8 @@ def session(self):
282296 self .local .session = s
283297 return s
284298
285- def _check_token (self ):
286- cur_session = self .token .signed_session ()
299+ def _check_token (self , retry_policy = None ):
300+ cur_session = self .token .signed_session (retry_policy = retry_policy )
287301 if not self .head or self .head .get ('Authorization' ) != cur_session .headers ['Authorization' ]:
288302 self .head = {'Authorization' : cur_session .headers ['Authorization' ]}
289303 self .local .session = None
0 commit comments