Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions airflow/providers/docker/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,24 +472,23 @@ def _attempt_to_retrieve_result(self):
This uses Docker's ``get_archive`` function. If the file is not yet
ready, *None* is returned.
"""

def copy_from_docker(container_id, src):
archived_result, stat = self.cli.get_archive(container_id, src)
if stat["size"] == 0:
# 0 byte file, it can't be anything else than None
return None
# no need to port to a file since we intend to deserialize
with BytesIO(b"".join(archived_result)) as f:
tar = tarfile.open(fileobj=f)
file = tar.extractfile(stat["name"])
lib = getattr(self, "pickling_library", pickle)
return lib.load(file)

try:
return copy_from_docker(self.container["Id"], self.retrieve_output_path)
return self._copy_from_docker(self.container["Id"], self.retrieve_output_path)
except APIError:
return None

def _copy_from_docker(self, container_id, src):
archived_result, stat = self.cli.get_archive(container_id, src)
if stat["size"] == 0:
# 0 byte file, it can't be anything else than None
return None
# no need to port to a file since we intend to deserialize
with BytesIO(b"".join(archived_result)) as f:
tar = tarfile.open(fileobj=f)
file = tar.extractfile(stat["name"])
lib = getattr(self, "pickling_library", pickle)
return lib.load(file)

def execute(self, context: Context) -> list[str] | str | None:
# Pull the docker image if `force_pull` is set or image does not exist locally
if self.force_pull or not self.cli.images(name=self.image):
Expand Down
38 changes: 38 additions & 0 deletions airflow/providers/docker/operators/docker_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import TYPE_CHECKING

from docker import types
from docker.errors import APIError

from airflow.exceptions import AirflowException
from airflow.providers.docker.operators.docker import DockerOperator
Expand Down Expand Up @@ -87,6 +88,10 @@ class DockerSwarmOperator(DockerOperator):
:param enable_logging: Show the application's logs in operator's logs.
Supported only if the Docker engine is using json-file or journald logging drivers.
The `tty` parameter should be set to use this with Python applications.
:param retrieve_output: Should this docker image consistently attempt to pull from and output
file before manually shutting down the image. Useful for cases where users want a pickle serialized
output that is not posted to logs
:param retrieve_output_path: path for output file that will be retrieved and passed to xcom
:param configs: List of docker configs to be exposed to the containers of the swarm service.
The configs are ConfigReference objects as per the docker api
[https://docker-py.readthedocs.io/en/stable/services.html#docker.models.services.ServiceCollection.create]_
Expand Down Expand Up @@ -122,6 +127,8 @@ def __init__(
self.args = args
self.enable_logging = enable_logging
self.service = None
self.tasks: list[dict] = []
self.containers: list[dict] = []
self.configs = configs
self.secrets = secrets
self.mode = mode
Expand Down Expand Up @@ -173,6 +180,18 @@ def _run_service(self) -> None:
self.log.info("Service status before exiting: %s", self._service_status())
break

if self.service and self._service_status() == "complete":
self.tasks = self.cli.tasks(filters={"service": self.service["ID"]})
for task in self.tasks:
container_id = task["Status"]["ContainerStatus"]["ContainerID"]
container = self.cli.inspect_container(container_id)
self.containers.append(container)
else:
raise AirflowException(f"Service did not complete: {self.service!r}")

if self.retrieve_output:
return self._attempt_to_retrieve_results()

self.log.info("auto_removeauto_removeauto_removeauto_removeauto_remove : %s", str(self.auto_remove))
if self.service and self._service_status() != "complete":
if self.auto_remove == "success":
Expand Down Expand Up @@ -230,6 +249,25 @@ def stream_new_logs(last_line_logged, since=0):
sleep(2)
last_line_logged, last_timestamp = stream_new_logs(last_line_logged, since=last_timestamp)

def _attempt_to_retrieve_results(self):
"""
Attempt to pull the result from the expected file for each containers.

This uses Docker's ``get_archive`` function. If the file is not yet
ready, *None* is returned.
"""
try:
file_contents = []
for container in self.containers:
file_content = self._copy_from_docker(container["Id"], self.retrieve_output_path)
file_contents.append(file_content)
if len(file_contents) == 1:
return file_contents[0]
else:
return file_contents
except APIError:
return None

@staticmethod
def format_args(args: list[str] | str | None) -> list[str] | None:
"""
Expand Down
24 changes: 17 additions & 7 deletions tests/providers/docker/operators/test_docker_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _client_tasks_side_effect():
for _ in range(2):
yield [{"Status": {"State": "pending"}}]
while True:
yield [{"Status": {"State": "complete"}}]
yield [{"Status": {"State": "complete", "ContainerStatus": {"ContainerID": "some_id"}}}]

def _client_service_logs_effect():
service_logs = [
Expand Down Expand Up @@ -123,7 +123,7 @@ def _client_service_logs_effect():
assert cskwargs["labels"] == {"name": "airflow__adhoc_airflow__unittest"}
assert cskwargs["name"].startswith("airflow-")
assert cskwargs["mode"] == types.ServiceMode(mode="replicated", replicas=3)
assert client_mock.tasks.call_count == 6
assert client_mock.tasks.call_count == 8
client_mock.remove_service.assert_called_once_with("some_id")

@mock.patch("airflow.providers.docker.operators.docker_swarm.types")
Expand All @@ -134,7 +134,9 @@ def test_auto_remove(self, types_mock, docker_api_client_patcher):
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
client_mock.tasks.return_value = [
{"Status": {"State": "complete", "ContainerStatus": {"ContainerID": "some_id"}}}
]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
Expand All @@ -157,7 +159,9 @@ def test_no_auto_remove(self, types_mock, docker_api_client_patcher):
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
client_mock.tasks.return_value = [
{"Status": {"State": "complete", "ContainerStatus": {"ContainerID": "some_id"}}}
]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
Expand Down Expand Up @@ -233,7 +237,9 @@ def test_container_resources(self, types_mock, docker_api_client_patcher):
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
client_mock.tasks.return_value = [
{"Status": {"State": "complete", "ContainerStatus": {"ContainerID": "some_id"}}}
]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
Expand Down Expand Up @@ -278,7 +284,9 @@ def test_service_args_str(self, types_mock, docker_api_client_patcher):
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
client_mock.tasks.return_value = [
{"Status": {"State": "complete", "ContainerStatus": {"ContainerID": "some_id"}}}
]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
Expand Down Expand Up @@ -316,7 +324,9 @@ def test_service_args_list(self, types_mock, docker_api_client_patcher):
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
client_mock.tasks.return_value = [
{"Status": {"State": "complete", "ContainerStatus": {"ContainerID": "some_id"}}}
]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
Expand Down