Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
# Time to sleep between active checks of the operation results
TIME_TO_SLEEP_IN_SECONDS = 20

CLOUD_SQL_PROXY_VERSION_REGEX = re.compile(r"^v?(\d+\.\d+\.\d+)(-\w*.?\d?)?$")


class CloudSqlOperationStatus:
"""Helper class with operation statuses."""
Expand Down Expand Up @@ -449,16 +451,7 @@ def _download_sql_proxy_if_needed(self) -> None:
if os.path.isfile(self.sql_proxy_path):
self.log.info("cloud-sql-proxy is already present")
return
system = platform.system().lower()
processor = os.uname().machine
if processor == "x86_64":
processor = "amd64"
if not self.sql_proxy_version:
download_url = CLOUD_SQL_PROXY_DOWNLOAD_URL.format(system, processor)
else:
download_url = CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL.format(
self.sql_proxy_version, system, processor
)
download_url = self._get_sql_proxy_download_url()
proxy_path_tmp = self.sql_proxy_path + ".tmp"
self.log.info("Downloading cloud_sql_proxy from %s to %s", download_url, proxy_path_tmp)
# httpx has a breaking API change (follow_redirects vs allow_redirects)
Expand All @@ -482,6 +475,24 @@ def _download_sql_proxy_if_needed(self) -> None:
os.chmod(self.sql_proxy_path, 0o744) # Set executable bit
self.sql_proxy_was_downloaded = True

def _get_sql_proxy_download_url(self):
system = platform.system().lower()
processor = os.uname().machine
if processor == "x86_64":
processor = "amd64"
if not self.sql_proxy_version:
download_url = CLOUD_SQL_PROXY_DOWNLOAD_URL.format(system, processor)
else:
if not CLOUD_SQL_PROXY_VERSION_REGEX.match(self.sql_proxy_version):
raise ValueError(
"The sql_proxy_version should match the regular expression "
f"{CLOUD_SQL_PROXY_VERSION_REGEX.pattern}"
)
download_url = CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL.format(
self.sql_proxy_version, system, processor
)
return download_url

def _get_credential_parameters(self) -> list[str]:
extras = GoogleBaseHook.get_connection(conn_id=self.gcp_conn_id).extra_dejson
key_path = get_field(extras, "key_path")
Expand Down
75 changes: 70 additions & 5 deletions tests/providers/google/cloud/hooks/test_cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from __future__ import annotations

import json
import os
import platform
import tempfile
from unittest import mock
from unittest.mock import PropertyMock

Expand All @@ -27,7 +30,11 @@

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook
from airflow.providers.google.cloud.hooks.cloud_sql import (
CloudSQLDatabaseHook,
CloudSQLHook,
CloudSqlProxyRunner,
)
from tests.providers.google.cloud.utils.base_gcp_mock import (
mock_base_gcp_hook_default_project_id,
mock_base_gcp_hook_no_default_project_id,
Expand Down Expand Up @@ -847,8 +854,12 @@ def test_cloudsql_database_hook_validate_ssl_certs_with_ssl_files_not_readable(
err = ctx.value
assert "must be a readable file" in str(err)

@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.gettempdir")
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_cloudsql_database_hook_validate_socket_path_length_too_long(self, get_connection):
def test_cloudsql_database_hook_validate_socket_path_length_too_long(
self, get_connection, gettempdir_mock
):
gettempdir_mock.return_value = "/tmp"
connection = Connection()
connection.set_extra(
json.dumps(
Expand All @@ -870,8 +881,12 @@ def test_cloudsql_database_hook_validate_socket_path_length_too_long(self, get_c
err = ctx.value
assert "The UNIX socket path length cannot exceed" in str(err)

@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.gettempdir")
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_cloudsql_database_hook_validate_socket_path_length_not_too_long(self, get_connection):
def test_cloudsql_database_hook_validate_socket_path_length_not_too_long(
self, get_connection, gettempdir_mock
):
gettempdir_mock.return_value = "/tmp"
connection = Connection()
connection.set_extra(
json.dumps(
Expand Down Expand Up @@ -1093,7 +1108,7 @@ def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connection
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
assert "postgres" == connection.conn_type
assert "/tmp" in connection.host
assert tempfile.gettempdir() in connection.host
assert "example-project:europe-west1:testdb" in connection.host
assert connection.port is None
assert "testdb" == connection.schema
Expand Down Expand Up @@ -1166,7 +1181,7 @@ def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connection):
connection = hook.create_connection()
assert "mysql" == connection.conn_type
assert "localhost" == connection.host
assert "/tmp" in connection.extra_dejson["unix_socket"]
assert tempfile.gettempdir() in connection.extra_dejson["unix_socket"]
assert "example-project:europe-west1:testdb" in connection.extra_dejson["unix_socket"]
assert connection.port is None
assert "testdb" == connection.schema
Expand All @@ -1185,3 +1200,53 @@ def test_hook_with_correct_parameters_mysql_tcp(self, get_connection):
assert "127.0.0.1" == connection.host
assert 3200 != connection.port
assert "testdb" == connection.schema


def get_processor():
processor = os.uname().machine
if processor == "x86_64":
processor = "amd64"
return processor


class TestCloudSqlProxyRunner:
@pytest.mark.parametrize(
["version", "download_url"],
[
(
"v1.23.0",
"https://storage.googleapis.com/cloudsql-proxy/v1.23.0/cloud_sql_proxy."
f"{platform.system().lower()}.{get_processor()}",
),
(
"v1.23.0-preview.1",
"https://storage.googleapis.com/cloudsql-proxy/v1.23.0-preview.1/cloud_sql_proxy."
f"{platform.system().lower()}.{get_processor()}",
),
],
)
def test_cloud_sql_proxy_runner_version_ok(self, version, download_url):
runner = CloudSqlProxyRunner(
path_prefix="12345678",
instance_specification="project:us-east-1:instance",
sql_proxy_version=version,
)
assert runner._get_sql_proxy_download_url() == download_url

@pytest.mark.parametrize(
"version",
[
"v1.23.",
"v1.23.0..",
"v1.23.0\\",
"\\",
],
)
def test_cloud_sql_proxy_runner_version_nok(self, version):
runner = CloudSqlProxyRunner(
path_prefix="12345678",
instance_specification="project:us-east-1:instance",
sql_proxy_version=version,
)
with pytest.raises(ValueError, match="The sql_proxy_version should match the regular expression"):
runner._get_sql_proxy_download_url()