Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def initialize_method_map() -> dict[str, Callable]:
# XCom.get_many, # Not supported because it returns query
XCom.clear,
XCom.set,
Variable.set,
Variable.update,
Variable.delete,
Variable._set,
Variable._update,
Variable._delete,
DAG.fetch_callback,
DAG.fetch_dagrun,
DagRun.fetch_task_instances,
Expand Down Expand Up @@ -237,7 +237,8 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse:
response = json.dumps(output_json) if output_json is not None else None
log.info("Sending response: %s", response)
return Response(response=response, headers={"Content-Type": "application/json"})
except AirflowException as e: # In case of AirflowException transport the exception class back to caller
# In case of AirflowException or other selective known types, transport the exception class back to caller
except (KeyError, AttributeError, AirflowException) as e:
exception_json = BaseSerialization.serialize(e, use_pydantic_models=True)
response = json.dumps(exception_json)
log.info("Sending exception response: %s", response)
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_internal/internal_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def wrapper(*args, **kwargs):
if result is None or result == b"":
return None
result = BaseSerialization.deserialize(json.loads(result), use_pydantic_models=True)
if isinstance(result, AirflowException):
if isinstance(result, (KeyError, AttributeError, AirflowException)):
raise result
return result

Expand Down
66 changes: 63 additions & 3 deletions airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def get(

@staticmethod
@provide_session
@internal_api_call
def set(
key: str,
value: Any,
Expand All @@ -167,6 +166,35 @@ def set(

This operation overwrites an existing variable.

:param key: Variable Key
:param value: Value to set for the Variable
:param description: Description of the Variable
:param serialize_json: Serialize the value to a JSON string
:param session: Session
"""
Variable._set(
key=key, value=value, description=description, serialize_json=serialize_json, session=session
)
# invalidate key in cache for faster propagation
# we cannot save the value set because it's possible that it's shadowed by a custom backend
# (see call to check_for_write_conflict above)
SecretCache.invalidate_variable(key)

@staticmethod
@provide_session
@internal_api_call
def _set(
key: str,
value: Any,
description: str | None = None,
serialize_json: bool = False,
session: Session = None,
) -> None:
"""
Set a value for an Airflow Variable with a given Key.

This operation overwrites an existing variable.

:param key: Variable Key
:param value: Value to set for the Variable
:param description: Description of the Variable
Expand All @@ -190,7 +218,6 @@ def set(

@staticmethod
@provide_session
@internal_api_call
def update(
key: str,
value: Any,
Expand All @@ -200,6 +227,27 @@ def update(
"""
Update a given Airflow Variable with the Provided value.

:param key: Variable Key
:param value: Value to set for the Variable
:param serialize_json: Serialize the value to a JSON string
:param session: Session
"""
Variable._update(key=key, value=value, serialize_json=serialize_json, session=session)
# We need to invalidate the cache for internal API cases on the client side
SecretCache.invalidate_variable(key)

@staticmethod
@provide_session
@internal_api_call
def _update(
key: str,
value: Any,
serialize_json: bool = False,
session: Session = None,
) -> None:
"""
Update a given Airflow Variable with the Provided value.

:param key: Variable Key
:param value: Value to set for the Variable
:param serialize_json: Serialize the value to a JSON string
Expand All @@ -219,11 +267,23 @@ def update(

@staticmethod
@provide_session
@internal_api_call
def delete(key: str, session: Session = None) -> int:
"""
Delete an Airflow Variable for a given key.

:param key: Variable Keys
"""
rows = Variable._delete(key=key, session=session)
SecretCache.invalidate_variable(key)
return rows

@staticmethod
@provide_session
@internal_api_call
def _delete(key: str, session: Session = None) -> int:
"""
Delete an Airflow Variable for a given key.

:param key: Variable Keys
"""
rows = session.execute(delete(Variable).where(Variable.key == key)).rowcount
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class DagAttributeTypes(str, Enum):
RELATIVEDELTA = "relativedelta"
BASE_TRIGGER = "base_trigger"
AIRFLOW_EXC_SER = "airflow_exc_ser"
BASE_EXC_SER = "base_exc_ser"
DICT = "dict"
SET = "set"
TUPLE = "tuple"
Expand Down
16 changes: 14 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,15 @@ def serialize(
),
type_=DAT.AIRFLOW_EXC_SER,
)
elif isinstance(var, (KeyError, AttributeError)):
return cls._encode(
cls.serialize(
{"exc_cls_name": var.__class__.__name__, "args": [var.args], "kwargs": {}},
use_pydantic_models=use_pydantic_models,
strict=strict,
),
type_=DAT.BASE_EXC_SER,
)
elif isinstance(var, BaseTrigger):
return cls._encode(
cls.serialize(var.serialize(), use_pydantic_models=use_pydantic_models, strict=strict),
Expand Down Expand Up @@ -834,13 +843,16 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
return decode_timezone(var)
elif type_ == DAT.RELATIVEDELTA:
return decode_relativedelta(var)
elif type_ == DAT.AIRFLOW_EXC_SER:
elif type_ == DAT.AIRFLOW_EXC_SER or type_ == DAT.BASE_EXC_SER:
deser = cls.deserialize(var, use_pydantic_models=use_pydantic_models)
exc_cls_name = deser["exc_cls_name"]
args = deser["args"]
kwargs = deser["kwargs"]
del deser
exc_cls = import_string(exc_cls_name)
if type_ == DAT.AIRFLOW_EXC_SER:
exc_cls = import_string(exc_cls_name)
else:
exc_cls = import_string(f"builtins.{exc_cls_name}")
return exc_cls(*args, **kwargs)
elif type_ == DAT.BASE_TRIGGER:
tr_cls_name, kwargs = cls.deserialize(var, use_pydantic_models=use_pydantic_models)
Expand Down
8 changes: 6 additions & 2 deletions tests/models/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def setup_test_cases(self):
db.clear_db_variables()
crypto._fernet = None

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other fernet
@conf_vars({("core", "fernet_key"): "", ("core", "unit_test_mode"): "True"})
def test_variable_no_encryption(self, session):
"""
Expand All @@ -60,6 +61,7 @@ def test_variable_no_encryption(self, session):
# should mask anything. That logic is tested in test_secrets_masker.py
self.mask_secret.assert_called_once_with("value", "key")

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other fernet
@conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()})
def test_variable_with_encryption(self, session):
"""
Expand All @@ -70,6 +72,7 @@ def test_variable_with_encryption(self, session):
assert test_var.is_encrypted
assert test_var.val == "value"

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other fernet
@pytest.mark.parametrize("test_value", ["value", ""])
def test_var_with_encryption_rotate_fernet_key(self, test_value, session):
"""
Expand Down Expand Up @@ -152,6 +155,7 @@ def test_variable_update(self, session):
Variable.update(key="test_key", value="value2", session=session)
assert "value2" == Variable.get("test_key")

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, API server has other ENV
def test_variable_update_fails_on_non_metastore_variable(self, session):
with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="env-value"):
with pytest.raises(AttributeError):
Expand Down Expand Up @@ -281,6 +285,7 @@ def test_caching_caches(self, mock_ensure_secrets: mock.Mock):
mock_backend.get_variable.assert_called_once() # second call was not made because of cache
assert first == second

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other env
def test_cache_invalidation_on_set(self, session):
with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="from_env"):
a = Variable.get("key") # value is saved in cache
Expand Down Expand Up @@ -316,7 +321,7 @@ def test_masking_only_secret_values(variable_value, deserialize_json, expected_m
val=variable_value,
)
session.add(var)
session.flush()
session.commit()
# Make sure we re-load it, not just get the cached object back
session.expunge(var)
_secrets_masker().patterns = set()
Expand All @@ -326,5 +331,4 @@ def test_masking_only_secret_values(variable_value, deserialize_json, expected_m
for expected_masked_value in expected_masked_values:
assert expected_masked_value in _secrets_masker().patterns
finally:
session.rollback()
db.clear_db_variables()