Skip to content

Commit

Permalink
Add AAD authentication for DatabricksAsyncHook
Browse files Browse the repository at this point in the history
  • Loading branch information
eskarimov committed Dec 14, 2021
1 parent 851adaf commit 44ef54b
Show file tree
Hide file tree
Showing 2 changed files with 342 additions and 15 deletions.
157 changes: 142 additions & 15 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import asyncio
import json
import time
from functools import partial
from time import sleep
from typing import Any, Tuple
from urllib.parse import urlparse
Expand Down Expand Up @@ -590,6 +591,104 @@ async def __aexit__(self, *err):
await self._session.close()
self._session = None

async def _get_aad_token(self, resource: str) -> str:
"""
Function to get AAD token for given resource. Supports managed identity or service principal auth
:param resource: resource to issue token to
:return: AAD token, or raise an exception
"""
aad_token = self.aad_tokens.get(resource)
if aad_token and self._is_aad_token_valid(aad_token):
return aad_token['token']

self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...')
attempt_num = 1
while True:
try:
if self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
params = {
"api-version": "2018-02-01",
"resource": resource,
}
async with self._session.get(
url=AZURE_TOKEN_SERVICE_URL,
params=params,
headers={**USER_AGENT_HEADER, "Metadata": "true"},
timeout=self.aad_timeout_seconds,
) as resp:
resp.raise_for_status()
jsn = await resp.json()
else:
tenant_id = self.databricks_conn.extra_dejson['azure_tenant_id']
data = {
"grant_type": "client_credentials",
"client_id": self.databricks_conn.login,
"resource": resource,
"client_secret": self.databricks_conn.password,
}
azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
"azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
)
async with self._session.post(
url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={**USER_AGENT_HEADER, 'Content-Type': 'application/x-www-form-urlencoded'},
timeout=self.aad_timeout_seconds,
) as resp:
resp.raise_for_status()
jsn = await resp.json()
if 'access_token' not in jsn or jsn.get('token_type') != 'Bearer' or 'expires_on' not in jsn:
raise AirflowException(f"Can't get necessary data from AAD token: {jsn}")

token = jsn['access_token']
self.aad_tokens[resource] = {'token': token, 'expires_on': int(jsn["expires_on"])}
return token
except requests_exceptions.RequestException as e:
if not _retryable_error(e):
raise AirflowException(
f'Response: {e.response.content}, Status Code: {e.response.status_code}'
)

self._log_request_error(attempt_num, e)

if attempt_num == self.retry_limit:
raise AirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.')

attempt_num += 1
await asyncio.sleep(self.retry_delay)

async def _get_aad_headers(self) -> dict:
"""
Fills AAD headers if necessary (SPN is outside of the workspace)
:return: dictionary with filled AAD headers
"""
headers = {}
if 'azure_resource_id' in self.databricks_conn.extra_dejson:
mgmt_token = await self._get_aad_token(AZURE_MANAGEMENT_ENDPOINT)
headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[
'azure_resource_id'
]
headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token
return headers

async def _check_azure_metadata_service(self):
# check for Azure Metadata Service
# https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service
try:
async with self._session.get(
url=AZURE_METADATA_SERVICE_TOKEN_URL,
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
timeout=2,
) as resp:
jsn = await resp.json()
if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']:
raise AirflowException(
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}"
)
except (requests_exceptions.RequestException, ValueError) as e:
raise AirflowException(f"Can't reach Azure Metadata Service: {e}")

async def _do_api_call(self, endpoint_info: Tuple[str, str], json: dict) -> dict:
"""
Utility function to perform an async API call with retries
Expand All @@ -604,23 +703,41 @@ async def _do_api_call(self, endpoint_info: Tuple[str, str], json: dict) -> dict
"""
method, endpoint = endpoint_info

self.databricks_conn = self.get_connection(self.databricks_conn_id)
if self.databricks_conn is None:
loop = asyncio.get_event_loop()
pfunc = partial(self.get_connection, self.databricks_conn_id)
self.databricks_conn = await loop.run_in_executor(None, pfunc)

auth = None
headers = {}
if 'token' in self.databricks_conn.extra_dejson:
self.log.info('Using token auth.')
headers["Authorization"] = f'Bearer {self.databricks_conn.extra_dejson["token"]}'
if 'host' in self.databricks_conn.extra_dejson:
host = self._parse_host(self.databricks_conn.extra_dejson['host'])
self.host = self._parse_host(self.databricks_conn.extra_dejson['host'])
else:
host = self.databricks_conn.host
self.host = self._parse_host(self.databricks_conn.host)

url = f'https://{self.host}/{endpoint}'

aad_headers = await self._get_aad_headers()
headers = {**USER_AGENT_HEADER.copy(), **aad_headers}

if 'token' in self.databricks_conn.extra_dejson:
self.log.info(
'Using token auth. For security reasons, please set token in Password field instead of extra'
)
auth = BearerAuth(self.databricks_conn.extra_dejson["token"])
elif not self.databricks_conn.login and self.databricks_conn.password:
self.log.info('Using token auth.')
auth = BearerAuth(self.databricks_conn.password)
elif 'azure_tenant_id' in self.databricks_conn.extra_dejson:
if self.databricks_conn.login == "" or self.databricks_conn.password == "":
raise AirflowException("Azure SPN credentials aren't provided")
self.log.info('Using AAD Token for SPN.')
auth = BearerAuth(await self._get_aad_token(DEFAULT_DATABRICKS_SCOPE))
elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
self.log.info('Using AAD Token for managed identity.')
await self._check_azure_metadata_service()
auth = BearerAuth(await self._get_aad_token(DEFAULT_DATABRICKS_SCOPE))
else:
self.log.info('Using basic auth.')
auth = aiohttp.BasicAuth(self.databricks_conn.login, self.databricks_conn.password)
host = self.databricks_conn.host

url = f'https://{self._parse_host(host)}/{endpoint}'

if method == 'GET':
request_func = self._session.get
Expand Down Expand Up @@ -684,10 +801,7 @@ async def get_run_state(self, run_id: int) -> RunState:
json = {'run_id': run_id}
response = await self._do_api_call(GET_RUN_ENDPOINT, json)
state = response['state']
life_cycle_state = state['life_cycle_state']
result_state = state.get('result_state', None)
state_message = state['state_message']
return RunState(life_cycle_state, result_state, state_message)
return RunState(**state)

async def run_now(self, json: dict) -> int:
raise NotImplementedError('Please use run_now() in regular DatabricksHook class')
Expand Down Expand Up @@ -727,3 +841,16 @@ async def install(self, json: dict) -> None:

async def uninstall(self, json: dict) -> None:
raise NotImplementedError('Please use uninstall() in regular DatabricksHook class')


class BearerAuth(aiohttp.BasicAuth):
"""aiohttp only ships BasicAuth, for Bearer auth we need a subclass of BasicAuth."""

def __new__(cls, token: str) -> 'BearerAuth':
return super().__new__(cls, token) # type: ignore

def __init__(self, token: str) -> None:
self.token = token

def encode(self) -> str:
return f'Bearer {self.token}'
Loading

0 comments on commit 44ef54b

Please sign in to comment.