Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions airflow/providers/yandex/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@
Changelog
---------

4.0.0
.....

Breaking changes
~~~~~~~~~~~~~~~~

* This release of provider is only available for Airflow 2.3+ as explained in the Apache Airflow
providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers

Misc
~~~~

* In YandexCloudBaseHook, non-prefixed extra fields are supported and are preferred. E.g. ``folder_id`` will
be preferred if ``extra__yandexcloud__folder_id`` is also present.

3.1.0
.....

Expand Down
29 changes: 19 additions & 10 deletions airflow/providers/yandex/hooks/yandex.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,33 +46,33 @@ def get_connection_form_widgets() -> dict[str, Any]:
from wtforms import PasswordField, StringField

return {
"extra__yandexcloud__service_account_json": PasswordField(
"service_account_json": PasswordField(
lazy_gettext('Service account auth JSON'),
widget=BS3PasswordFieldWidget(),
description='Service account auth JSON. Looks like '
'{"id", "...", "service_account_id": "...", "private_key": "..."}. '
'Will be used instead of OAuth token and SA JSON file path field if specified.',
),
"extra__yandexcloud__service_account_json_path": StringField(
"service_account_json_path": StringField(
lazy_gettext('Service account auth JSON file path'),
widget=BS3TextFieldWidget(),
description='Service account auth JSON file path. File content looks like '
'{"id", "...", "service_account_id": "...", "private_key": "..."}. '
'Will be used instead of OAuth token if specified.',
),
"extra__yandexcloud__oauth": PasswordField(
"oauth": PasswordField(
lazy_gettext('OAuth Token'),
widget=BS3PasswordFieldWidget(),
description='User account OAuth token. '
'Either this or service account JSON must be specified.',
),
"extra__yandexcloud__folder_id": StringField(
"folder_id": StringField(
lazy_gettext('Default folder ID'),
widget=BS3TextFieldWidget(),
description='Optional. This folder will be used '
'to create all new clusters and nodes by default',
),
"extra__yandexcloud__public_ssh_key": StringField(
"public_ssh_key": StringField(
lazy_gettext('Public SSH key'),
widget=BS3TextFieldWidget(),
description='Optional. This key will be placed to all created Compute nodes'
Expand Down Expand Up @@ -146,9 +146,18 @@ def _get_credentials(self) -> dict[str, Any]:
return {'token': oauth_token}

def _get_field(self, field_name: str, default: Any = None) -> Any:
"""Fetches a field from extras, and returns it."""
long_f = f'extra__yandexcloud__{field_name}'
if hasattr(self, 'extras') and long_f in self.extras:
return self.extras[long_f]
else:
"""Get field from extra, first checking short name, then for backcompat we check for prefixed name."""
if not hasattr(self, 'extras'):
return default
backcompat_prefix = 'extra__yandexcloud__'
if field_name.startswith('extra__'):
raise ValueError(
f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix "
"when using this method."
)
if field_name in self.extras:
return self.extras[field_name]
prefixed_name = f"{backcompat_prefix}{field_name}"
if prefixed_name in self.extras:
return self.extras[prefixed_name]
return default
24 changes: 21 additions & 3 deletions tests/providers/yandex/hooks/test_yandex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@
# under the License.
from __future__ import annotations

import unittest
import os
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest
from pytest import param

from airflow.exceptions import AirflowException
from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook


class TestYandexHook(unittest.TestCase):
class TestYandexHook:
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@mock.patch('airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials')
def test_client_created_without_exceptions(self, get_credentials_mock, get_connection_mock):
Expand Down Expand Up @@ -81,7 +83,7 @@ def test_get_field(self, get_credentials_mock, get_connection_mock):
default_folder_id = 'test_id'
default_public_ssh_key = 'test_key'

extra_dejson = {"extra__yandexcloud__one": "value_one"}
extra_dejson = {"one": "value_one"}
get_connection_mock['extra_dejson'] = "sdsd"
get_connection_mock.extra_dejson = '{"extras": "extra"}'
get_connection_mock.return_value = mock.Mock(
Expand All @@ -96,3 +98,19 @@ def test_get_field(self, get_credentials_mock, get_connection_mock):
)

assert hook._get_field('one') == 'value_one'

@pytest.mark.parametrize(
'uri',
[
param(
'a://?extra__yandexcloud__folder_id=abc&extra__yandexcloud__public_ssh_key=abc', id='prefix'
),
param('a://?folder_id=abc&public_ssh_key=abc', id='no-prefix'),
],
)
@patch('airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials', new=MagicMock())
def test_backcompat_prefix_works(self, uri):
with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}):
hook = YandexCloudBaseHook('my_conn')
assert hook.default_folder_id == 'abc'
assert hook.default_public_ssh_key == 'abc'
2 changes: 1 addition & 1 deletion tests/providers/yandex/hooks/test_yandexcloud_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _init_hook(self):
self.hook = DataprocHook()

def setUp(self):
self.connection = Connection(extra=json.dumps({'extra__yandexcloud__oauth': OAUTH_TOKEN}))
self.connection = Connection(extra=json.dumps({'oauth': OAUTH_TOKEN}))
self._init_hook()

@patch('yandexcloud.SDK.create_operation_and_get_result')
Expand Down