Skip to content

Commit

Permalink
Allow sending multiple roles to UNLOAD and COPY commands (sqlalchemy-…
Browse files Browse the repository at this point in the history
…redshift#212)

* Allow sending multiple roles to UNLOAD and COPY commands

* Code review fixes and linting
  • Loading branch information
alexcarruthers authored Jan 7, 2021
1 parent 40bddca commit 4f59b4b
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 16 deletions.
4 changes: 3 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
0.8.2 (unreleased)
------------------

- Nothing changed yet.
- Allow supplying multiple role ARNs in COPY and UNLOAD commands. This allows
the first role to assume other roles as explained
`here <https://docs.aws.amazon.com/redshift/latest/mgmt/authorizing-redshift-service.html#authorizing-redshift-service-chaining-roles>`_.


0.8.1 (2020-07-15)
Expand Down
81 changes: 66 additions & 15 deletions sqlalchemy_redshift/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,20 @@
TOKEN_RE = re.compile('[A-Za-z0-9/+=]+')
AWS_PARTITIONS = frozenset({'aws', 'aws-cn', 'aws-us-gov'})
AWS_ACCOUNT_ID_RE = re.compile('[0-9]{12}')
IAM_ROLE_NAME_RE = re.compile('[A-Za-z0-9+=,.@-_]{1,64}')
IAM_ROLE_NAME_RE = re.compile('[A-Za-z0-9+=,.@\-_]{1,64}')
IAM_ROLE_ARN_RE = re.compile('arn:(aws|aws-cn|aws-us-gov):iam::'
'[0-9]{12}:role/[A-Za-z0-9+=,.@\-_]{1,64}')


def _process_aws_credentials(access_key_id=None, secret_access_key=None,
session_token=None, aws_partition='aws',
aws_account_id=None, iam_role_name=None):
aws_account_id=None, iam_role_name=None,
iam_role_arns=None):
uses_iam_role = aws_account_id is not None and iam_role_name is not None
uses_iam_roles = iam_role_arns is not None
uses_key = access_key_id is not None and secret_access_key is not None

if (access_key_id is not None and secret_access_key is not None and
aws_account_id is not None and iam_role_name is not None):
if uses_iam_role + uses_iam_roles + uses_key > 1:
raise TypeError(
'Either access key based credentials or role based credentials '
'should be specified, but not both'
Expand Down Expand Up @@ -66,6 +71,21 @@ def _process_aws_credentials(access_key_id=None, secret_access_key=None,
iam_role_name,
)

if iam_role_arns is not None:
if isinstance(iam_role_arns, str):
iam_role_arns = [iam_role_arns]
if not isinstance(iam_role_arns, list):
raise ValueError('iam_role_arns must be a list')
for arn in iam_role_arns:
if not IAM_ROLE_ARN_RE.match(arn):
raise ValueError(
'invalid AWS account ID; does not match {pattern}'.format(
pattern=IAM_ROLE_ARN_RE.pattern,
)
)

credentials = 'aws_iam_role=' + ','.join(iam_role_arns)

if access_key_id is not None and secret_access_key is not None:
if not ACCESS_KEY_ID_RE.match(access_key_id):
raise ValueError(
Expand Down Expand Up @@ -185,21 +205,29 @@ class UnloadFromSelect(_ExecutableClause):
file if the `manifest` option is used
access_key_id: str, optional
Access Key. Required unless you supply role-based credentials
(``aws_account_id`` and ``iam_role_name``)
(``aws_account_id`` and ``iam_role_name`` or ``iam_role_arns``)
secret_access_key: str, optional
Secret Access Key ID. Required unless you supply role-based credentials
(``aws_account_id`` and ``iam_role_name``)
(``aws_account_id`` and ``iam_role_name`` or ``iam_role_arns``)
session_token : str, optional
iam_role_arns : str or list of strings, optional
Either a single arn or a list of arns of roles to assume when unloading
Required unless you supply key based credentials (``access_key_id`` and
``secret_access_key``) or (``aws_account_id`` and ``iam_role_name``)
separately.
aws_partition: str, optional
AWS partition to use with role-based credentials. Defaults to
``'aws'``. Not applicable when using key based credentials
(``access_key_id`` and ``secret_access_key``).
(``access_key_id`` and ``secret_access_key``) or role arns
(``iam_role_arns``) directly.
aws_account_id: str, optional
AWS account ID for role-based credentials. Required unless you supply
key based credentials (``access_key_id`` and ``secret_access_key``)
or role arns (``iam_role_arns``) directly.
iam_role_name: str, optional
IAM role name for role-based credentials. Required unless you supply
key based credentials (``access_key_id`` and ``secret_access_key``)
or role arns (``iam_role_arns``) directly.
manifest: bool, optional
Boolean value denoting whether data_location is a manifest file.
delimiter: File delimiter, optional
Expand Down Expand Up @@ -248,7 +276,7 @@ def __init__(self, select, unload_location, access_key_id=None,
encrypted=False, gzip=False, add_quotes=False, null=None,
escape=False, allow_overwrite=False, parallel=True,
header=False, region=None, max_file_size=None,
format=None):
format=None, iam_role_arns=None):

if delimiter is not None and len(delimiter) != 1:
raise ValueError(
Expand All @@ -267,6 +295,7 @@ def __init__(self, select, unload_location, access_key_id=None,
aws_partition=aws_partition,
aws_account_id=aws_account_id,
iam_role_name=iam_role_name,
iam_role_arns=iam_role_arns,
)

self.select = select
Expand Down Expand Up @@ -453,21 +482,29 @@ class CopyCommand(_ExecutableClause):
the `manifest` option is used
access_key_id: str, optional
Access Key. Required unless you supply role-based credentials
(``aws_account_id`` and ``iam_role_name``)
(``aws_account_id`` and ``iam_role_name`` or ``iam_role_arns``)
secret_access_key: str, optional
Secret Access Key ID. Required unless you supply role-based credentials
(``aws_account_id`` and ``iam_role_name``)
(``aws_account_id`` and ``iam_role_name`` or ``iam_role_arns``)
session_token : str, optional
iam_role_arns : str or list of strings, optional
Either a single arn or a list of arns of roles to assume when unloading
Required unless you supply key based credentials (``access_key_id`` and
``secret_access_key``) or (``aws_account_id`` and ``iam_role_name``)
separately.
aws_partition: str, optional
AWS partition to use with role-based credentials. Defaults to
``'aws'``. Not applicable when using key based credentials
(``access_key_id`` and ``secret_access_key``).
(``access_key_id`` and ``secret_access_key``) or role arns
(``iam_role_arns``) directly.
aws_account_id: str, optional
AWS account ID for role-based credentials. Required unless you supply
key based credentials (``access_key_id`` and ``secret_access_key``)
or role arns (``iam_role_arns``) directly.
iam_role_name: str, optional
IAM role name for role-based credentials. Required unless you supply
key based credentials (``access_key_id`` and ``secret_access_key``)
or role arns (``iam_role_arns``) directly.
format : Format, optional
Indicates the type of file to copy from
quote : str, optional
Expand Down Expand Up @@ -587,7 +624,7 @@ def __init__(self, to, data_location, access_key_id=None,
roundec=False, time_format=None, trim_blanks=False,
truncate_columns=False, comp_rows=None, comp_update=None,
max_error=None, no_load=False, stat_update=None,
manifest=False, region=None):
manifest=False, region=None, iam_role_arns=None):

credentials = _process_aws_credentials(
access_key_id=access_key_id,
Expand All @@ -596,6 +633,7 @@ def __init__(self, to, data_location, access_key_id=None,
aws_partition=aws_partition,
aws_account_id=aws_account_id,
iam_role_name=iam_role_name,
iam_role_arns=iam_role_arns,
)

if delimiter is not None and len(delimiter) != 1:
Expand Down Expand Up @@ -873,17 +911,29 @@ class CreateLibraryCommand(_ExecutableClause):
S3 location.
access_key_id: str, optional
Access Key. Required unless you supply role-based credentials
(``aws_account_id`` and ``iam_role_name``)
(``aws_account_id`` and ``iam_role_name`` or ``iam_role_arns``)
secret_access_key: str, optional
Secret Access Key ID. Required unless you supply role-based credentials
(``aws_account_id`` and ``iam_role_name``)
(``aws_account_id`` and ``iam_role_name`` or ``iam_role_arns``)
session_token : str, optional
iam_role_arns : str or list of strings, optional
Either a single arn or a list of arns of roles to assume when unloading
Required unless you supply key based credentials (``access_key_id`` and
``secret_access_key``) or (``aws_account_id`` and ``iam_role_name``)
separately.
aws_partition: str, optional
AWS partition to use with role-based credentials. Defaults to
``'aws'``. Not applicable when using key based credentials
(``access_key_id`` and ``secret_access_key``) or role arns
(``iam_role_arns``) directly.
aws_account_id: str, optional
AWS account ID for role-based credentials. Required unless you supply
key based credentials (``access_key_id`` and ``secret_access_key``)
or role arns (``iam_role_arns``) directly.
iam_role_name: str, optional
IAM role name for role-based credentials. Required unless you supply
key based credentials (``access_key_id`` and ``secret_access_key``)
or role arns (``iam_role_arns``) directly.
replace: bool, optional, default False
Controls the presence of ``OR REPLACE`` in the compiled statement. See
the command documentation for details.
Expand All @@ -894,7 +944,7 @@ class CreateLibraryCommand(_ExecutableClause):
def __init__(self, library_name, location, access_key_id=None,
secret_access_key=None, session_token=None,
aws_account_id=None, iam_role_name=None, replace=False,
region=None):
region=None, iam_role_arns=None):
self.library_name = library_name
self.location = location
self.credentials = _process_aws_credentials(
Expand All @@ -903,6 +953,7 @@ def __init__(self, library_name, location, access_key_id=None,
session_token=session_token,
aws_account_id=aws_account_id,
iam_role_name=iam_role_name,
iam_role_arns=iam_role_arns,
)
self.replace = replace
self.region = region
Expand Down
42 changes: 42 additions & 0 deletions tests/test_copy_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,48 @@ def test_iam_role_partition_validation():
)


def test_iam_role_arns_list():
"""Tests the use of multiple iam role arns instead of access keys."""

iam_role_arns = [
'arn:aws:iam::000123456789:role/redshiftrole',
'arn:aws:iam::000123456789:role/redshiftrole2',
]
creds = 'aws_iam_role=arn:aws:iam::000123456789:role/redshiftrole,' \
'arn:aws:iam::000123456789:role/redshiftrole2'

expected_result = """
COPY schema1.t1 FROM 's3://mybucket/data/listing/'
WITH CREDENTIALS AS '{creds}'
""".format(creds=creds)

copy = dialect.CopyCommand(
tbl,
data_location='s3://mybucket/data/listing/',
iam_role_arns=iam_role_arns,
)
assert clean(expected_result) == clean(compile_query(copy))


def test_iam_role_arns_single():
"""Tests the use of a single iam role arn instead of access keys."""

iam_role_arns = 'arn:aws:iam::000123456789:role/redshiftrole'
creds = 'aws_iam_role=arn:aws:iam::000123456789:role/redshiftrole'

expected_result = """
COPY schema1.t1 FROM 's3://mybucket/data/listing/'
WITH CREDENTIALS AS '{creds}'
""".format(creds=creds)

copy = dialect.CopyCommand(
tbl,
data_location='s3://mybucket/data/listing/',
iam_role_arns=iam_role_arns,
)
assert clean(expected_result) == clean(compile_query(copy))


def test_format():
expected_result = """
COPY t1 FROM 's3://mybucket/data/listing/'
Expand Down
46 changes: 46 additions & 0 deletions tests/test_unload_from_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,52 @@ def test_iam_role_partition_validation():
)


def test_iam_role_arns_list():
"""Tests the use of multiple iam role arns instead of access keys."""

iam_role_arns = [
'arn:aws:iam::000123456789:role/redshiftrole',
'arn:aws:iam::000123456789:role/redshiftrole2',
]
creds = 'aws_iam_role=arn:aws:iam::000123456789:role/redshiftrole,' \
'arn:aws:iam::000123456789:role/redshiftrole2'

unload = dialect.UnloadFromSelect(
select=sa.select([sa.func.count(table.c.id)]),
unload_location='s3://bucket/key',
iam_role_arns=iam_role_arns,
)

expected_result = """
UNLOAD ('SELECT count(t1.id) AS count_1 FROM t1')
TO 's3://bucket/key'
CREDENTIALS '{creds}'
""".format(creds=creds)

assert clean(compile_query(unload)) == clean(expected_result)


def test_iam_role_arns_single():
"""Tests the use of a single iam role arn instead of access keys."""

iam_role_arns = 'arn:aws:iam::000123456789:role/redshiftrole'
creds = 'aws_iam_role=arn:aws:iam::000123456789:role/redshiftrole'

unload = dialect.UnloadFromSelect(
select=sa.select([sa.func.count(table.c.id)]),
unload_location='s3://bucket/key',
iam_role_arns=iam_role_arns,
)

expected_result = """
UNLOAD ('SELECT count(t1.id) AS count_1 FROM t1')
TO 's3://bucket/key'
CREDENTIALS '{creds}'
""".format(creds=creds)

assert clean(compile_query(unload)) == clean(expected_result)


def test_all_redshift_options():
"""Tests that UnloadFromSelect handles all options correctly."""

Expand Down

0 comments on commit 4f59b4b

Please sign in to comment.