Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions airflow/providers/snowflake/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@
Changelog
---------

5.0.0
.....

Breaking changes
~~~~~~~~~~~~~~~~

* This release of provider is only available for Airflow 2.3+ as explained in the Apache Airflow
providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
* In SnowflakeHook, if both ``extra__snowflake__foo`` and ``foo`` existed in connection extra
dict, the prefixed version would be used; now, the non-prefixed version will be preferred.

3.3.0
.....

Expand Down
123 changes: 79 additions & 44 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import os
from contextlib import closing
from functools import wraps
from io import StringIO
from pathlib import Path
from typing import Any, Callable, Iterable, Mapping
Expand All @@ -41,6 +42,34 @@ def _try_to_boolean(value: Any):
return value


def _ensure_prefixes(conn_type):
"""
Remove when provider min airflow version >= 2.5.0 since this is handled by
provider manager from that version.
"""

def dec(func):
@wraps(func)
def inner():
field_behaviors = func()
conn_attrs = {'host', 'schema', 'login', 'password', 'port', 'extra'}

def _ensure_prefix(field):
if field not in conn_attrs and not field.startswith('extra__'):
return f"extra__{conn_type}__{field}"
else:
return field

if 'placeholders' in field_behaviors:
placeholders = field_behaviors['placeholders']
field_behaviors['placeholders'] = {_ensure_prefix(k): v for k, v in placeholders.items()}
return field_behaviors

return inner

return dec


class SnowflakeHook(DbApiHook):
"""
A client to interact with Snowflake.
Expand Down Expand Up @@ -92,25 +121,22 @@ def get_connection_form_widgets() -> dict[str, Any]:
from wtforms import BooleanField, StringField

return {
"extra__snowflake__account": StringField(lazy_gettext('Account'), widget=BS3TextFieldWidget()),
"extra__snowflake__warehouse": StringField(
lazy_gettext('Warehouse'), widget=BS3TextFieldWidget()
),
"extra__snowflake__database": StringField(lazy_gettext('Database'), widget=BS3TextFieldWidget()),
"extra__snowflake__region": StringField(lazy_gettext('Region'), widget=BS3TextFieldWidget()),
"extra__snowflake__role": StringField(lazy_gettext('Role'), widget=BS3TextFieldWidget()),
"extra__snowflake__private_key_file": StringField(
lazy_gettext('Private key (Path)'), widget=BS3TextFieldWidget()
),
"extra__snowflake__private_key_content": StringField(
"account": StringField(lazy_gettext('Account'), widget=BS3TextFieldWidget()),
"warehouse": StringField(lazy_gettext('Warehouse'), widget=BS3TextFieldWidget()),
"database": StringField(lazy_gettext('Database'), widget=BS3TextFieldWidget()),
"region": StringField(lazy_gettext('Region'), widget=BS3TextFieldWidget()),
"role": StringField(lazy_gettext('Role'), widget=BS3TextFieldWidget()),
"private_key_file": StringField(lazy_gettext('Private key (Path)'), widget=BS3TextFieldWidget()),
"private_key_content": StringField(
lazy_gettext('Private key (Text)'), widget=BS3TextAreaFieldWidget()
),
"extra__snowflake__insecure_mode": BooleanField(
"insecure_mode": BooleanField(
label=lazy_gettext('Insecure mode'), description="Turns off OCSP certificate checks"
),
}

@staticmethod
@_ensure_prefixes(conn_type='snowflake')
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour"""
import json
Expand All @@ -130,14 +156,14 @@ def get_ui_field_behaviour() -> dict[str, Any]:
'schema': 'snowflake schema',
'login': 'snowflake username',
'password': 'snowflake password',
'extra__snowflake__account': 'snowflake account name',
'extra__snowflake__warehouse': 'snowflake warehouse name',
'extra__snowflake__database': 'snowflake db name',
'extra__snowflake__region': 'snowflake hosted region',
'extra__snowflake__role': 'snowflake role',
'extra__snowflake__private_key_file': 'Path of snowflake private key (PEM Format)',
'extra__snowflake__private_key_content': 'Content to snowflake private key (PEM format)',
'extra__snowflake__insecure_mode': 'insecure mode',
'account': 'snowflake account name',
'warehouse': 'snowflake warehouse name',
'database': 'snowflake db name',
'region': 'snowflake hosted region',
'role': 'snowflake role',
'private_key_file': 'Path of snowflake private key (PEM Format)',
'private_key_content': 'Content to snowflake private key (PEM format)',
'insecure_mode': 'insecure mode',
},
}

Expand All @@ -153,31 +179,44 @@ def __init__(self, *args, **kwargs) -> None:
self.session_parameters = kwargs.pop("session_parameters", None)
self.query_ids: list[str] = []

def _get_field(self, extra_dict, field_name):
backcompat_prefix = 'extra__snowflake__'
backcompat_key = f"{backcompat_prefix}{field_name}"
if field_name.startswith('extra__'):
raise ValueError(
f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix "
f"when using this method."
)
if field_name in extra_dict:
import warnings

if backcompat_key in extra_dict:
warnings.warn(
f"Conflicting params `{field_name}` and `{backcompat_key}` found in extras. "
f"Using value for `{field_name}`. Please ensure this is the correct "
f"value and remove the backcompat key `{backcompat_key}`."
)
return extra_dict[field_name] or None
return extra_dict.get(backcompat_key) or None

def _get_conn_params(self) -> dict[str, str | None]:
"""
One method to fetch connection params as a dict
used in get_uri() and get_connection()
"""
conn = self.get_connection(self.snowflake_conn_id) # type: ignore[attr-defined]
account = conn.extra_dejson.get('extra__snowflake__account', '') or conn.extra_dejson.get(
'account', ''
)
warehouse = conn.extra_dejson.get('extra__snowflake__warehouse', '') or conn.extra_dejson.get(
'warehouse', ''
)
database = conn.extra_dejson.get('extra__snowflake__database', '') or conn.extra_dejson.get(
'database', ''
)
region = conn.extra_dejson.get('extra__snowflake__region', '') or conn.extra_dejson.get('region', '')
role = conn.extra_dejson.get('extra__snowflake__role', '') or conn.extra_dejson.get('role', '')
extra_dict = conn.extra_dejson
account = self._get_field(extra_dict, 'account') or ''
warehouse = self._get_field(extra_dict, 'warehouse') or ''
database = self._get_field(extra_dict, 'database') or ''
region = self._get_field(extra_dict, 'region') or ''
role = self._get_field(extra_dict, 'role') or ''
insecure_mode = _try_to_boolean(self._get_field(extra_dict, 'insecure_mode'))
schema = conn.schema or ''
authenticator = conn.extra_dejson.get('authenticator', 'snowflake')
session_parameters = conn.extra_dejson.get('session_parameters')
insecure_mode = _try_to_boolean(
conn.extra_dejson.get(
'extra__snowflake__insecure_mode', conn.extra_dejson.get('insecure_mode', None)
)
)

# authenticator and session_parameters never supported long name so we don't use _get_field
authenticator = extra_dict.get('authenticator', 'snowflake')
session_parameters = extra_dict.get('session_parameters')

conn_config = {
"user": conn.login,
Expand All @@ -202,12 +241,8 @@ def _get_conn_params(self) -> dict[str, str | None]:
# The connection password then becomes the passphrase for the private key.
# If your private key is not encrypted (not recommended), then leave the password empty.

private_key_file = conn.extra_dejson.get(
'extra__snowflake__private_key_file'
) or conn.extra_dejson.get('private_key_file')
private_key_content = conn.extra_dejson.get(
'extra__snowflake__private_key_content'
) or conn.extra_dejson.get('private_key_content')
private_key_file = self._get_field(extra_dict, 'private_key_file')
private_key_content = self._get_field(extra_dict, 'private_key_content')

private_key_pem = None
if private_key_content and private_key_file:
Expand Down
98 changes: 98 additions & 0 deletions tests/providers/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import json
import unittest
from copy import deepcopy
from pathlib import Path
Expand All @@ -30,6 +31,7 @@

from airflow.models import Connection
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from tests.test_utils.providers import get_provider_min_airflow_version, object_exists

_PASSWORD = 'snowflake42'

Expand Down Expand Up @@ -292,6 +294,66 @@ def test_get_conn_params_should_support_private_auth_in_connection(
):
assert 'private_key' in SnowflakeHook(snowflake_conn_id='test_conn')._get_conn_params()

@pytest.mark.parametrize('include_params', [True, False])
def test_hook_param_beats_extra(self, include_params):
"""When both hook params and extras are supplied, hook params should
beat extras."""
hook_params = dict(
account='account',
warehouse='warehouse',
database='database',
region='region',
role='role',
authenticator='authenticator',
session_parameters='session_parameters',
)
extras = {k: f"{v}_extra" for k, v in hook_params.items()}
with unittest.mock.patch.dict(
'os.environ',
AIRFLOW_CONN_TEST_CONN=Connection(conn_type='any', extra=json.dumps(extras)).get_uri(),
):
assert hook_params != extras
assert SnowflakeHook(
snowflake_conn_id='test_conn', **(hook_params if include_params else {})
)._get_conn_params() == {
'user': None,
'password': '',
'application': 'AIRFLOW',
'schema': '',
**(hook_params if include_params else extras),
}

@pytest.mark.parametrize('include_unprefixed', [True, False])
def test_extra_short_beats_long(self, include_unprefixed):
"""When both prefixed and unprefixed values are found in extra (e.g.
extra__snowflake__account and account), we should prefer the short
name."""
extras = dict(
account='account',
warehouse='warehouse',
database='database',
region='region',
role='role',
)
extras_prefixed = {f"extra__snowflake__{k}": f"{v}_prefixed" for k, v in extras.items()}
with unittest.mock.patch.dict(
'os.environ',
AIRFLOW_CONN_TEST_CONN=Connection(
conn_type='any',
extra=json.dumps({**(extras if include_unprefixed else {}), **extras_prefixed}),
).get_uri(),
):
assert list(extras.values()) != list(extras_prefixed.values())
assert SnowflakeHook(snowflake_conn_id='test_conn')._get_conn_params() == {
'user': None,
'password': '',
'application': 'AIRFLOW',
'schema': '',
'authenticator': 'snowflake',
'session_parameters': None,
**(extras if include_unprefixed else dict(zip(extras.keys(), extras_prefixed.values()))),
}

def test_get_conn_params_should_support_private_auth_with_encrypted_key(
self, encrypted_temporary_private_key
):
Expand Down Expand Up @@ -524,3 +586,39 @@ def test_empty_sql_parameter(self):
with pytest.raises(ValueError) as err:
hook.run(sql=empty_statement)
assert err.value.args[0] == "List of SQL statements is empty"

def test__ensure_prefixes_removal(self):
"""Ensure that _ensure_prefixes is removed from snowflake when airflow min version >= 2.5.0."""
path = 'airflow.providers.snowflake.hooks.snowflake._ensure_prefixes'
if not object_exists(path):
raise Exception(
"You must remove this test. It only exists to "
"remind us to remove decorator `_ensure_prefixes`."
)

if get_provider_min_airflow_version('apache-airflow-providers-snowflake') >= (2, 5):
raise Exception(
"You must now remove `_ensure_prefixes` from SnowflakeHook. The functionality is now taken"
"care of by providers manager."
)

def test___ensure_prefixes(self):
"""
Check that ensure_prefixes decorator working properly

Note: remove this test when removing ensure_prefixes (after min airflow version >= 2.5.0
"""
assert list(SnowflakeHook.get_ui_field_behaviour()['placeholders'].keys()) == [
'extra',
'schema',
'login',
'password',
'extra__snowflake__account',
'extra__snowflake__warehouse',
'extra__snowflake__database',
'extra__snowflake__region',
'extra__snowflake__role',
'extra__snowflake__private_key_file',
'extra__snowflake__private_key_content',
'extra__snowflake__insecure_mode',
]
10 changes: 10 additions & 0 deletions tests/test_utils/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,13 @@ def get_provider_version(provider_name):

info = ProvidersManager().providers[provider_name]
return semver.VersionInfo.parse(info.version)


def get_provider_min_airflow_version(provider_name):
from airflow.providers_manager import ProvidersManager

p = ProvidersManager()
deps = p.providers[provider_name].data['dependencies']
airflow_dep = [x for x in deps if x.startswith('apache-airflow')][0]
min_airflow_version = tuple(map(int, airflow_dep.split('>=')[1].split('.')))
return min_airflow_version