Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from urllib.parse import urlsplit
from uuid import uuid4

from airflow.providers.common.compat.connection import get_async_connection

if TYPE_CHECKING:
from aiobotocore.client import AioBaseClient
from mypy_boto3_s3.service_resource import (
Expand All @@ -52,7 +54,6 @@
from airflow.providers.amazon.version_compat import ArgNotSet


from asgiref.sync import sync_to_async
from boto3.s3.transfer import S3Transfer, TransferConfig
from botocore.exceptions import ClientError

Expand Down Expand Up @@ -90,7 +91,7 @@ async def maybe_add_bucket_name(*args, **kwargs):
if not bound_args.arguments.get("bucket_name"):
self = args[0]
if self.aws_conn_id:
connection = await sync_to_async(self.get_connection)(self.aws_conn_id)
connection = await get_async_connection(self.aws_conn_id)
if connection.schema:
bound_args.arguments["bucket_name"] = connection.schema
return bound_args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import aiohttp
import requests
from aiohttp import ClientResponseError
from asgiref.sync import sync_to_async

from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook

Expand Down Expand Up @@ -526,7 +526,7 @@ async def _do_api_call_async(
auth = None

if self.http_conn_id:
conn = await sync_to_async(self.get_connection)(self.http_conn_id)
conn = await get_async_connection(self.http_conn_id)

self.base_url = self._generate_base_url(conn) # type: ignore[arg-type]
if conn.login:
Expand Down
12 changes: 6 additions & 6 deletions providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ async def test_run_method_error(self, mock_do_api_call_async):

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_post_method_with_success(self, mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for success response for POST method."""

Expand All @@ -634,7 +634,7 @@ async def mock_fun(arg1, arg2, arg3, arg4):

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_get_method_with_success(self, mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for GET method."""

Expand All @@ -659,7 +659,7 @@ async def mock_fun(arg1, arg2, arg3, arg4):

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_patch_method_with_success(self, mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for PATCH method."""

Expand All @@ -684,7 +684,7 @@ async def mock_fun(arg1, arg2, arg3, arg4):

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_unexpected_method_error(self, mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for unexpected method error"""
GET_RUN_ENDPOINT = "api/jobs/runs/get"
Expand All @@ -700,7 +700,7 @@ async def test_do_api_call_async_unexpected_method_error(self, mock_get_connecti

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_with_type_error(self, mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for TypeError."""

Expand All @@ -719,7 +719,7 @@ async def mock_fun(arg1, arg2, arg3, arg4):

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_with_client_response_error(self, mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for Client Response Error."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import aiofiles
import requests
from asgiref.sync import sync_to_async
from kubernetes import client, config, utils, watch
from kubernetes.client.models import V1Deployment
from kubernetes.config import ConfigException
Expand All @@ -46,6 +45,7 @@
container_is_completed,
container_is_running,
)
from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, AirflowNotFoundException, BaseHook
from airflow.utils import yaml

Expand Down Expand Up @@ -885,7 +885,7 @@ async def api_client_from_kubeconfig_file(_kubeconfig_path: str | None):
async def get_conn_extras(self) -> dict:
if self._extras is None:
if self.conn_id:
connection = await sync_to_async(self.get_connection)(self.conn_id)
connection = await get_async_connection(self.conn_id)
self._extras = connection.extra_dejson
else:
self._extras = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from requests.sessions import Session
from tenacity import AsyncRetrying, RetryCallState, retry_if_exception, stop_after_attempt, wait_exponential

from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.http.hooks.http import HttpHook

Expand Down Expand Up @@ -161,7 +162,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
if bound_args.arguments.get("account_id") is None:
self = args[0]
if self.dbt_cloud_conn_id:
connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
connection = await get_async_connection(self.dbt_cloud_conn_id)
default_account_id = connection.login
if not default_account_id:
raise AirflowException("Could not determine the dbt Cloud account.")
Expand Down
4 changes: 2 additions & 2 deletions providers/http/src/airflow/providers/http/hooks/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import aiohttp
import tenacity
from aiohttp import ClientResponseError
from asgiref.sync import sync_to_async
from requests import PreparedRequest, Request, Response, Session
from requests.auth import HTTPBasicAuth
from requests.exceptions import ConnectionError, HTTPError
from requests.models import DEFAULT_REDIRECT_LIMIT
from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter

from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.providers.http.exceptions import HttpErrorException, HttpMethodException

Expand Down Expand Up @@ -461,7 +461,7 @@ async def run(
auth = None

if self.http_conn_id:
conn = await sync_to_async(self.get_connection)(self.http_conn_id)
conn = await get_async_connection(self.http_conn_id)

if conn.host and "://" in conn.host:
self.base_url = conn.host
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from functools import wraps
from typing import IO, TYPE_CHECKING, Any, TypeVar, cast

from asgiref.sync import sync_to_async
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.identity.aio import (
ClientSecretCredential as AsyncClientSecretCredential,
Expand All @@ -48,6 +47,7 @@
from azure.mgmt.datafactory import DataFactoryManagementClient
from azure.mgmt.datafactory.aio import DataFactoryManagementClient as AsyncDataFactoryManagementClient

from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
Expand Down Expand Up @@ -1089,7 +1089,7 @@ async def bind_argument(arg: Any, default_key: str) -> None:
# Check if arg was not included in the function signature or, if it is, the value is not provided.
if arg not in bound_args.arguments or bound_args.arguments[arg] is None:
self = args[0]
conn = await sync_to_async(self.get_connection)(self.conn_id)
conn = await get_async_connection(self.conn_id)
extras = conn.extra_dejson
default_value = extras.get(default_key) or extras.get(
f"extra__azure_data_factory__{default_key}"
Expand Down Expand Up @@ -1126,7 +1126,7 @@ async def get_async_conn(self) -> AsyncDataFactoryManagementClient:
if self._async_conn is not None:
return self._async_conn

conn = await sync_to_async(self.get_connection)(self.conn_id)
conn = await get_async_connection(self.conn_id)
extras = conn.extra_dejson
tenant = get_field(extras, "tenantId")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import os
from typing import TYPE_CHECKING, Any, cast

from asgiref.sync import sync_to_async
from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError
from azure.identity import ClientSecretCredential
from azure.identity.aio import (
Expand All @@ -44,6 +43,7 @@
ContainerClient as AsyncContainerClient,
)

from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
Expand Down Expand Up @@ -620,7 +620,7 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
self._blob_service_client = cast("AsyncBlobServiceClient", self._blob_service_client)
return self._blob_service_client

conn = await sync_to_async(self.get_connection)(self.conn_id)
conn = await get_async_connection(self.conn_id)
extra = conn.extra_dejson or {}
client_secret_auth_config = extra.pop("client_secret_auth_config", {})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

import aiohttp
import pagerduty
from asgiref.sync import sync_to_async

from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.providers.http.hooks.http import HttpAsyncHook

Expand Down Expand Up @@ -285,7 +285,7 @@ async def get_integration_key(self) -> str:
return self.integration_key

if self.pagerduty_events_conn_id is not None:
conn = await sync_to_async(self.get_connection)(self.pagerduty_events_conn_id)
conn = await get_async_connection(self.pagerduty_events_conn_id)
self.integration_key = conn.password
if self.integration_key:
return self.integration_key
Expand Down
4 changes: 2 additions & 2 deletions providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
from typing import IO, TYPE_CHECKING, Any, cast

import asyncssh
from asgiref.sync import sync_to_async
from paramiko.config import SSH_PORT

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook, Connection
from airflow.providers.sftp.exceptions import ConnectionNotOpenedException
from airflow.providers.ssh.hooks.ssh import SSHHook
Expand Down Expand Up @@ -756,7 +756,7 @@ async def _get_conn(self) -> asyncssh.SSHClientConnection:
- known_hosts
- passphrase
"""
conn = await sync_to_async(self.get_connection)(self.sftp_conn_id)
conn = await get_async_connection(self.sftp_conn_id)
if conn.extra is not None:
self._parse_extras(conn) # type: ignore[arg-type]

Expand Down
16 changes: 8 additions & 8 deletions providers/sftp/tests/unit/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def __init__(self):

class TestSFTPHookAsync:
@patch("asyncssh.connect", new_callable=AsyncMock)
@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
@patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
@pytest.mark.asyncio
async def test_extra_dejson_fields_for_connection_building_known_hosts_none(
self, mock_get_connection, mock_connect, caplog
Expand Down Expand Up @@ -775,7 +775,7 @@ async def test_extra_dejson_fields_for_connection_building_known_hosts_none(
)
@patch("asyncssh.connect", new_callable=AsyncMock)
@patch("asyncssh.import_private_key")
@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
@patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
@pytest.mark.asyncio
async def test_extra_dejson_fields_for_connection_with_host_key(
self,
Expand All @@ -799,7 +799,7 @@ async def test_extra_dejson_fields_for_connection_with_host_key(
assert hook.known_hosts == f"localhost {mock_host_key}".encode()

@patch("asyncssh.connect", new_callable=AsyncMock)
@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
@patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
@pytest.mark.asyncio
async def test_extra_dejson_fields_for_connection_raises_valuerror(
self, mock_get_connection, mock_connect
Expand All @@ -820,7 +820,7 @@ async def test_extra_dejson_fields_for_connection_raises_valuerror(
@patch("paramiko.SSHClient.connect")
@patch("asyncssh.import_private_key")
@patch("asyncssh.connect", new_callable=AsyncMock)
@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
@patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
@pytest.mark.asyncio
async def test_no_host_key_check_set_logs_warning(
self, mock_get_connection, mock_connect, mock_import_pkey, mock_ssh_connect, caplog
Expand All @@ -833,7 +833,7 @@ async def test_no_host_key_check_set_logs_warning(
assert "No Host Key Verification. This won't protect against Man-In-The-Middle attacks" in caplog.text

@patch("asyncssh.connect", new_callable=AsyncMock)
@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
@patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
@pytest.mark.asyncio
async def test_extra_dejson_fields_for_connection_building(self, mock_get_connection, mock_connect):
"""
Expand Down Expand Up @@ -861,7 +861,7 @@ async def test_extra_dejson_fields_for_connection_building(self, mock_get_connec
@pytest.mark.asyncio
@patch("asyncssh.connect", new_callable=AsyncMock)
@patch("asyncssh.import_private_key")
@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
@patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
async def test_connection_private(self, mock_get_connection, mock_import_private_key, mock_connect):
"""
Assert that connection details with private key passed through the extra field in the Airflow connection
Expand All @@ -888,7 +888,7 @@ async def test_connection_private(self, mock_get_connection, mock_import_private

@pytest.mark.asyncio
@patch("asyncssh.connect", new_callable=AsyncMock)
@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
@patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
async def test_connection_port_default_to_22(self, mock_get_connection, mock_connect):
from unittest.mock import Mock, call

Expand Down Expand Up @@ -917,7 +917,7 @@ async def test_connection_port_default_to_22(self, mock_get_connection, mock_con

@pytest.mark.asyncio
@patch("asyncssh.connect", new_callable=AsyncMock)
@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
@patch("airflow.providers.sftp.hooks.sftp.get_async_connection")
async def test_init_argument_not_ignored(self, mock_get_connection, mock_connect):
from unittest.mock import Mock, call

Expand Down
4 changes: 2 additions & 2 deletions providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from sshtunnel import SSHTunnelForwarder
from tenacity import Retrying, stop_after_attempt, wait_fixed, wait_random

from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.utils.platform import getuser

Expand Down Expand Up @@ -615,9 +616,8 @@ async def _get_conn(self):
Returns an asyncssh SSHClientConnection that can be used to run commands.
"""
import asyncssh
from asgiref.sync import sync_to_async

conn = await sync_to_async(self.get_connection)(self.ssh_conn_id)
conn = await get_async_connection(self.ssh_conn_id)
if conn.extra is not None:
self._parse_extras(conn)

Expand Down
Loading