Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions providers/src/airflow/providers/docker/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,23 +472,24 @@ def _attempt_to_retrieve_result(self):
This uses Docker's ``get_archive`` function. If the file is not yet
ready, *None* is returned.
"""

def copy_from_docker(container_id, src):
archived_result, stat = self.cli.get_archive(container_id, src)
if stat["size"] == 0:
# 0 byte file, it can't be anything else than None
return None
# no need to port to a file since we intend to deserialize
with BytesIO(b"".join(archived_result)) as f:
tar = tarfile.open(fileobj=f)
file = tar.extractfile(stat["name"])
lib = getattr(self, "pickling_library", pickle)
return lib.load(file)

try:
return self._copy_from_docker(self.container["Id"], self.retrieve_output_path)
return copy_from_docker(self.container["Id"], self.retrieve_output_path)
except APIError:
return None

def _copy_from_docker(self, container_id, src):
archived_result, stat = self.cli.get_archive(container_id, src)
if stat["size"] == 0:
# 0 byte file, it can't be anything else than None
return None
# no need to port to a file since we intend to deserialize
with BytesIO(b"".join(archived_result)) as f:
tar = tarfile.open(fileobj=f)
file = tar.extractfile(stat["name"])
lib = getattr(self, "pickling_library", pickle)
return lib.load(file)

def execute(self, context: Context) -> list[str] | str | None:
# Pull the docker image if `force_pull` is set or image does not exist locally
if self.force_pull or not self.cli.images(name=self.image):
Expand Down Expand Up @@ -544,4 +545,4 @@ def unpack_environment_variables(env_str: str) -> dict:
separated by a ``\n`` (newline)
:return: dictionary containing parsed environment variables
"""
return dotenv_values(stream=StringIO(env_str))
return dotenv_values(stream=StringIO(env_str))
88 changes: 23 additions & 65 deletions providers/src/airflow/providers/docker/operators/docker_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
import shlex
from datetime import datetime
from time import sleep
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Optional

from docker import types
from docker.errors import APIError

from airflow.exceptions import AirflowException
from airflow.providers.docker.operators.docker import DockerOperator
Expand Down Expand Up @@ -88,10 +87,6 @@ class DockerSwarmOperator(DockerOperator):
:param enable_logging: Show the application's logs in operator's logs.
Supported only if the Docker engine is using json-file or journald logging drivers.
The `tty` parameter should be set to use this with Python applications.
:param retrieve_output: Should this docker image consistently attempt to pull from and output
file before manually shutting down the image. Useful for cases where users want a pickle serialized
output that is not posted to logs
:param retrieve_output_path: path for output file that will be retrieved and passed to xcom
:param configs: List of docker configs to be exposed to the containers of the swarm service.
The configs are ConfigReference objects as per the docker api
[https://docker-py.readthedocs.io/en/stable/services.html#docker.models.services.ServiceCollection.create]_
Expand All @@ -107,16 +102,6 @@ class DockerSwarmOperator(DockerOperator):
The resources are Resources as per the docker api
[https://docker-py.readthedocs.io/en/stable/api.html#docker.types.Resources]_
This parameter has precedence on the mem_limit parameter.
:param logging_driver: The logging driver to use for container logs. Docker by default uses 'json-file'.
For more information on Docker logging drivers: https://docs.docker.com/engine/logging/configure/
NOTE: Only drivers 'json-file' and 'gelf' are currently supported. If left empty, 'json-file' will be used.
:param logging_driver_opts: Dictionary of logging options to use with the associated logging driver chosen.
Depending on the logging driver, some options are required.
Failure to include them, will result in the operator failing.
All option values must be strings and wrapped in double quotes.
For information on 'json-file' options: https://docs.docker.com/engine/logging/drivers/json-file/
For information on 'gelf' options: https://docs.docker.com/engine/logging/drivers/gelf/
NOTE: 'gelf' driver requires the 'gelf-address' option to be set.
"""

def __init__(
Expand All @@ -131,34 +116,20 @@ def __init__(
networks: list[str | types.NetworkAttachmentConfig] | None = None,
placement: types.Placement | list[types.Placement] | None = None,
container_resources: types.Resources | None = None,
logging_driver: Literal["json-path", "gelf"] | None = None,
logging_driver_opts: dict | None = None,
hosts: Optional[dict[str, str]],
**kwargs,
) -> None:
super().__init__(image=image, **kwargs)
self.args = args
self.enable_logging = enable_logging
self.service = None
self.tasks: list[dict] = []
self.containers: list[dict] = []
self.configs = configs
self.secrets = secrets
self.mode = mode
self.networks = networks
self.placement = placement
self.container_resources = container_resources or types.Resources(mem_limit=self.mem_limit)
self.logging_driver = logging_driver
self.logging_driver_opts = logging_driver_opts

if self.logging_driver:
supported_logging_drivers = ("json-file", "gelf")
if self.logging_driver not in supported_logging_drivers:
raise AirflowException(
f"Invalid logging driver provided: {self.logging_driver}. Must be one of: [{', '.join(supported_logging_drivers)}]"
)
self.log_driver_config = types.DriverConfig(self.logging_driver, self.logging_driver_opts)
else:
self.log_driver_config = None
self.hosts = hosts or {}

def execute(self, context: Context) -> None:
self.environment["AIRFLOW_TMP_DIR"] = self.tmp_dir
Expand All @@ -176,14 +147,14 @@ def _run_service(self) -> None:
env=self.environment,
user=self.user,
tty=self.tty,
hosts=self.hosts,
configs=self.configs,
secrets=self.secrets,
),
restart_policy=types.RestartPolicy(condition="none"),
resources=self.container_resources,
networks=self.networks,
placement=self.placement,
log_driver=self.log_driver_config,
),
name=f"airflow-{get_random_string()}",
labels={"name": f"airflow__{self.dag_id}__{self.task_id}"},
Expand All @@ -204,19 +175,16 @@ def _run_service(self) -> None:
if self._has_service_terminated():
self.log.info("Service status before exiting: %s", self._service_status())
break

if self.service and self._service_status() == "complete":
self.tasks = self.cli.tasks(filters={"service": self.service["ID"]})
for task in self.tasks:
container_id = task["Status"]["ContainerStatus"]["ContainerID"]
container = self.cli.inspect_container(container_id)
self.containers.append(container)
else:
raise AirflowException(f"Service did not complete: {self.service!r}")

if self.retrieve_output:
return self._attempt_to_retrieve_results()

logs = None
if self.do_xcom_push:
all_logs = self.get_logs()
if self.xcom_all:
# Get all logs
logs = "\n".join(all_logs)
else:
if len(all_logs):
# get last log
logs = all_logs[-1]
self.log.info("auto_removeauto_removeauto_removeauto_removeauto_remove : %s", str(self.auto_remove))
if self.service and self._service_status() != "complete":
if self.auto_remove == "success":
Expand All @@ -226,6 +194,7 @@ def _run_service(self) -> None:
if not self.service:
raise RuntimeError("The 'service' should be initialized before!")
self.cli.remove_service(self.service["ID"])
return logs

def _service_status(self) -> str | None:
if not self.service:
Expand All @@ -236,6 +205,14 @@ def _has_service_terminated(self) -> bool:
status = self._service_status()
return status in ["complete", "failed", "shutdown", "rejected", "orphaned", "remove"]

def get_logs(self) -> list[str]:
logs = self.cli.service_logs(
self.service["ID"],
stdout=True,
stderr=True,
)
return list(map(lambda line: line.decode("utf-8"), logs))

def _stream_logs_to_output(self) -> None:
if not self.service:
raise RuntimeError("The 'service' should be initialized before!")
Expand Down Expand Up @@ -274,25 +251,6 @@ def stream_new_logs(last_line_logged, since=0):
sleep(2)
last_line_logged, last_timestamp = stream_new_logs(last_line_logged, since=last_timestamp)

def _attempt_to_retrieve_results(self):
"""
Attempt to pull the result from the expected file for each containers.

This uses Docker's ``get_archive`` function. If the file is not yet
ready, *None* is returned.
"""
try:
file_contents = []
for container in self.containers:
file_content = self._copy_from_docker(container["Id"], self.retrieve_output_path)
file_contents.append(file_content)
if len(file_contents) == 1:
return file_contents[0]
else:
return file_contents
except APIError:
return None

@staticmethod
def format_args(args: list[str] | str | None) -> list[str] | None:
"""
Expand Down
52 changes: 1 addition & 51 deletions providers/tests/docker/operators/test_docker_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def _client_service_logs_effect():
mode=types.ServiceMode(mode="replicated", replicas=3),
networks=["dummy_network"],
placement=types.Placement(constraints=["node.labels.region==east"]),
logging_driver=None,
logging_driver_opts=None,
)
caplog.clear()
operator.execute(None)
Expand All @@ -88,7 +86,6 @@ def _client_service_logs_effect():
resources=mock_obj,
networks=["dummy_network"],
placement=types.Placement(constraints=["node.labels.region==east"]),
log_driver=None,
)
types_mock.ContainerSpec.assert_called_once_with(
image="ubuntu:latest",
Expand Down Expand Up @@ -262,8 +259,6 @@ def test_container_resources(self, types_mock, docker_api_client_patcher):
cpu_reservation=100000000,
mem_reservation=67108864,
),
logging_driver=None,
logging_driver_opts=None,
)
operator.execute(None)

Expand All @@ -278,7 +273,6 @@ def test_container_resources(self, types_mock, docker_api_client_patcher):
),
networks=None,
placement=None,
log_driver=None,
)
types_mock.Resources.assert_not_called()

Expand Down Expand Up @@ -360,48 +354,4 @@ def test_service_args_list(self, types_mock, docker_api_client_patcher):
env={"AIRFLOW_TMP_DIR": "/tmp/airflow"},
configs=None,
secrets=None,
)

@mock.patch("airflow.providers.docker.operators.docker_swarm.types")
def test_logging_driver(self, types_mock, docker_api_client_patcher):
mock_obj = mock.Mock()

client_mock = mock.Mock(spec=APIClient)
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
types_mock.Resources.return_value = mock_obj

docker_api_client_patcher.return_value = client_mock

operator = DockerSwarmOperator(
image="", logging_driver="json-file", task_id="unittest", enable_logging=False
)

assert operator.logging_driver == "json-file"

@mock.patch("airflow.providers.docker.operators.docker_swarm.types")
def test_invalid_logging_driver(self, types_mock, docker_api_client_patcher):
mock_obj = mock.Mock()

client_mock = mock.Mock(spec=APIClient)
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
types_mock.Resources.return_value = mock_obj

docker_api_client_patcher.return_value = client_mock

msg = "Invalid logging driver provided: json. Must be one of: [json-file, gelf]"
with pytest.raises(AirflowException) as e:
# Exception is raised in __init__()
DockerSwarmOperator(image="", logging_driver="json", task_id="unittest", enable_logging=False)
assert str(e.value) == msg
)