Skip to content

Commit

Permalink
Remove global variable with API auth backend (#9833)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 7b23d4d01f35b693848fb3b2c482db66d91b658b
  • Loading branch information
mik-laj authored and Cloud Composer Team committed Sep 12, 2024
1 parent 1a1fd24 commit a6c37de
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 33 deletions.
11 changes: 1 addition & 10 deletions airflow/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,6 @@
log = logging.getLogger(__name__)


class ApiAuth: # pylint: disable=too-few-public-methods
"""Class to keep module of Authentication API """
def __init__(self):
self.api_auth = None


API_AUTH = ApiAuth()


def load_auth():
"""Loads authentication backend"""
auth_backend = 'airflow.api.auth.backend.default'
Expand All @@ -43,7 +34,7 @@ def load_auth():
pass

try:
API_AUTH.api_auth = import_module(auth_backend)
return import_module(auth_backend)
except ImportError as err:
log.critical(
"Cannot import %s for API authentication due to: %s",
Expand Down
5 changes: 1 addition & 4 deletions airflow/api/auth/backend/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
"""Default authentication backend - everything is allowed"""
from functools import wraps
from typing import Callable, Optional, TypeVar, cast
from typing import Callable, TypeVar, cast

from airflow.typing_compat import Protocol

Expand All @@ -33,9 +33,6 @@ def handle_response(self, _):
...


CLIENT_AUTH = None # type: Optional[ClientAuthProtocol]


def init_app(_):
"""Initializes authentication backend"""

Expand Down
2 changes: 1 addition & 1 deletion airflow/api/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def get_current_api_client() -> Client:
api_module = import_module(conf.get('cli', 'api_client')) # type: Any
api_client = api_module.Client(
api_base_url=conf.get('cli', 'endpoint_url'),
auth=api.API_AUTH.api_auth.CLIENT_AUTH
auth=api.load_auth()
)
return api_client
3 changes: 0 additions & 3 deletions airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
from sqlalchemy.orm.session import Session as SASession
from sqlalchemy.pool import NullPool

# noinspection PyUnresolvedReferences
from airflow import api
# pylint: disable=unused-import
from airflow.configuration import AIRFLOW_HOME, WEBSERVER_CONFIG, conf # NOQA F401
from airflow.logging_config import configure_logging
Expand Down Expand Up @@ -330,7 +328,6 @@ def initialize():
# The webservers import this file from models.py with the default settings.
configure_orm()
configure_action_logging()
api.load_auth()

# Ensure we close DB connections at scheduler and gunicon worker terminations
atexit.register(dispose_orm)
Expand Down
17 changes: 14 additions & 3 deletions airflow/www/api/experimental/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import wraps
from typing import Callable, TypeVar, cast

from flask import Blueprint, g, jsonify, request, url_for
from flask import Blueprint, current_app, g, jsonify, request, url_for

import airflow.api
from airflow import models
from airflow.api.common.experimental import delete_dag as delete, pool as pool_api, trigger_dag as trigger
from airflow.api.common.experimental.get_code import get_code
Expand All @@ -35,7 +36,17 @@

log = logging.getLogger(__name__)

requires_authentication = airflow.api.API_AUTH.api_auth.requires_authentication
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name


def requires_authentication(function: T):
"""Decorator for functions that require authentication"""
@wraps(function)
def decorated(*args, **kwargs):
return current_app.api_auth.requires_authentication(function)(*args, **kwargs)

return cast(T, decorated)


api_experimental = Blueprint('api_experimental', __name__)

Expand Down
26 changes: 21 additions & 5 deletions airflow/www/extensions/init_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from importlib import import_module

from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, AirflowException

log = logging.getLogger(__name__)


def init_xframe_protection(app):
Expand All @@ -37,8 +42,19 @@ def apply_caching(response):


def init_api_experimental_auth(app):
"""Initialize authorization in Experimental API"""
from airflow import api

api.load_auth()
api.API_AUTH.api_auth.init_app(app)
"""Loads authentication backend"""
auth_backend = 'airflow.api.auth.backend.default'
try:
auth_backend = conf.get("api", "auth_backend")
except AirflowConfigException:
pass

try:
app.api_auth = import_module(auth_backend)
app.api_auth.init_app(app)
except ImportError as err:
log.critical(
"Cannot import %s for API authentication due to: %s",
auth_backend, err
)
raise AirflowException(err)
7 changes: 0 additions & 7 deletions airflow/www/extensions/init_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,5 @@ def init_api_experimental(app):
"""Initialize Experimental API"""
from airflow.www.api.experimental import endpoints

# required for testing purposes otherwise the module retains
# a link to the default_auth
if app.config['TESTING']:
import importlib

importlib.reload(endpoints)

app.register_blueprint(endpoints.api_experimental, url_prefix='/api/experimental')
app.extensions['csrf'].exempt(endpoints.api_experimental)

0 comments on commit a6c37de

Please sign in to comment.