Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Team Based Configuration (AIP-67) #45016

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ core:
type: string
example: ~
default: "SequentialExecutor"
multi_team: true
auth_manager:
description: |
The auth manager class that airflow should use. Full import path to the auth manager class.
Expand Down Expand Up @@ -525,6 +526,13 @@ core:
type: integer
example: ~
default: "4096"
multi_team_configurations:
description: |
A comma delimited list of team names and their respective configuration files (separated by a colon).
version_added: 3.1.0
type: string
example: "path/to/team_a/config:team_a,different/path/team_b/configuration:team_b"
default: ""
database:
description: ~
options:
Expand Down
41 changes: 41 additions & 0 deletions airflow/config_templates/team_unit_tests.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# NOTE FOR CONTRIBUTORS:
#
# Note this is an example TEAM based config.
#
# Values specified in this file are automatically used in our unit tests
# run by pytest and override default airflow configuration values provided by config.yml.
#
# These configuration settings should provide consistent environment to run tests -
# no matter if you are in Breeze env or use local venv or even run tests in the CI environment.
#
# If you want to have all unit tests to get some default configuration value, you should set it here.
#
# You cannot use ``conf_vars`` context manager to override the configuration for this config at this moment.
#
# The test configuration is loaded via setting AIRFLOW__CORE__UNIT_TEST_MODE=True
# in a pytest fixture in tests/conftest.py. This in turn triggers reloading of the configuration
# from this file after cleaning the respective configuration retrieved during initialization
# of configuration. See ``load_test_config`` function in ``airflow/config.py`` for details.
#

# TODO: update this description to include team

[core]
executor = SequentialExecutor
108 changes: 93 additions & 15 deletions airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def __init__(
self.is_validated = False
self._suppress_future_warnings = False
self._providers_configuration_loaded = False
self._team_configs = {} # type: ignore[var-annotated]

def _update_logging_deprecated_template_to_one_from_defaults(self):
default = self.get_default_value("logging", "log_filename_template")
Expand Down Expand Up @@ -790,12 +791,14 @@ def mask_secrets(self):
continue
mask_secret(value)

def _env_var_name(self, section: str, key: str) -> str:
return f"{ENV_VAR_PREFIX}{section.replace('.', '_').upper()}__{key.upper()}"
def _env_var_name(self, section: str, key: str, team_id: str | None = None) -> str:
team_component = f"{team_id.upper()}__" if team_id else ""
return f"{ENV_VAR_PREFIX}{team_component}{section.replace('.', '_').upper()}__{key.upper()}"

def _get_env_var_option(self, section: str, key: str):
# must have format AIRFLOW__{SECTION}__{KEY} (note double underscore)
env_var = self._env_var_name(section, key)
def _get_env_var_option(self, section: str, key: str, team_id: str | None = None):
# must have format AIRFLOW__{SECTION}__{KEY} (note double underscore) OR for team based
# configuration must have the format AIRFLOW__{TEAM_ID}__{SECTION}__{KEY}
env_var = self._env_var_name(section, key, team_id=team_id)
if env_var in os.environ:
return expand_env_var(os.environ[env_var])
# alternatively AIRFLOW__{SECTION}__{KEY}_CMD (for a command)
Expand Down Expand Up @@ -885,6 +888,7 @@ def get( # type: ignore[override,misc]
key: str,
suppress_warnings: bool = False,
_extra_stacklevel: int = 0,
team_id: str | None = None,
**kwargs,
) -> str | None:
section = section.lower()
Expand Down Expand Up @@ -942,6 +946,7 @@ def get( # type: ignore[override,misc]
deprecated_section,
key,
section,
team_id=team_id,
issue_warning=not warning_emitted,
extra_stacklevel=_extra_stacklevel,
)
Expand All @@ -955,6 +960,7 @@ def get( # type: ignore[override,misc]
key,
kwargs,
section,
team_id=team_id,
issue_warning=not warning_emitted,
extra_stacklevel=_extra_stacklevel,
)
Expand All @@ -967,6 +973,7 @@ def get( # type: ignore[override,misc]
deprecated_section,
key,
section,
team_id=team_id,
issue_warning=not warning_emitted,
extra_stacklevel=_extra_stacklevel,
)
Expand All @@ -979,6 +986,7 @@ def get( # type: ignore[override,misc]
deprecated_section,
key,
section,
team_id=team_id,
issue_warning=not warning_emitted,
extra_stacklevel=_extra_stacklevel,
)
Expand All @@ -1004,9 +1012,20 @@ def _get_option_from_secrets(
deprecated_section: str | None,
key: str,
section: str,
team_id: str | None = None,
issue_warning: bool = True,
extra_stacklevel: int = 0,
) -> str | None:
if team_id:
# The number of team configs that are going to be needed will be small and so support for obscure
# ways to pass configuration just complicate things unnecessarily. This can always be added in
# the future if there is enough user demand.
log.debug(
"Secrets are not supported for team configs. "
"Please use environment variables or the team configuration file instead."
)
return None

option = self._get_secret_option(section, key)
if option:
return option
Expand All @@ -1025,9 +1044,20 @@ def _get_option_from_commands(
deprecated_section: str | None,
key: str,
section: str,
team_id: str | None = None,
issue_warning: bool = True,
extra_stacklevel: int = 0,
) -> str | None:
if team_id:
# The number of team configs that are going to be needed will be small and so support for obscure
# ways to pass configuration just complicate things unnecessarily. This can always be added in
# the future if there is enough user demand.
log.debug(
"Commands are not supported for team configs. "
"Please use environment variables or the team configuration file instead."
)
return None

option = self._get_cmd_option(section, key)
if option:
return option
Expand All @@ -1047,19 +1077,23 @@ def _get_option_from_config_file(
key: str,
kwargs: dict[str, Any],
section: str,
team_id: str | None = None,
issue_warning: bool = True,
extra_stacklevel: int = 0,
) -> str | None:
if super().has_option(section, key):
# Get a specific team config if team_id is set otherwise use the standard conf parser
config_parser = self._team_configs.get(team_id, super())

if config_parser.has_option(section, key):
# Use the parent's methods to get the actual config here to be able to
# separate the config from default config.
return expand_env_var(super().get(section, key, **kwargs))
return expand_env_var(config_parser.get(section, key, **kwargs))
if deprecated_section and deprecated_key:
if super().has_option(deprecated_section, deprecated_key):
if config_parser.has_option(deprecated_section, deprecated_key):
if issue_warning:
self._warn_deprecate(section, key, deprecated_section, deprecated_key, extra_stacklevel)
with self.suppress_future_warnings():
return expand_env_var(super().get(deprecated_section, deprecated_key, **kwargs))
return expand_env_var(config_parser.get(deprecated_section, deprecated_key, **kwargs))
return None

def _get_environment_variables(
Expand All @@ -1068,15 +1102,16 @@ def _get_environment_variables(
deprecated_section: str | None,
key: str,
section: str,
team_id: str | None = None,
issue_warning: bool = True,
extra_stacklevel: int = 0,
) -> str | None:
option = self._get_env_var_option(section, key)
option = self._get_env_var_option(section, key, team_id=team_id)
if option is not None:
return option
if deprecated_section and deprecated_key:
with self.suppress_future_warnings():
option = self._get_env_var_option(deprecated_section, deprecated_key)
option = self._get_env_var_option(deprecated_section, deprecated_key, team_id=team_id)
if option is not None:
if issue_warning:
self._warn_deprecate(section, key, deprecated_section, deprecated_key, extra_stacklevel)
Expand Down Expand Up @@ -1240,7 +1275,7 @@ def read_dict( # type: ignore[override]
"""
super().read_dict(dictionary=dictionary, source=source)

def has_option(self, section: str, option: str) -> bool:
def has_option(self, section: str, option: str, team_id=None) -> bool:
"""
Check if option is defined.

Expand All @@ -1249,10 +1284,13 @@ def has_option(self, section: str, option: str) -> bool:

:param section: section to get option from
:param option: option to get
:param team_id: team_id to get option from
:return:
"""
try:
value = self.get(section, option, fallback=None, _extra_stacklevel=1, suppress_warnings=True)
value = self.get(
section, option, fallback=None, _extra_stacklevel=1, suppress_warnings=True, team_id=team_id
)
if value is None:
return False
return True
Expand Down Expand Up @@ -1763,6 +1801,29 @@ def providers_configuration_loaded(self) -> bool:
"""Checks if providers have been loaded."""
return self._providers_configuration_loaded

def setup_team_configs(self, from_dict: dict | None = None):
"""
Load the team configurations if any are specified.

They are a comma delimited list of items where each item is a path to a team configuration
file and a team id (separated by a colon).
"""
team_configs = self.get("core", "multi_team_configurations")
if not team_configs:
return
for team_config in team_configs.split(","):
team_config_path, team_id = team_config.split(":")
# Create a parser for each team, teams will implement the same configurations so they get their
# own parsers such that they do not overwrite each other
team_parser = create_default_config_parser(self.configuration_description, multi_team=True)
if from_dict:
team_parser.read_dict(from_dict)
else:
# Load the config file into the team config parser
team_parser.read(team_config_path)

self._team_configs[team_id] = team_parser

def load_providers_configuration(self):
"""
Load configuration for providers.
Expand Down Expand Up @@ -1874,9 +1935,11 @@ def _generate_fernet_key() -> str:
return Fernet.generate_key().decode()


def create_default_config_parser(configuration_description: dict[str, dict[str, Any]]) -> ConfigParser:
def create_default_config_parser(
configuration_description: dict[str, dict[str, Any]], multi_team: bool = False
) -> ConfigParser:
"""
Create default config parser based on configuration description.
Create default Airflow or multi-team config parser based on configuration description.

It creates ConfigParser with all default values retrieved from the configuration description and
expands all the variables from the global and local variables defined in this module.
Expand All @@ -1888,9 +1951,14 @@ def create_default_config_parser(configuration_description: dict[str, dict[str,
parser = ConfigParser()
all_vars = get_all_expansion_variables()
for section, section_desc in configuration_description.items():
# TODO: Should we avoid adding sections that don't have any team specific configs? Will that cause
# any issues?
parser.add_section(section)
options = section_desc["options"]
for key in options:
if multi_team and not options[key].get("multi_team", False):
# We are building a multi-team config and this option is not multi-team, skip it.
continue
default_value = options[key]["default"]
is_template = options[key].get("is_template", False)
if default_value is not None:
Expand Down Expand Up @@ -2025,6 +2093,16 @@ def initialize_config() -> AirflowConfigParser:
# Set the WEBSERVER_CONFIG variable
global WEBSERVER_CONFIG
WEBSERVER_CONFIG = airflow_config_parser.get("webserver", "config_file")
# Set up the configs for any teams
if airflow_config_parser.getboolean("core", "unit_test_mode"):
# Set the path to a test team config file
team_unit_test_config_file = (
pathlib.Path(__file__).parent / "config_templates" / "team_unit_tests.cfg"
)
multi_team_config = f"{team_unit_test_config_file}:unit_test_team"
airflow_config_parser.set("core", "multi_team_configurations", multi_team_config)
airflow_config_parser.setup_team_configs()

return airflow_config_parser


Expand Down
20 changes: 20 additions & 0 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,25 @@ def can_try_again(self):
return True


class ExecutorConf:
"""
This class is used to fetch configuration for an executor for a particular team_id.

It wraps the implementation of the configuration.get() to look for the particular section and key
prefixed with the team_id. This makes it easy for child classes (i.e. concrete executors) to fetch
configuration values for a particular team_id without having to worry about passing through the team_id.
"""

def __init__(self, team_id: str | None = None):
self.team_id = team_id

def get(self, *args, **kwargs):
return conf.get(*args, **kwargs, team_id=self.team_id)

def getboolean(self, *args, **kwargs):
return conf.getboolean(*args, **kwargs, team_id=self.team_id)


class BaseExecutor(LoggingMixin):
"""
Base class to inherit for concrete executors such as Celery, Kubernetes, Local, Sequential, etc.
Expand All @@ -132,6 +151,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_id: str | None = None):
super().__init__()
self.parallelism: int = parallelism
self.team_id: str | None = team_id
self.conf = ExecutorConf(team_id)
self.queued_tasks: dict[TaskInstanceKey, QueuedTaskInstanceType] = {}
self.running: set[TaskInstanceKey] = set()
self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {}
Expand Down
Loading