Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion azure/datalake/store/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def concat(self, outfile, filelist, delete_source=False):
self.azure.call('MSCONCAT', outfile.as_posix(),
data=bytearray(json.dumps(sources,separators=(',', ':')), encoding="utf-8"),
deleteSourceDirectory=delete,
headers={'Content-Type': "application/json"},)
headers={'Content-Type': "application/json"})
self.invalidate_cache(outfile)

merge = concat
Expand Down
80 changes: 47 additions & 33 deletions azure/datalake/store/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
else:
import urllib

from .retry import ExponentialRetryPolicy
from .retry import ExponentialRetryPolicy, retry_decorator_for_auth

# 3rd party imports
import adal
Expand Down Expand Up @@ -74,7 +74,7 @@
def auth(tenant_id=None, username=None,
password=None, client_id=default_client,
client_secret=None, resource=DEFAULT_RESOURCE_ENDPOINT,
require_2fa=False, authority=None, **kwargs):
require_2fa=False, authority=None, retry_policy=None, **kwargs):
""" User/password authentication

Parameters
Expand Down Expand Up @@ -103,6 +103,7 @@ def auth(tenant_id=None, username=None,
-------
:type DataLakeCredential :mod: `A DataLakeCredential object`
"""

if not authority:
authority = 'https://login.microsoftonline.com/'

Expand All @@ -124,24 +125,30 @@ def auth(tenant_id=None, username=None,
if not client_secret:
client_secret = os.environ.get('azure_client_secret', None)

# You can explicitly authenticate with 2fa, or pass in nothing to the auth call and
# You can explicitly authenticate with 2fa, or pass in nothing to the auth call
# and the user will be prompted to login interactively through a browser.
if require_2fa or (username is None and password is None and client_secret is None):
code = context.acquire_user_code(resource, client_id)
print(code['message'])
out = context.acquire_token_with_device_code(resource, code, client_id)

elif username and password:
out = context.acquire_token_with_username_password(resource, username,
password, client_id)
elif client_id and client_secret:
out = context.acquire_token_with_client_credentials(resource, client_id,
client_secret)
# for service principal, we store the secret in the credential object for use when refreshing.
out.update({'secret': client_secret})
else:
raise ValueError("No authentication method found for credentials")

@retry_decorator_for_auth(retry_policy=retry_policy)
def get_token_internal():
# Internal function used so as to use retry decorator
if require_2fa or (username is None and password is None and client_secret is None):
code = context.acquire_user_code(resource, client_id)
print(code['message'])
out = context.acquire_token_with_device_code(resource, code, client_id)

elif username and password:
out = context.acquire_token_with_username_password(resource, username,
password, client_id)
elif client_id and client_secret:
out = context.acquire_token_with_client_credentials(resource, client_id,
client_secret)
# for service principal, we store the secret in the credential object for use when refreshing.
out.update({'secret': client_secret})
else:
raise ValueError("No authentication method found for credentials")
return out

out = get_token_internal()
out.update({'access': out['accessToken'], 'resource': resource,
'refresh': out.get('refreshToken', False),
'time': time.time(), 'tenant': tenant_id, 'client': client_id})
Expand All @@ -152,22 +159,22 @@ class DataLakeCredential:
def __init__(self, token):
self.token = token

def signed_session(self):
def signed_session(self, retry_policy=None):
# type: () -> requests.Session
"""Create requests session with any required auth headers applied.

:rtype: requests.Session
"""
session = requests.Session()
if time.time() - self.token['time'] > self.token['expiresIn'] - 100:
self.refresh_token()
self.refresh_token(retry_poliy=retry_policy)

scheme, token = self.token['tokenType'], self.token['access']
header = "{} {}".format(scheme, token)
session.headers['Authorization'] = header
return session

def refresh_token(self, authority=None):
def refresh_token(self, authority=None, retry_policy=None):
""" Refresh an expired authorization token

Parameters
Expand All @@ -183,15 +190,22 @@ def refresh_token(self, authority=None):

context = adal.AuthenticationContext(authority +
self.token['tenant'])
if self.token.get('secret') and self.token.get('client'):
out = context.acquire_token_with_client_credentials(self.token['resource'], self.token['client'],
self.token['secret'])
out.update({'secret': self.token['secret']})
else:
out = context.acquire_token_with_refresh_token(self.token['refresh'],
client_id=self.token['client'],
resource=self.token['resource'])
out.update({'refresh': out['refreshToken']})

@retry_decorator_for_auth(retry_policy=retry_policy)
def get_token_internal():
# Internal function used so as to use retry decorator
if self.token.get('secret') and self.token.get('client'):
out = context.acquire_token_with_client_credentials(self.token['resource'],
self.token['client'],
self.token['secret'])
out.update({'secret': self.token['secret']})
else:
out = context.acquire_token_with_refresh_token(self.token['refresh'],
client_id=self.token['client'],
resource=self.token['resource'])
return out

out = get_token_internal()
# common items to update
out.update({'access': out['accessToken'],
'time': time.time(), 'tenant': self.token['tenant'],
Expand Down Expand Up @@ -257,7 +271,7 @@ def __init__(self, store_name=default_store, token=None,
# There is a case where the user can opt to exclude an API version, in which case
# the service itself decides on the API version to use (it's default).
self.api_version = api_version or None
self.head = {'Authorization': token.signed_session().headers['Authorization']}
self.head = {'Authorization': token.signed_session(retry_policy=None).headers['Authorization']}
self.url = 'https://%s.%s/' % (store_name, url_suffix)
self.webhdfs = 'webhdfs/v1/'
self.extended_operations = 'webhdfsext/'
Expand All @@ -282,8 +296,8 @@ def session(self):
self.local.session = s
return s

def _check_token(self):
cur_session = self.token.signed_session()
def _check_token(self, retry_policy=None):
cur_session = self.token.signed_session(retry_policy=retry_policy)
if not self.head or self.head.get('Authorization') != cur_session.headers['Authorization']:
self.head = {'Authorization': cur_session.headers['Authorization']}
self.local.session = None
Expand Down
67 changes: 63 additions & 4 deletions azure/datalake/store/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import sys
import time

from functools import wraps
# local imports

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -45,6 +45,9 @@ def should_retry(self, response, last_exception, retry_count):
self.__backoff()
return True

if response is None:
return False

status_code = response.status_code

if(status_code == 501
Expand All @@ -58,8 +61,8 @@ def should_retry(self, response, last_exception, retry_count):
if(status_code >= 500
or status_code == 401
or status_code == 408
or status_code == 429):

or status_code == 429
or status_code == 104):
self.__backoff()
return True

Expand All @@ -70,4 +73,60 @@ def should_retry(self, response, last_exception, retry_count):

def __backoff(self):
time.sleep(self.exponential_retry_interval)
self.exponential_retry_interval *= self.exponential_factor
self.exponential_retry_interval *= self.exponential_factor


def retry_decorator_for_auth(retry_policy = None):
import adal
from requests import HTTPError
if retry_policy is None:
retry_policy = ExponentialRetryPolicy(max_retries=2)

def deco_retry(func):
@wraps(func)
def f_retry(*args, **kwargs):
retry_count = -1
last_exception = None
out = None
while True:
retry_count += 1
try:
out = func(*args, **kwargs)
except (adal.adal_error.AdalError, HTTPError) as e:
# ADAL error corresponds to everything but 429, which bubbles up HTTP error.
last_exception = e
logger.exception("Retry count " + str(retry_count) + "Exception :" + str(last_exception))

if hasattr(last_exception, 'error_response'): # ADAL exception
response = response_from_adal_exception(last_exception)
if hasattr(last_exception, 'response'): # HTTP exception i.e 429
response = last_exception.response

request_successful = last_exception is None or response.status_code == 401 # 401 = Invalid credentials
if request_successful or not retry_policy.should_retry(response, last_exception, retry_count):
break
if out is None:
raise last_exception
return out

return f_retry

return deco_retry


def response_from_adal_exception(e):
import re
from collections import namedtuple

response = e.error_response
http_code = re.search("http error: (\d+)", str(e))
if http_code is not None: # Add status_code to response object for use in should_retry
keys = list(response.keys()) + ['status_code']
status_code = int(http_code.group(1))
values = list(response.values()) + [status_code]

Response = namedtuple("Response", keys)
response = Response(
*values) # Construct response object with adal exception response and http code
return response

1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
install_requires=[
'cffi',
'adal>=0.4.2',
'requests>=2.20.0'
],
extras_require={
":python_version<'3.4'": ['pathlib2'],
Expand Down
1 change: 1 addition & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
SUBSCRIPTION_ID = fake_settings.SUBSCRIPTION_ID
RESOURCE_GROUP_NAME = fake_settings.RESOURCE_GROUP_NAME
RECORD_MODE = os.environ.get('RECORD_MODE', 'all').lower()
CLIENT_ID = os.environ['azure_service_principal']
'''
RECORD_MODE = os.environ.get('RECORD_MODE', 'none').lower()

Expand Down
Loading