Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions airflow/providers/amazon/aws/protocols/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
92 changes: 92 additions & 0 deletions airflow/providers/amazon/aws/protocols/docker_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from datetime import datetime, timedelta, timezone
from functools import cached_property

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.models import Connection
from airflow.providers.amazon.aws.hooks.ecr import EcrHook

try:
from airflow.providers.docker.protocols.docker_registry import (
DockerRegistryCredentials,
RefreshableDockerRegistryAuthProtocol,
)
except ImportError:
raise AirflowOptionalProviderFeatureException(
"Failed to import `airflow.providers.docker.protocols.docker_registry`, "
"required version of Docker Provider not installed, run: "
"pip install 'apache-airflow-providers-amazon[docker]'"
)


class EcrDockerRegistryAuthProtocol(RefreshableDockerRegistryAuthProtocol):
"""Implementation of DockerRegistryAuthProtocol for ECR."""

def __init__(
self,
*,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
registry_ids: list[str] | str | None = None,
):
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.registry_ids = registry_ids
self._expires_at: datetime | None = None

@property
def need_refresh(self) -> bool:
if self.expires_at:
return (self.expires_at - datetime.now(tz=timezone.utc)) < timedelta(minutes=5)
return False

@property
def expires_at(self) -> datetime | None:
return self._expires_at

@expires_at.setter
def expires_at(self, value: datetime):
if not self._expires_at or self.need_refresh:
self._expires_at = value
elif self._expires_at > value:
self._expires_at = value # Use lower value

@cached_property
def hook(self) -> EcrHook:
return EcrHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

def get_credentials(self, *, conn: Connection | None) -> list[DockerRegistryCredentials]:
credentials = []
registry_ids = self.registry_ids or self.hook.service_config.get("registry_ids", None)
for ecr_creds in self.hook.get_temporary_credentials(registry_ids=registry_ids):
self.expires_at = ecr_creds.expires_at
credentials.append(
DockerRegistryCredentials(
username=ecr_creds.username,
password=ecr_creds.password,
registry=ecr_creds.registry,
reauth=True,
)
)
return credentials

def refresh_credentials(self, *, conn: Connection | None) -> list[DockerRegistryCredentials]:
return self.get_credentials(conn=conn)
3 changes: 3 additions & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,9 @@ additional-extras:
- name: cncf.kubernetes
dependencies:
- apache-airflow-providers-cncf-kubernetes>=7.2.0
- name: docker
dependencies:
- apache-airflow-providers-docker>3.7.3

config:
aws:
Expand Down
9 changes: 1 addition & 8 deletions airflow/providers/docker/decorators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,7 @@

from airflow.decorators.base import DecoratedOperator, task_decorator_factory
from airflow.providers.docker.operators.docker import DockerOperator

try:
from airflow.utils.decorators import remove_task_decorator

# This can be removed after we move to Airflow 2.4+
except ImportError:
from airflow.utils.python_virtualenv import remove_task_decorator

from airflow.utils.decorators import remove_task_decorator
from airflow.utils.python_virtualenv import write_python_script

if TYPE_CHECKING:
Expand Down
115 changes: 73 additions & 42 deletions airflow/providers/docker/hooks/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,15 @@
from docker.constants import DEFAULT_TIMEOUT_SECONDS
from docker.errors import APIError

from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.docker.protocols.docker_registry import (
AirflowConnectionDockerRegistryAuth,
DockerRegistryAuthProtocol,
DockerRegistryCredentials,
NoDockerRegistryAuth,
RefreshableDockerRegistryAuthProtocol,
)

if TYPE_CHECKING:
from airflow.models import Connection
Expand All @@ -50,6 +57,9 @@ class DockerHook(BaseHook):
:param tls: Is connection required TLS, for enable pass ``True`` for use with default options,
or pass a `docker.tls.TLSConfig` object to use custom configurations.
:param timeout: Default timeout for API calls, in seconds.
:param registry_auth: Object which use for auth to Docker Registry, should implement
class:`airflow.providers.docker.protocols.docker_registry.DockerRegistryAuthProtocol`,
if set to ``None`` then auto-assign object depend on value of ``docker_conn_id``.
"""

conn_name_attr = "docker_conn_id"
Expand All @@ -64,6 +74,7 @@ def __init__(
version: str | None = None,
tls: TLSConfig | bool | None = None,
timeout: int = DEFAULT_TIMEOUT_SECONDS,
registry_auth: DockerRegistryAuthProtocol | None = None,
) -> None:
super().__init__()
if not base_url:
Expand All @@ -82,6 +93,21 @@ def __init__(
self.__timeout = timeout
self._client_created = False

if registry_auth is None:
# Set registry_auth based on `docker_conn_id` value
if self.docker_conn_id:
registry_auth = AirflowConnectionDockerRegistryAuth()
else:
registry_auth = NoDockerRegistryAuth()
elif not isinstance(registry_auth, DockerRegistryAuthProtocol):
raise TypeError(
"'registry_auth' expected DockerRegistryAuthProtocol, "
f"but got {type(registry_auth).__name__}."
)

self._registry_auth: DockerRegistryAuthProtocol = registry_auth
self._api_client: APIClient | None = None

@staticmethod
def construct_tls_config(
ca_cert: str | None = None,
Expand Down Expand Up @@ -115,17 +141,52 @@ def construct_tls_config(
return False

@cached_property
def _airflow_connection(self) -> Connection | None:
"""Return Airflow connection associated with `docker_conn_id` (cached)."""
if not self.docker_conn_id:
return None
return self.get_connection(self.docker_conn_id)

@property
def api_client(self) -> APIClient:
"""Create connection to docker host and return ``docker.APIClient`` (cached)."""
client = APIClient(
base_url=self.__base_url, version=self.__version, tls=self.__tls, timeout=self.__timeout
)
if self.docker_conn_id:
# Obtain connection and try to login to Container Registry only if ``docker_conn_id`` set.
self.__login(client, self.get_connection(self.docker_conn_id))
"""Create connection to docker host and return ``docker.APIClient``."""
if not self._api_client:
# Create client only once
self._api_client = APIClient(
base_url=self.__base_url, version=self.__version, tls=self.__tls, timeout=self.__timeout
)
self._login(
self._api_client,
self._registry_auth.get_credentials(conn=self._airflow_connection),
)
self._client_created = True
elif (
isinstance(self._registry_auth, RefreshableDockerRegistryAuthProtocol)
and self._registry_auth.need_refresh
):
self._login(
self._api_client,
self._registry_auth.refresh_credentials(conn=self._airflow_connection),
reauth=True,
)

self._client_created = True
return client
return self._api_client

def _login(self, client: APIClient, credentials: list[DockerRegistryCredentials], reauth=False) -> None:
for rc in credentials:
try:
self.log.info("Login into Docker Registry: %s", rc.registry)
client.login(
username=rc.username,
password=rc.password,
registry=rc.registry,
email=rc.email,
reauth=reauth or rc.reauth, # Force reauth on refresh credentials
)
except APIError:
self.log.error("Login failed to registry: %s", rc.registry)
raise
self.log.debug("Login successful to registry: %s", rc.registry)

@property
def client_created(self) -> bool:
Expand All @@ -136,38 +197,8 @@ def get_conn(self) -> APIClient:
"""Create connection to docker host and return ``docker.APIClient`` (cached)."""
return self.api_client

def __login(self, client, conn: Connection) -> None:
if not conn.host:
raise AirflowNotFoundException("No Docker Registry URL provided.")
if not conn.login:
raise AirflowNotFoundException("No Docker Registry username provided.")

registry = f"{conn.host}:{conn.port}" if conn.port else conn.host

# Parse additional optional parameters
email = conn.extra_dejson.get("email") or None
reauth = conn.extra_dejson.get("reauth", True)
if isinstance(reauth, str):
reauth = reauth.lower()
if reauth in ("y", "yes", "t", "true", "on", "1"):
reauth = True
elif reauth in ("n", "no", "f", "false", "off", "0"):
reauth = False
else:
raise ValueError(f"Unable parse `reauth` value {reauth!r} to bool.")

try:
self.log.info("Login into Docker Registry: %s", registry)
client.login(
username=conn.login, password=conn.password, registry=registry, email=email, reauth=reauth
)
self.log.debug("Login successful")
except APIError:
self.log.error("Login failed")
raise

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection form widgets."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
Expand Down
7 changes: 7 additions & 0 deletions airflow/providers/docker/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
DockerContainerFailedSkipException,
)
from airflow.providers.docker.hooks.docker import DockerHook
from airflow.providers.docker.protocols.docker_registry import DockerRegistryAuthProtocol

if TYPE_CHECKING:
from docker import APIClient
Expand Down Expand Up @@ -166,6 +167,9 @@ class DockerOperator(BaseOperator):
dictionary of value where the key indicates the port to open inside the container
and value indicates the host port that binds to the container port.
Incompatible with ``host`` in ``network_mode``.
:param registry_auth: Object which use for auth to Docker Registry, should implement
class:`airflow.providers.docker.protocols.docker_registry.DockerRegistryAuthProtocol`,
if set to ``None`` then auto-assign object depend on value of ``docker_conn_id``.
"""

template_fields: Sequence[str] = ("image", "command", "environment", "env_file", "container_name")
Expand Down Expand Up @@ -225,6 +229,7 @@ def __init__(
skip_exit_code: int | None = None,
skip_on_exit_code: int | Container[int] | None = None,
port_bindings: dict | None = None,
registry_auth: DockerRegistryAuthProtocol | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -304,6 +309,7 @@ def __init__(
self.port_bindings = port_bindings or {}
if self.port_bindings and self.network_mode == "host":
raise ValueError("Port bindings is not supported in the host network mode")
self.registry_auth = registry_auth

@cached_property
def hook(self) -> DockerHook:
Expand All @@ -322,6 +328,7 @@ def hook(self) -> DockerHook:
version=self.api_version,
tls=tls_config,
timeout=self.timeout,
registry_auth=self.registry_auth,
)

def get_hook(self) -> DockerHook:
Expand Down
16 changes: 16 additions & 0 deletions airflow/providers/docker/protocols/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading