Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
917c415
add node_overrides parameter to batch operator to support multinode j…
camilleanne Dec 2, 2022
57d8675
use trim_none_values to pass only truthy parameters to boto
camilleanne Dec 9, 2022
e9cf26e
add test
camilleanne Dec 9, 2022
5713c48
access logstreamname for multinode jobs; add batch_client test
camilleanne Dec 9, 2022
4c226fa
better conditionals on attempts array length
camilleanne Dec 9, 2022
d59729f
lint
camilleanne Dec 12, 2022
59e9a87
fix line length; extend test for multiple attempts
camilleanne Dec 12, 2022
5d16298
Merge branch 'main' of https://github.com/apache/airflow into ct/aws-…
camilleanne Dec 12, 2022
c2182e3
fix bad tab
camilleanne Dec 16, 2022
830bbe9
update logstream tests
camilleanne Dec 16, 2022
b6750cf
update tests for new expectations around arrayProperties
camilleanne Dec 16, 2022
6169610
Merge branch 'main' of https://github.com/apache/airflow into ct/aws-…
camilleanne Dec 16, 2022
79e488c
rename overrides param
vandonr-amz Dec 22, 2022
40453ca
Merge pull request #2 from aws-mwaa/vandonr/multinode
camilleanne Dec 23, 2022
179fc21
raise exception instead of a warning on unrecognized job type
camilleanne Dec 24, 2022
528bb99
Merge branch 'main' of https://github.com/apache/airflow into ct/aws-…
camilleanne Jan 26, 2023
8fada5c
Merge remote-tracking branch 'origin/main' into vandonr/batch
vandonr-amz Feb 14, 2023
8cb9529
add check against job_node_range_properties being empty
vandonr-amz Feb 14, 2023
318e91a
string formating suggestion
vandonr-amz Feb 14, 2023
4744db2
static check fixes
vandonr-amz Feb 15, 2023
87bc61a
Merge remote-tracking branch 'origin/main' into vandonr/batch
vandonr-amz Feb 16, 2023
e4444bb
take the opportunity to remove hook creation from operator ctor
vandonr-amz Feb 17, 2023
95f3c95
Merge remote-tracking branch 'origin/main' into vandonr/batch
vandonr-amz Feb 17, 2023
86cd265
Merge remote-tracking branch 'origin/main' into vandonr/batch
vandonr-amz Mar 13, 2023
871cd96
rework a bit deprecation warning
vandonr-amz Mar 14, 2023
bc458ce
Merge remote-tracking branch 'origin/main' into vandonr/batch
vandonr-amz Mar 27, 2023
8bc2f95
log links to all log streams
vandonr-amz Mar 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 60 additions & 27 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,43 +414,76 @@ def parse_job_description(job_id: str, response: dict) -> dict:
return matching_jobs[0]

def get_job_awslogs_info(self, job_id: str) -> dict[str, str] | None:
all_info = self.get_job_all_awslogs_info(job_id)
if not all_info:
return None
if len(all_info) > 1:
self.log.warning(
f"AWS Batch job ({job_id}) has more than one log stream, " f"only returning the first one."
)
return all_info[0]

def get_job_all_awslogs_info(self, job_id: str) -> list[dict[str, str]]:
"""
Parse job description to extract AWS CloudWatch information.

:param job_id: AWS Batch Job ID
"""
job_container_desc = self.get_job_description(job_id=job_id).get("container", {})
log_configuration = job_container_desc.get("logConfiguration", {})

# In case if user select other "logDriver" rather than "awslogs"
# than CloudWatch logging should be disabled.
# If user not specify anything than expected that "awslogs" will use
# with default settings:
# awslogs-group = /aws/batch/job
# awslogs-region = `same as AWS Batch Job region`
log_driver = log_configuration.get("logDriver", "awslogs")
if log_driver != "awslogs":
job_desc = self.get_job_description(job_id=job_id)

job_node_properties = job_desc.get("nodeProperties", {})
job_container_desc = job_desc.get("container", {})

if job_node_properties:
# one log config per node
log_configs = [
p.get("container", {}).get("logConfiguration", {})
for p in job_node_properties.get("nodeRangeProperties", {})
]
# one stream name per attempt
stream_names = [a.get("container", {}).get("logStreamName") for a in job_desc.get("attempts", [])]
elif job_container_desc:
log_configs = [job_container_desc.get("logConfiguration", {})]
stream_name = job_container_desc.get("logStreamName")
stream_names = [stream_name] if stream_name is not None else []
else:
raise AirflowException(
f"AWS Batch job ({job_id}) is not a supported job type. "
"Supported job types: container, array, multinode."
)

# If the user selected another logDriver than "awslogs", then CloudWatch logging is disabled.
if any([c.get("logDriver", "awslogs") != "awslogs" for c in log_configs]):
self.log.warning(
"AWS Batch job (%s) uses logDriver (%s). AWS CloudWatch logging disabled.", job_id, log_driver
f"AWS Batch job ({job_id}) uses non-aws log drivers. AWS CloudWatch logging disabled."
)
return None
return []

awslogs_stream_name = job_container_desc.get("logStreamName")
if not awslogs_stream_name:
# In case of call this method on very early stage of running AWS Batch
# there is possibility than AWS CloudWatch Stream Name not exists yet.
# AWS CloudWatch Stream Name also not created in case of misconfiguration.
self.log.warning("AWS Batch job (%s) doesn't create AWS CloudWatch Stream.", job_id)
return None
if not stream_names:
# If this method is called very early after starting the AWS Batch job,
# there is a possibility that the AWS CloudWatch Stream Name would not exist yet.
# This can also happen in case of misconfiguration.
self.log.warning(f"AWS Batch job ({job_id}) doesn't have any AWS CloudWatch Stream.")
return []

# Try to get user-defined log configuration options
log_options = log_configuration.get("options", {})

return {
"awslogs_stream_name": awslogs_stream_name,
"awslogs_group": log_options.get("awslogs-group", "/aws/batch/job"),
"awslogs_region": log_options.get("awslogs-region", self.conn_region_name),
}
log_options = [c.get("options", {}) for c in log_configs]

# cross stream names with options (i.e. attempts X nodes) to generate all log infos
result = []
for stream in stream_names:
for option in log_options:
result.append(
{
"awslogs_stream_name": stream,
# If the user did not specify anything, the default settings are:
# awslogs-group = /aws/batch/job
# awslogs-region = `same as AWS Batch Job region`
"awslogs_group": option.get("awslogs-group", "/aws/batch/job"),
"awslogs_region": option.get("awslogs-region", self.conn_region_name),
}
)
return result

@staticmethod
def add_jitter(delay: int | float, width: int | float = 1, minima: int | float = 0) -> float:
Expand Down
102 changes: 78 additions & 24 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"""
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Sequence

from airflow.compat.functools import cached_property
Expand Down Expand Up @@ -54,7 +55,9 @@ class BatchOperator(BaseOperator):
:param job_name: the name for the job that will run on AWS Batch (templated)
:param job_definition: the job definition name on AWS Batch
:param job_queue: the queue name on AWS Batch
:param overrides: the `containerOverrides` parameter for boto3 (templated)
:param overrides: DEPRECATED, use container_overrides instead with the same value.
:param container_overrides: the `containerOverrides` parameter for boto3 (templated)
:param node_overrides: the `nodeOverrides` parameter for boto3 (templated)
:param array_properties: the `arrayProperties` parameter for boto3
:param parameters: the `parameters` for boto3 (templated)
:param job_id: the job ID, usually unknown (None) until the
Expand Down Expand Up @@ -88,14 +91,19 @@ class BatchOperator(BaseOperator):
"job_name",
"job_definition",
"job_queue",
"overrides",
"container_overrides",
"array_properties",
"node_overrides",
"parameters",
"waiters",
"tags",
"wait_for_completion",
)
template_fields_renderers = {"overrides": "json", "parameters": "json"}
template_fields_renderers = {
"container_overrides": "json",
"parameters": "json",
"node_overrides": "json",
}

@property
def operator_extra_links(self):
Expand All @@ -114,8 +122,10 @@ def __init__(
job_name: str,
job_definition: str,
job_queue: str,
overrides: dict,
overrides: dict | None = None, # deprecated
container_overrides: dict | None = None,
array_properties: dict | None = None,
node_overrides: dict | None = None,
parameters: dict | None = None,
job_id: str | None = None,
waiters: Any | None = None,
Expand All @@ -133,17 +143,43 @@ def __init__(
self.job_name = job_name
self.job_definition = job_definition
self.job_queue = job_queue
self.overrides = overrides or {}
self.array_properties = array_properties or {}

self.container_overrides = container_overrides
# handle `overrides` deprecation in favor of `container_overrides`
if overrides:
if container_overrides:
# disallow setting both old and new params
raise AirflowException(
"'container_overrides' replaces the 'overrides' parameter. "
"You cannot specify both. Please remove assignation to the deprecated 'overrides'."
)
self.container_overrides = overrides
warnings.warn(
"Parameter `overrides` is deprecated, Please use `container_overrides` instead.",
DeprecationWarning,
stacklevel=2,
)

self.node_overrides = node_overrides
self.array_properties = array_properties
self.parameters = parameters or {}
self.waiters = waiters
self.tags = tags or {}
self.wait_for_completion = wait_for_completion
self.hook = BatchClientHook(
max_retries=max_retries,
status_retries=status_retries,
aws_conn_id=aws_conn_id,
region_name=region_name,

# params for hook
self.max_retries = max_retries
self.status_retries = status_retries
self.aws_conn_id = aws_conn_id
self.region_name = region_name

@cached_property
def hook(self) -> BatchClientHook:
return BatchClientHook(
max_retries=self.max_retries,
status_retries=self.status_retries,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
)

def execute(self, context: Context):
Expand Down Expand Up @@ -174,18 +210,27 @@ def submit_job(self, context: Context):
self.job_definition,
self.job_queue,
)
self.log.info("AWS Batch job - container overrides: %s", self.overrides)

if self.container_overrides:
self.log.info("AWS Batch job - container overrides: %s", self.container_overrides)
if self.array_properties:
self.log.info("AWS Batch job - array properties: %s", self.array_properties)
if self.node_overrides:
self.log.info("AWS Batch job - node properties: %s", self.node_overrides)

args = {
"jobName": self.job_name,
"jobQueue": self.job_queue,
"jobDefinition": self.job_definition,
"arrayProperties": self.array_properties,
"parameters": self.parameters,
"tags": self.tags,
"containerOverrides": self.container_overrides,
"nodeOverrides": self.node_overrides,
}

try:
response = self.hook.client.submit_job(
jobName=self.job_name,
jobQueue=self.job_queue,
jobDefinition=self.job_definition,
arrayProperties=self.array_properties,
parameters=self.parameters,
containerOverrides=self.overrides,
tags=self.tags,
)
response = self.hook.client.submit_job(**trim_none_values(args))
except Exception as e:
self.log.error(
"AWS Batch job failed submission - job definition: %s - on queue %s",
Expand Down Expand Up @@ -249,15 +294,24 @@ def monitor_job(self, context: Context):
else:
self.hook.wait_for_job(self.job_id)

awslogs = self.hook.get_job_awslogs_info(self.job_id)
awslogs = self.hook.get_job_all_awslogs_info(self.job_id)
if awslogs:
self.log.info("AWS Batch job (%s) CloudWatch Events details found: %s", self.job_id, awslogs)
self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id)
link_builder = CloudWatchEventsLink()
for log in awslogs:
self.log.info(link_builder.format_link(**log))
if len(awslogs) > 1:
# there can be several log streams on multi-node jobs
self.log.warning(
"out of all those logs, we can only link to one in the UI. " "Using the first one."
)

CloudWatchEventsLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
**awslogs,
**awslogs[0],
)

self.hook.check_job_success(self.job_id)
Expand Down
74 changes: 71 additions & 3 deletions tests/providers/amazon/aws/hooks/test_batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,24 @@ def test_job_no_awslogs_stream(self, caplog):
"jobs": [
{
"jobId": JOB_ID,
"container": {},
"container": {"logConfiguration": {}},
}
]
}

with caplog.at_level(level=logging.WARNING):
assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
assert len(caplog.records) == 1
assert "doesn't create AWS CloudWatch Stream" in caplog.messages[0]
assert "doesn't have any AWS CloudWatch Stream" in caplog.messages[0]

def test_job_not_recognized_job(self):
self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID}]}
with pytest.raises(AirflowException) as ctx:
self.batch_client.get_job_awslogs_info(JOB_ID)
# It should not retry when this client error occurs
self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])
msg = "is not a supported job type"
assert msg in str(ctx.value)

def test_job_splunk_logs(self, caplog):
self.client_mock.describe_jobs.return_value = {
Expand All @@ -307,7 +316,66 @@ def test_job_splunk_logs(self, caplog):
with caplog.at_level(level=logging.WARNING):
assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
assert len(caplog.records) == 1
assert "uses logDriver (splunk). AWS CloudWatch logging disabled." in caplog.messages[0]
assert "uses non-aws log drivers. AWS CloudWatch logging disabled." in caplog.messages[0]

def test_job_awslogs_multinode_job(self):
self.client_mock.describe_jobs.return_value = {
"jobs": [
{
"jobId": JOB_ID,
"attempts": [
{"container": {"exitCode": 0, "logStreamName": "test/stream/attempt0"}},
{"container": {"exitCode": 0, "logStreamName": "test/stream/attempt1"}},
],
"nodeProperties": {
"mainNode": 0,
"nodeRangeProperties": [
{
"targetNodes": "0:",
"container": {
"logConfiguration": {
"logDriver": "awslogs",
"options": {
"awslogs-group": "/test/batch/job-a",
"awslogs-region": AWS_REGION,
},
}
},
},
{
"targetNodes": "1:",
"container": {
"logConfiguration": {
"logDriver": "awslogs",
"options": {
"awslogs-group": "/test/batch/job-b",
"awslogs-region": AWS_REGION,
},
}
},
},
],
},
}
]
}
awslogs = self.batch_client.get_job_all_awslogs_info(JOB_ID)
assert len(awslogs) == 4
assert all([log["awslogs_region"] == AWS_REGION for log in awslogs])

combinations = {
("test/stream/attempt0", "/test/batch/job-a"): False,
("test/stream/attempt0", "/test/batch/job-b"): False,
("test/stream/attempt1", "/test/batch/job-a"): False,
("test/stream/attempt1", "/test/batch/job-b"): False,
}
for log_info in awslogs:
# mark combinations that we see
combinations[(log_info["awslogs_stream_name"], log_info["awslogs_group"])] = True

assert len(combinations) == 4
# all combinations listed above should have been seen
assert all(combinations.values())


class TestBatchClientDelays:
Expand Down
Loading