Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions airflow/providers/mysql/hooks/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,20 @@
from __future__ import annotations

import json
import logging
from typing import TYPE_CHECKING, Any, Union

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.models import Connection
from airflow.providers.common.sql.hooks.sql import DbApiHook

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from mysql.connector.abstracts import MySQLConnectionAbstract
try:
from mysql.connector.abstracts import MySQLConnectionAbstract
except ModuleNotFoundError:
logger.warning("The package 'mysql-connector-python' is not installed. Import skipped")
from MySQLdb.connections import Connection as MySQLdbConnection

MySQLConnectionTypes = Union["MySQLdbConnection", "MySQLConnectionAbstract"]
Expand Down Expand Up @@ -181,7 +188,14 @@ def get_conn(self) -> MySQLConnectionTypes:
return MySQLdb.connect(**conn_config)

if client_name == "mysql-connector-python":
import mysql.connector
try:
import mysql.connector
except ModuleNotFoundError:
raise AirflowOptionalProviderFeatureException(
"The pip package 'mysql-connector-python' is not installed, therefore the connection "
"wasn't established. Please, consider using default driver or pip install the package "
"'mysql-connector-python'. Warning! It might cause dependency conflicts."
)

conn_config = self._get_conn_config_mysql_connector_python(conn)
return mysql.connector.connect(**conn_config)
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/mysql/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ versions:
dependencies:
- apache-airflow>=2.3.0
- apache-airflow-providers-common-sql>=1.3.1
- mysql-connector-python>=8.0.11
- mysqlclient>=1.3.6

integrations:
Expand Down Expand Up @@ -87,3 +86,8 @@ transfers:
connection-types:
- hook-class-name: airflow.providers.mysql.hooks.mysql.MySqlHook
connection-type: mysql

additional-extras:
- name: mysql-connector-python
dependencies:
- mysql-connector-python>=8.0.11
4 changes: 3 additions & 1 deletion docker_tests/test_prod_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import subprocess
import tempfile
from importlib.util import find_spec
from pathlib import Path

import pytest
Expand Down Expand Up @@ -161,7 +162,6 @@ def test_pip_dependencies_conflict(self):
"grpc": ["grpc", "google.auth", "google_auth_httplib2"],
"hashicorp": ["hvac"],
"ldap": ["ldap"],
"mysql": ["mysql"],
"postgres": ["psycopg2"],
"pyodbc": ["pyodbc"],
"redis": ["redis"],
Expand All @@ -171,6 +171,8 @@ def test_pip_dependencies_conflict(self):
"statsd": ["statsd"],
"virtualenv": ["virtualenv"],
}
if bool(find_spec("mysql")):
PACKAGE_IMPORTS["mysql"] = ["mysql"]

@pytest.mark.skipif(os.environ.get("TEST_SLIM_IMAGE") == "true", reason="Skipped with slim image")
@pytest.mark.parametrize("package_name,import_names", PACKAGE_IMPORTS.items())
Expand Down
8 changes: 7 additions & 1 deletion docs/apache-airflow/howto/set-up-database.rst
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,13 @@ We recommend using the ``mysqlclient`` driver and specifying it in your SqlAlche
mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>

We also support the ``mysql-connector-python`` driver, which lets you connect through SSL
without any cert options provided.
without any cert options provided. If you wish to use ``mysql-connector-python`` driver, please install it with extras.

.. code-block:: text

$ pip install mysql-connector-python

The connection string in this case should look like:

.. code-block:: text

Expand Down
1 change: 0 additions & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,6 @@
"deps": [
"apache-airflow-providers-common-sql>=1.3.1",
"apache-airflow>=2.3.0",
"mysql-connector-python>=8.0.11",
"mysqlclient>=1.3.6"
],
"cross-providers-deps": [
Expand Down
4 changes: 2 additions & 2 deletions scripts/in_container/verify_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ class ProviderPackageDetails(NamedTuple):
"Implementing implicit namespace packages (as specified in PEP 420) is "
"preferred to `pkg_resources.declare_namespace`",
"This module is deprecated. Please use `airflow.providers.cncf.kubernetes.operators.pod` instead.",
"urllib3 (1.26.6) or chardet (5.1.0)/charset_normalizer (2.0.12) doesn't match a supported version!",
"urllib3 (1.26.9) or chardet (5.1.0)/charset_normalizer (2.0.12) doesn't match a supported version!",
"This operator is deprecated. Please use `GoogleDisplayVideo360CreateQueryOperator`",
"This operator is deprecated. Please use `GoogleDisplayVideo360RunQueryOperator`",
"This operator is deprecated. Please use `GoogleDisplayVideo360RunQuerySensor`",
"This operator is deprecated. Please use `GoogleDisplayVideo360DownloadReportV2Operator`",
"urllib3 (1.26.6) or chardet (5.1.0)/charset_normalizer (2.0.12) doesn't match a supported version!",
"urllib3 (1.26.9) or chardet (5.1.0)/charset_normalizer (2.0.12) doesn't match a supported version!",
}

# The set of warning messages generated by direct importing of some deprecated modules. We should only
Expand Down
57 changes: 0 additions & 57 deletions tests/providers/mysql/hooks/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,63 +182,6 @@ def test_get_conn_rds_iam(self, mock_client, mock_connect):
)


class TestMySqlHookConnMySqlConnectorPython:
def setup_method(self):
self.connection = Connection(
login="login",
password="password",
host="host",
schema="schema",
extra='{"client": "mysql-connector-python"}',
)

self.db_hook = MySqlHook()
self.db_hook.get_connection = mock.Mock()
self.db_hook.get_connection.return_value = self.connection

@mock.patch("mysql.connector.connect")
def test_get_conn(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
assert args == ()
assert kwargs["user"] == "login"
assert kwargs["password"] == "password"
assert kwargs["host"] == "host"
assert kwargs["database"] == "schema"

@mock.patch("mysql.connector.connect")
def test_get_conn_port(self, mock_connect):
self.connection.port = 3307
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
assert args == ()
assert kwargs["port"] == 3307

@mock.patch("mysql.connector.connect")
def test_get_conn_allow_local_infile(self, mock_connect):
extra_dict = self.connection.extra_dejson
self.connection.extra = json.dumps(extra_dict)
self.db_hook.local_infile = True
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
assert args == ()
assert kwargs["allow_local_infile"] == 1

@mock.patch("mysql.connector.connect")
def test_get_ssl_mode(self, mock_connect):
extra_dict = self.connection.extra_dejson
extra_dict.update(ssl_disabled=True)
self.connection.extra = json.dumps(extra_dict)
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
assert args == ()
assert kwargs["ssl_disabled"] == 1


class MockMySQLConnectorConnection:
DEFAULT_AUTOCOMMIT = "default"

Expand Down
86 changes: 86 additions & 0 deletions tests/providers/mysql/hooks/test_mysql_connector_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import json
from unittest import mock

import pytest

from airflow.models import Connection
from airflow.providers.mysql.hooks.mysql import MySqlHook

# Make sure that the optional package 'mysql-connector-python' is installed (which is not by default)
pytest.importorskip("mysql")


class TestMySqlHookConnMySqlConnectorPython:
def setup_method(self):
self.connection = Connection(
login="login",
password="password",
host="host",
schema="schema",
extra='{"client": "mysql-connector-python"}',
)

self.db_hook = MySqlHook()
self.db_hook.get_connection = mock.Mock()
self.db_hook.get_connection.return_value = self.connection

@mock.patch("mysql.connector.connect")
def test_get_conn(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
assert args == ()
assert kwargs["user"] == "login"
assert kwargs["password"] == "password"
assert kwargs["host"] == "host"
assert kwargs["database"] == "schema"

@mock.patch("mysql.connector.connect")
def test_get_conn_port(self, mock_connect):
self.connection.port = 3307
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
assert args == ()
assert kwargs["port"] == 3307

@mock.patch("mysql.connector.connect")
def test_get_conn_allow_local_infile(self, mock_connect):
extra_dict = self.connection.extra_dejson
self.connection.extra = json.dumps(extra_dict)
self.db_hook.local_infile = True
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
assert args == ()
assert kwargs["allow_local_infile"] == 1

@mock.patch("mysql.connector.connect")
def test_get_ssl_mode(self, mock_connect):
extra_dict = self.connection.extra_dejson
extra_dict.update(ssl_disabled=True)
self.connection.extra = json.dumps(extra_dict)
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
assert args == ()
assert kwargs["ssl_disabled"] == 1