Skip to content

Commit

Permalink
Get boto3.session.Session by appropriate method
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored and potiuk committed Aug 6, 2022
1 parent 5a68213 commit a7160c2
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 62 deletions.
62 changes: 34 additions & 28 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,26 +329,26 @@ def _strip_invalid_session_name_characters(self, role_session_name: str) -> str:

def _get_region_name(self) -> Optional[str]:
warnings.warn(
"`BaseSessionFactory._get_region_name` method will be deprecated in the future."
"Please use `BaseSessionFactory.region_name` property instead.",
"`BaseSessionFactory._get_region_name` method deprecated and will be removed "
"in a future releases. Please use `BaseSessionFactory.region_name` property instead.",
PendingDeprecationWarning,
stacklevel=2,
)
return self.region_name

def _read_role_arn_from_extra_config(self) -> Optional[str]:
warnings.warn(
"`BaseSessionFactory._read_role_arn_from_extra_config` method will be deprecated in the future."
"Please use `BaseSessionFactory.role_arn` property instead.",
"`BaseSessionFactory._read_role_arn_from_extra_config` method deprecated and will be removed "
"in a future releases. Please use `BaseSessionFactory.role_arn` property instead.",
PendingDeprecationWarning,
stacklevel=2,
)
return self.role_arn

def _read_credentials_from_connection(self) -> Tuple[Optional[str], Optional[str]]:
warnings.warn(
"`BaseSessionFactory._read_credentials_from_connection` method will be deprecated in the future."
"Please use `BaseSessionFactory.conn.aws_access_key_id` and "
"`BaseSessionFactory._read_credentials_from_connection` method deprecated and will be removed "
"in a future releases. Please use `BaseSessionFactory.conn.aws_access_key_id` and "
"`BaseSessionFactory.aws_secret_access_key` properties instead.",
PendingDeprecationWarning,
stacklevel=2,
Expand Down Expand Up @@ -430,24 +430,19 @@ def config(self) -> Optional[Config]:
"""Configuration for botocore client read-only property."""
return self.conn_config.botocore_config

def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:
self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)

session = SessionFactory(
def get_session(self, region_name: Optional[str] = None) -> boto3.session.Session:
"""Get the underlying boto3.session.Session(region_name=region_name)."""
return SessionFactory(
conn=self.conn_config, region_name=region_name, config=self.config
).create_session()

return session, self.conn_config.endpoint_url

def get_client_type(
self,
client_type: Optional[str] = None,
region_name: Optional[str] = None,
config: Optional[Config] = None,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session"""
session, endpoint_url = self._get_credentials(region_name=region_name)

if client_type:
warnings.warn(
"client_type is deprecated. Set client_type from class attribute.",
Expand All @@ -462,7 +457,10 @@ def get_client_type(
if config is None:
config = self.config

return session.client(client_type, endpoint_url=endpoint_url, config=config, verify=self.verify)
session = self.get_session(region_name=region_name)
return session.client(
client_type, endpoint_url=self.conn_config.endpoint_url, config=config, verify=self.verify
)

def get_resource_type(
self,
Expand All @@ -471,8 +469,6 @@ def get_resource_type(
config: Optional[Config] = None,
) -> boto3.resource:
"""Get the underlying boto3 resource using boto3 session"""
session, endpoint_url = self._get_credentials(region_name=region_name)

if resource_type:
warnings.warn(
"resource_type is deprecated. Set resource_type from class attribute.",
Expand All @@ -487,10 +483,13 @@ def get_resource_type(
if config is None:
config = self.config

return session.resource(resource_type, endpoint_url=endpoint_url, config=config, verify=self.verify)
session = self.get_session(region_name=region_name)
return session.resource(
resource_type, endpoint_url=self.conn_config.endpoint_url, config=config, verify=self.verify
)

@cached_property
def conn(self) -> Union[boto3.client, boto3.resource]:
def conn(self) -> BaseAwsConnection:
"""
Get the underlying boto3 client/resource (cached)
Expand Down Expand Up @@ -538,22 +537,16 @@ def get_conn(self) -> BaseAwsConnection:
# Compat shim
return self.conn

def get_session(self, region_name: Optional[str] = None) -> boto3.session.Session:
"""Get the underlying boto3.session."""
session, _ = self._get_credentials(region_name=region_name)
return session

def get_credentials(self, region_name: Optional[str] = None) -> ReadOnlyCredentials:
"""
Get the underlying `botocore.Credentials` object.
This contains the following authentication attributes: access_key, secret_key and token.
"""
session, _ = self._get_credentials(region_name=region_name)
# Credentials are refreshable, so accessing your access key and
# secret key separately can lead to a race condition.
# See https://stackoverflow.com/a/36291428/8283373
return session.get_credentials().get_frozen_credentials()
return self.get_session(region_name=region_name).get_credentials().get_frozen_credentials()

def expand_role(self, role: str, region_name: Optional[str] = None) -> str:
"""
Expand All @@ -567,8 +560,10 @@ def expand_role(self, role: str, region_name: Optional[str] = None) -> str:
if "/" in role:
return role
else:
session, endpoint_url = self._get_credentials(region_name=region_name)
_client = session.client('iam', endpoint_url=endpoint_url, config=self.config, verify=self.verify)
session = self.get_session(region_name=region_name)
_client = session.client(
'iam', endpoint_url=self.conn_config.endpoint_url, config=self.config, verify=self.verify
)
return _client.get_role(RoleName=role)["Role"]["Arn"]

@staticmethod
Expand Down Expand Up @@ -603,6 +598,17 @@ def decorator_f(self, *args, **kwargs):

return retry_decorator

def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:
warnings.warn(
"`AwsGenericHook._get_credentials` method deprecated and will be removed in a future releases. "
"Please use `AwsGenericHook.get_session` method and "
"`AwsGenericHook.conn_config.endpoint_url` property instead.",
DeprecationWarning,
stacklevel=2,
)

return self.get_session(region_name=region_name), self.conn_config.endpoint_url

@staticmethod
def get_ui_field_behaviour() -> Dict[str, Any]:
"""Returns custom UI field behaviour for AWS Connection."""
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/hooks/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def list_jobs(self) -> List:

def get_iam_execution_role(self) -> Dict:
""":return: iam role for job execution"""
session, endpoint_url = self._get_credentials(region_name=self.region_name)
iam_client = session.client('iam', endpoint_url=endpoint_url, config=self.config, verify=self.verify)

try:
iam_client = self.get_session(region_name=self.region_name).client(
'iam', endpoint_url=self.conn_config.endpoint_url, config=self.config, verify=self.verify
)
glue_execution_role = iam_client.get_role(RoleName=self.role_name)
self.log.info("Iam Role Name: %s", self.role_name)
return glue_execution_role
Expand Down
14 changes: 4 additions & 10 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,9 @@ def get_bucket(self, bucket_name: Optional[str] = None) -> object:
:return: the bucket object to the bucket name.
:rtype: boto3.S3.Bucket
"""
# Buckets have no regions, and we cannot remove the region name from _get_credentials as we would
# break compatibility, so we set it explicitly to None.
session, endpoint_url = self._get_credentials(region_name=None)
s3_resource = session.resource(
s3_resource = self.get_session().resource(
"s3",
endpoint_url=endpoint_url,
endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)
Expand Down Expand Up @@ -465,12 +462,9 @@ def get_key(self, key: str, bucket_name: Optional[str] = None) -> S3Transfer:
:return: the key object from the bucket
:rtype: boto3.s3.Object
"""
# Buckets have no regions, and we cannot remove the region name from _get_credentials as we would
# break compatibility, so we set it explicitly to None.
session, endpoint_url = self._get_credentials(region_name=None)
s3_resource = session.resource(
s3_resource = self.get_session().resource(
"s3",
endpoint_url=endpoint_url,
endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)
Expand Down
43 changes: 43 additions & 0 deletions tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,49 @@ def test_conn_config_conn_id_not_exists(self):
assert isinstance(conn_config_fallback_not_exists, AwsConnectionWrapper)
assert not conn_config_fallback_not_exists

@mock.patch('airflow.providers.amazon.aws.hooks.base_aws.SessionFactory')
@pytest.mark.parametrize("hook_region_name", [None, "eu-west-1"])
@pytest.mark.parametrize(
"hook_botocore_config", [None, Config(s3={"us_east_1_regional_endpoint": "regional"})]
)
@pytest.mark.parametrize("method_region_name", [None, "cn-north-1"])
def test_get_session(
self, mock_session_factory, hook_region_name, hook_botocore_config, method_region_name
):
"""Test get boto3 Session by hook."""
mock_session_factory_instance = mock_session_factory.return_value
mock_session_factory_instance.create_session.return_value = MOCK_BOTO3_SESSION

hook = AwsBaseHook(aws_conn_id=None, region_name=hook_region_name, config=hook_botocore_config)
session = hook.get_session(region_name=method_region_name)
mock_session_factory.assert_called_once_with(
conn=hook.conn_config,
region_name=method_region_name,
config=hook_botocore_config,
)
assert mock_session_factory_instance.create_session.assert_called_once
assert session == MOCK_BOTO3_SESSION

@mock.patch(
'airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook.get_session',
return_value=MOCK_BOTO3_SESSION,
)
@pytest.mark.parametrize("region_name", [None, "aws-global", "eu-west-1"])
def test_deprecate_private_method__get_credentials(self, mock_boto3_session, region_name):
"""Test deprecated method AwsGenericHook._get_credentials."""
hook = AwsBaseHook(aws_conn_id=None)
warning_message = (
r"`AwsGenericHook._get_credentials` method deprecated and will be removed in a future releases\. "
r"Please use `AwsGenericHook.get_session` method and "
r"`AwsGenericHook.conn_config.endpoint_url` property instead\."
)
with pytest.warns(DeprecationWarning, match=warning_message):
session, endpoint = hook._get_credentials(region_name)

mock_boto3_session.assert_called_once_with(region_name=region_name)
assert session == MOCK_BOTO3_SESSION
assert endpoint == hook.conn_config.endpoint_url


class ThrowErrorUntilCount:
"""Holds counter state for invoking a method several times in a row."""
Expand Down
44 changes: 23 additions & 21 deletions tests/providers/amazon/aws/hooks/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# specific language governing permissions and limitations
# under the License.
import json
import unittest
from unittest import mock

import boto3
import pytest

from airflow.providers.amazon.aws.hooks.glue import GlueJobHook

try:
Expand All @@ -27,19 +29,19 @@
mock_iam = mock_glue = None


class TestGlueJobHook(unittest.TestCase):
def setUp(self):
class TestGlueJobHook:
@pytest.fixture(autouse=True)
def setup(self):
self.some_aws_region = "us-west-2"

@unittest.skipIf(mock_iam is None, 'mock_iam package not present')
@pytest.mark.skipif(mock_glue is None, reason="mock_glue package not present")
@mock_iam
def test_get_iam_execution_role(self):
hook = GlueJobHook(
job_name='aws_test_glue_job', s3_bucket='some_bucket', iam_role_name='my_test_role'
)
iam_role = hook.get_client_type('iam').create_role(
Path="/",
RoleName='my_test_role',
@pytest.mark.parametrize("role_path", ["/", "/custom-path/"])
def test_get_iam_execution_role(self, role_path):
expected_role = "my_test_role"
boto3.client("iam").create_role(
Path=role_path,
RoleName=expected_role,
AssumeRolePolicyDocument=json.dumps(
{
"Version": "2012-10-17",
Expand All @@ -51,11 +53,18 @@ def test_get_iam_execution_role(self):
}
),
)

hook = GlueJobHook(
aws_conn_id=None,
job_name='aws_test_glue_job',
s3_bucket='some_bucket',
iam_role_name=expected_role,
)
iam_role = hook.get_iam_execution_role()
assert iam_role is not None
assert "Role" in iam_role
assert "Arn" in iam_role['Role']
assert iam_role['Role']['Arn'] == "arn:aws:iam::123456789012:role/my_test_role"
assert iam_role['Role']['Arn'] == f"arn:aws:iam::123456789012:role{role_path}{expected_role}"

@mock.patch.object(GlueJobHook, "get_conn")
def test_get_or_create_glue_job_get_existing_job(self, mock_get_conn):
Expand Down Expand Up @@ -84,7 +93,7 @@ def test_get_or_create_glue_job_get_existing_job(self, mock_get_conn):
mock_get_conn.return_value.get_job.assert_called_once_with(JobName=hook.job_name)
assert result == expected_job_name

@unittest.skipIf(mock_glue is None, "mock_glue package not present")
@pytest.mark.skipif(mock_glue is None, reason="mock_glue package not present")
@mock_glue
@mock.patch.object(GlueJobHook, "get_iam_execution_role")
def test_get_or_create_glue_job_create_new_job(self, mock_get_iam_execution_role):
Expand Down Expand Up @@ -135,10 +144,7 @@ def test_init_worker_type_value_error(self, mock_get_conn, mock_get_iam_executio
some_script = "s3:/glue-examples/glue-scripts/sample_aws_glue_job.py"
some_s3_bucket = "my-includes"

with self.assertRaises(
ValueError,
msg="ValueError should be raised for specifying the num_of_dpus and worker type together!",
):
with pytest.raises(ValueError, match="Cannot specify num_of_dpus with custom WorkerType"):
GlueJobHook(
job_name='aws_test_glue_job',
desc='This is test case job from Airflow',
Expand Down Expand Up @@ -175,7 +181,3 @@ def test_initialize_job(self, mock_get_conn, mock_get_or_create_glue_job, mock_g
glue_job_run = glue_job_hook.initialize_job(some_script_arguments, some_run_kwargs)
glue_job_run_state = glue_job_hook.get_job_state(glue_job_run['JobName'], glue_job_run['JobRunId'])
assert glue_job_run_state == mock_job_run_state, 'Mocks but be equal'


if __name__ == '__main__':
unittest.main()

0 comments on commit a7160c2

Please sign in to comment.