Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
GetPreviousDagRun,
GetPrevSuccessfulDagRun,
GetVariable,
MaskSecret,
OKResponse,
PreviousDagRunResult,
PrevSuccessfulDagRunResult,
Expand Down Expand Up @@ -106,6 +107,7 @@ class DagFileParsingResult(BaseModel):
DeleteVariable,
GetPrevSuccessfulDagRun,
GetPreviousDagRun,
MaskSecret,
],
Field(discriminator="type"),
]
Expand Down Expand Up @@ -431,6 +433,10 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int
dagrun_result = PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
resp = dagrun_result
dump_opts = {"exclude_unset": True}
elif isinstance(msg, MaskSecret):
from airflow.sdk.execution_time.secrets_masker import mask_secret

mask_secret(msg.value, msg.name)
else:
log.error("Unhandled request", msg=msg)
self.send_msg(
Expand Down
6 changes: 6 additions & 0 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
GetTICount,
GetVariable,
GetXCom,
MaskSecret,
OKResponse,
PutVariable,
SetXCom,
Expand Down Expand Up @@ -250,6 +251,7 @@ class TriggerStateSync(BaseModel):
GetTaskStates,
GetDagRunState,
GetDRCount,
MaskSecret,
],
Field(discriminator="type"),
]
Expand Down Expand Up @@ -472,6 +474,10 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r
resp = TaskStatesResult.from_api_response(run_id_task_state_map)
else:
resp = run_id_task_state_map
elif isinstance(msg, MaskSecret):
from airflow.sdk.execution_time.secrets_masker import mask_secret

mask_secret(msg.value, msg.name)
else:
raise ValueError(f"Unknown message type {type(msg)}")

Expand Down
12 changes: 9 additions & 3 deletions airflow-core/tests/unit/hooks/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
# under the License.
from __future__ import annotations

from unittest.mock import call

import pytest

from airflow.exceptions import AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, GetConnection
from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, GetConnection, MaskSecret

from tests_common.test_utils.config import conf_vars

Expand Down Expand Up @@ -56,8 +58,12 @@ def test_get_connection(self, mock_supervisor_comms):

hook = BaseHook(logger_name="")
hook.get_connection(conn_id="test_conn")
mock_supervisor_comms.send.assert_called_once_with(
msg=GetConnection(conn_id="test_conn"),
mock_supervisor_comms.send.assert_has_calls(
[
call(GetConnection(conn_id="test_conn", type="GetConnection")),
call(MaskSecret(value="password", name=None, type="MaskSecret")),
call(MaskSecret(value='{"extra_key": "extra_value"}', name=None, type="MaskSecret")),
]
)

def test_get_connection_not_found(self, mock_supervisor_comms):
Expand Down
2 changes: 2 additions & 0 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,11 +626,13 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]:

from airflow.sdk import Variable
from airflow.sdk.execution_time.xcom import XCom
from airflow.sdk.log import mask_secret

conn = await sync_to_async(BaseHook.get_connection)("test_connection")
self.log.info("Loaded conn %s", conn.conn_id)

get_variable_value = await sync_to_async(Variable.get)("test_get_variable")
await sync_to_async(mask_secret)(get_variable_value)
self.log.info("Loaded variable %s", get_variable_value)

get_xcom_value = await sync_to_async(XCom.get_one)(
Expand Down
7 changes: 4 additions & 3 deletions task-sdk/docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,12 @@ I/O Helpers
Execution Time Components
-------------------------
.. rubric:: Context

.. autoapiclass:: airflow.sdk.Context
.. autoapimodule:: airflow.sdk.execution_time.context
:members:
:undoc-members:

.. rubric:: Logging

.. autofunction:: airflow.sdk.log.mask_secret

Everything else
---------------
Expand Down
6 changes: 5 additions & 1 deletion task-sdk/src/airflow/sdk/definitions/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,15 @@ def get(cls, conn_id: str) -> Any:
@property
def extra_dejson(self) -> dict:
"""Deserialize `extra` property to JSON."""
from airflow.sdk.execution_time.secrets_masker import mask_secret

extra = {}
if self.extra:
try:
extra = json.loads(self.extra)
except JSONDecodeError:
log.exception("Failed to deserialize extra property `extra`, returning empty dictionary")
# TODO: Mask sensitive keys from this list or revisit if it will be done in server
else:
mask_secret(extra)

return extra
2 changes: 2 additions & 0 deletions task-sdk/src/airflow/sdk/definitions/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import attrs

from airflow.sdk.definitions._internal.types import NOTSET
from airflow.sdk.log import mask_secret

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,6 +54,7 @@ def get(cls, key: str, default: Any = NOTSET, deserialize_json: bool = False):
return _get_variable(key, deserialize_json=deserialize_json)
except AirflowRuntimeError as e:
if e.error.error == ErrorType.VARIABLE_NOT_FOUND and default is not NOTSET:
mask_secret(default, name=key)
return default
raise

Expand Down
14 changes: 13 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from __future__ import annotations

import itertools
from collections.abc import Iterator
from collections.abc import Iterable, Iterator
from datetime import datetime
from functools import cached_property
from pathlib import Path
Expand Down Expand Up @@ -858,6 +858,17 @@ class GetDRCount(BaseModel):
type: Literal["GetDRCount"] = "GetDRCount"


class MaskSecret(BaseModel):
"""Add a new value to be redacted in task logs."""

# This is needed since calls to `mask_secret` in the Task process will otherwise only add the mask value
# to the child process, but the redaction happens in the parent.

value: str | dict | Iterable
name: str | None = None
type: Literal["MaskSecret"] = "MaskSecret"


ToSupervisor = Annotated[
Union[
DeferTask,
Expand Down Expand Up @@ -891,6 +902,7 @@ class GetDRCount(BaseModel):
TriggerDagRun,
DeleteVariable,
ResendLoggingFD,
MaskSecret,
],
Field(discriminator="type"),
]
5 changes: 5 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def _get_connection(conn_id: str) -> Connection:
try:
conn = secrets_backend.get_connection(conn_id=conn_id)
if conn:
# TODO: this should probably be in get conn
if conn.password:
mask_secret(conn.password)
if conn.extra:
mask_secret(conn.extra)
return conn
except Exception:
log.exception(
Expand Down
23 changes: 16 additions & 7 deletions task-sdk/src/airflow/sdk/execution_time/secrets_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,28 @@ def should_hide_value_for_key(name):

def mask_secret(secret: str | dict | Iterable, name: str | None = None) -> None:
"""
Mask a secret from appearing in the task logs.
Mask a secret from appearing in the logs.

If ``name`` is provided, then it will only be masked if the name matches
one of the configured "sensitive" names.
If ``name`` is provided, then it will only be masked if the name matches one of the configured "sensitive"
names.

If ``secret`` is a dict or a iterable (excluding str) then it will be
recursively walked and keys with sensitive names will be hidden.
If ``secret`` is a dict or a iterable (excluding str) then it will be recursively walked and keys with
sensitive names will be hidden.

If the secret value is too short (by default 5 characters or fewer, configurable via the
:ref:`[logging] min_length_masked_secret <config:logging__min_length_masked_secret>` setting) it will not
be masked
"""
# Filtering all log messages is not a free process, so we only do it when
# running tasks
if not secret:
return

from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time.comms import MaskSecret

if comms := getattr(task_runner, "SUPERVISOR_COMMS", None):
# Tell the parent, the process which handles all logs writing and output, about the values to mask
comms.send(MaskSecret(value=secret, name=name))

_secrets_masker().add_mask(secret, name)


Expand Down
8 changes: 7 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
GetXComSequenceItem,
GetXComSequenceSlice,
InactiveAssetsResult,
MaskSecret,
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
Expand Down Expand Up @@ -1064,7 +1065,10 @@ def final_state(self):
return TaskInstanceState.FAILED

def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: int):
log.debug("Received message from task runner", msg=msg)
if isinstance(msg, MaskSecret):
log.debug("Received message from task runner (body omitted)", msg=type(msg))
else:
log.debug("Received message from task runner", msg=msg)
resp: BaseModel | None = None
dump_opts = {}
if isinstance(msg, TaskState):
Expand Down Expand Up @@ -1253,6 +1257,8 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
self._send_new_log_fd(req_id)
# Since we've sent the message, return. Nothing else in this ifelse/switch should return directly
return
elif isinstance(msg, MaskSecret):
mask_secret(msg.value, msg.name)
else:
log.error("Unhandled request", msg=msg)
self.send_msg(
Expand Down
13 changes: 13 additions & 0 deletions task-sdk/src/airflow/sdk/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,19 @@
import structlog

if TYPE_CHECKING:
from collections.abc import Callable

from structlog.typing import EventDict, ExcInfo, FilteringBoundLogger, Processor

from airflow.logging_config import RemoteLogIO
from airflow.sdk.execution_time.secrets_masker import mask_secret as mask_secret
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI


__all__ = [
"configure_logging",
"reset_logging",
"mask_secret",
]


Expand Down Expand Up @@ -568,3 +572,12 @@ def upload_to_remote(logger: FilteringBoundLogger, ti: RuntimeTI):

log_relative_path = relative_path.as_posix()
handler.upload(log_relative_path, ti)


def __getattr__(name: str):
if name == "mask_secret":
from airflow.sdk.execution_time.secrets_masker import mask_secret

globals()["mask_secret"] = mask_secret
return mask_secret
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Loading