Skip to content

Commit

Permalink
Add implicit_tls connect arg to support non-standard implicit TLS c…
Browse files Browse the repository at this point in the history
…onnections, such as Google Cloud SQL

fixes aio-libs#757
  • Loading branch information
Nothing4You committed Aug 27, 2022
1 parent ab13f94 commit 173eb36
Show file tree
Hide file tree
Showing 14 changed files with 265 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .codecov.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
codecov:
notify:
after_n_builds: 40
after_n_builds: 6
63 changes: 48 additions & 15 deletions .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -406,18 +406,18 @@ jobs:
- ubuntu-latest
py:
- '3.7'
- '3.8'
- '3.9'
- '3.10'
# - '3.8'
# - '3.9'
# - '3.10'
- '3.11-dev'
db:
- [mysql, '5.7']
- [mysql, '8.0']
- [mariadb, '10.3']
- [mariadb, '10.4']
- [mariadb, '10.5']
- [mariadb, '10.6']
- [mariadb, '10.7']
# - [mariadb, '10.3']
# - [mariadb, '10.4']
# - [mariadb, '10.5']
# - [mariadb, '10.6']
# - [mariadb, '10.7']
- [mariadb, '10.8']

fail-fast: false
Expand Down Expand Up @@ -449,6 +449,13 @@ jobs:
options: '--name=mysqld'
env:
MYSQL_ROOT_PASSWORD: rootpw
haproxy:
image: haproxytech/haproxy-alpine:2.6
ports:
- 13306:13306
volumes:
- "/tmp/run-${{ join(matrix.db, '-') }}/:/var/lib/haproxy/socket-mount/"
options: '--name=haproxy'

steps:
- name: Setup Python ${{ matrix.py }}
Expand Down Expand Up @@ -569,6 +576,18 @@ jobs:
# unfortunately we need this hacky workaround as GitHub Actions service containers can't reference data from our repo.
- name: Prepare mysql
run: |
# we need to ensure that the socket path is readable from haproxy and
# writable for the user running the DB process
sudo chmod 0777 /tmp/run-${{ join(matrix.db, '-') }}
# inject HAproxy configuration
docker container stop haproxy
docker container cp "${{ github.workspace }}/tests/ssl_resources/haproxy.cfg" haproxy:/usr/local/etc/haproxy/haproxy.cfg
docker container cp "${{ github.workspace }}/tests/ssl_resources/ssl/server-combined.pem" haproxy:/usr/local/etc/haproxy/haproxy.pem
docker container start haproxy
# ensure server is started up
while :
do
Expand All @@ -582,9 +601,6 @@ jobs:
docker container cp "${{ github.workspace }}/tests/ssl_resources/tls.cnf" mysqld:/etc/mysql/conf.d/aiomysql-tls.cnf
# use custom socket path
# we need to ensure that the socket path is writable for the user running the DB process in the container
sudo chmod 0777 /tmp/run-${{ join(matrix.db, '-') }}
docker container cp "${{ github.workspace }}/tests/ssl_resources/socket.cnf" mysqld:/etc/mysql/conf.d/aiomysql-socket.cnf
docker container start mysqld
Expand All @@ -598,11 +614,28 @@ jobs:
mysql -h127.0.0.1 -uroot "-p$MYSQL_ROOT_PASSWORD" -e "SET GLOBAL local_infile=on"
# This should get removed before merging
mysql -h127.0.0.1 -uroot "-p$MYSQL_ROOT_PASSWORD" -e "select user()"
mysql -h127.0.0.1 -uroot "-p$MYSQL_ROOT_PASSWORD" -e "select host, user, hex(authentication_string) from mysql.user"
- name: Run tests
run: |
# timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs
timeout --preserve-status --signal=INT --verbose 570s \
pytest --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql --cov tests ./tests --mysql-unix-socket "unix-${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock" --mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306"
# timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs
run: >-
timeout
--preserve-status
--signal=INT
--verbose 570s
pytest
--capture=no
--verbosity 2
--cov-report term
--cov-report xml
--cov aiomysql
--cov tests
./tests
--mysql-unix-socket "unix-${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock"
--mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306"
--mysql-address-tls "tls-${{ join(matrix.db, '') }}=127.0.0.1:13306"
env:
PYTHONUNBUFFERED: 1
timeout-minutes: 10
Expand Down
2 changes: 2 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ next (unreleased)
| aiomysql now reraises the original exception during connect() if it's not `IOError`, `OSError` or `asyncio.TimeoutError`.
| This was previously always raised as `OperationalError`.

* Add `implicit_tls` connect arg to support non-standard implicit TLS connections, such as Google Cloud SQL #757

0.1.1 (2022-05-08)
^^^^^^^^^^^^^^^^^^

Expand Down
32 changes: 24 additions & 8 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
import configparser
import getpass
import ssl as ssllib
from functools import partial

from pymysql.charset import charset_by_name, charset_by_id
Expand Down Expand Up @@ -53,7 +54,7 @@ def connect(host="localhost", user=None, password="",
connect_timeout=None, read_default_group=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
program_name='', server_public_key=None, implicit_tls=False):
"""See connections.Connection.__init__() for information about
defaults."""
coro = _connect(host=host, user=user, password=password, db=db,
Expand All @@ -66,7 +67,8 @@ def connect(host="localhost", user=None, password="",
read_default_group=read_default_group,
autocommit=autocommit, echo=echo,
local_infile=local_infile, loop=loop, ssl=ssl,
auth_plugin=auth_plugin, program_name=program_name)
auth_plugin=auth_plugin, program_name=program_name,
implicit_tls=implicit_tls)
return _ConnectionContextManager(coro)


Expand Down Expand Up @@ -142,7 +144,7 @@ def __init__(self, host="localhost", user=None, password="",
connect_timeout=None, read_default_group=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
program_name='', server_public_key=None, implicit_tls=False):
"""
Establish a connection to the MySQL database. Accepts several
arguments:
Expand Down Expand Up @@ -184,6 +186,9 @@ def __init__(self, host="localhost", user=None, password="",
handshaking with MySQL. (omitted by default)
:param server_public_key: SHA256 authentication plugin public
key value.
:param implicit_tls: Establish TLS immediately, skipping non-TLS
preamble before upgrading to TLS.
(default: False)
:param loop: asyncio loop
"""
self._loop = loop or asyncio.get_event_loop()
Expand Down Expand Up @@ -218,6 +223,7 @@ def __init__(self, host="localhost", user=None, password="",
self._auth_plugin_used = ""
self._secure = False
self.server_public_key = server_public_key
self._implicit_tls = implicit_tls
self.salt = None

from . import __version__
Expand All @@ -241,7 +247,10 @@ def __init__(self, host="localhost", user=None, password="",
self.use_unicode = use_unicode

self._ssl_context = ssl
if ssl:
# TLS is required when implicit_tls is True
if implicit_tls and not self._ssl_context:
self._ssl_context = ssllib.create_default_context()
if ssl and not implicit_tls:
client_flag |= CLIENT.SSL

self._encoding = charset_by_name(self._charset).encoding
Expand Down Expand Up @@ -536,7 +545,8 @@ async def _connect(self):

self._next_seq_id = 0

await self._get_server_information()
if not self._implicit_tls:
await self._get_server_information()
await self._request_authentication()

self.connected_time = self._loop.time()
Expand Down Expand Up @@ -738,7 +748,8 @@ async def _execute_command(self, command, sql):

async def _request_authentication(self):
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
if int(self.server_version.split('.', 1)[0]) >= 5:
# FIXME: change this before merge
if self._implicit_tls or int(self.server_version.split('.', 1)[0]) >= 5:
self.client_flag |= CLIENT.MULTI_RESULTS

if self.user is None:
Expand All @@ -748,8 +759,10 @@ async def _request_authentication(self):
data_init = struct.pack('<iIB23s', self.client_flag, MAX_PACKET_LEN,
charset_id, b'')

if self._ssl_context and self.server_capabilities & CLIENT.SSL:
self.write_packet(data_init)
if self._ssl_context and \
(self._implicit_tls or self.server_capabilities & CLIENT.SSL):
if not self._implicit_tls:
self.write_packet(data_init)

# Stop sending events to data_received
self._writer.transport.pause_reading()
Expand All @@ -771,6 +784,9 @@ async def _request_authentication(self):
server_hostname=self._host
)

if self._implicit_tls:
await self._get_server_information()

self._secure = True

if isinstance(self.user, str):
Expand Down
7 changes: 6 additions & 1 deletion docs/connection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Example::
connect_timeout=None, read_default_group=None,
autocommit=False, echo=False
ssl=None, auth_plugin='', program_name='',
server_public_key=None, loop=None)
server_public_key=None, loop=None, implicit_tls=False)

A :ref:`coroutine <coroutine>` that connects to MySQL.

Expand Down Expand Up @@ -93,6 +93,11 @@ Example::
``sys.argv[0]`` is no longer passed by default
:param server_public_key: SHA256 authenticaiton plugin public key value.
:param loop: asyncio event loop instance or ``None`` for default one.
:param implicit_tls: Establish TLS immediately, skipping non-TLS
preamble before upgrading to TLS.
(default: False)

.. versionadded:: 0.2
:returns: :class:`Connection` instance.


Expand Down
63 changes: 54 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
import os
import re
import socket
import ssl
import sys

Expand Down Expand Up @@ -63,13 +64,26 @@ def pytest_generate_tests(metafunc):

if ":" in addr:
addr = addr.split(":", 1)
mysql_addresses.append((addr[0], int(addr[1])))
mysql_addresses.append((addr[0], int(addr[1]), False))
else:
mysql_addresses.append((addr, 3306))
mysql_addresses.append((addr, 3306, False))

opt_mysql_address_tls =\
list(metafunc.config.getoption("mysql_address_tls"))
for i in range(len(opt_mysql_address_tls)):
if "=" in opt_mysql_address_tls[i]:
label, addr = opt_mysql_address_tls[i].split("=", 1)
ids.append(label)
else:
addr = opt_mysql_address_tls[i]
ids.append("tls{}".format(i))

addr = addr.split(":", 1)
mysql_addresses.append((addr[0], int(addr[1]), True))

# default to connecting to localhost
if len(mysql_addresses) == 0:
mysql_addresses = [("127.0.0.1", 3306)]
mysql_addresses = [("127.0.0.1", 3306, False)]
ids = ["tcp-local"]

assert len(mysql_addresses) == len(set(mysql_addresses)), \
Expand Down Expand Up @@ -153,6 +167,12 @@ def pytest_addoption(parser):
default=[],
help="list of addresses to connect to: [name=]host[:port]",
)
parser.addoption(
"--mysql-address-tls",
action="append",
default=[],
help="list of addresses to connect to using implicit TLS: [name=]host:port",
)
parser.addoption(
"--mysql-unix-socket",
action="append",
Expand Down Expand Up @@ -249,6 +269,7 @@ def _register_table(table_name):
@pytest.fixture(scope='session')
def mysql_server(mysql_address):
unix_socket = type(mysql_address) is str
implicit_tls = not unix_socket and mysql_address[2]

if not unix_socket:
ssl_directory = os.path.join(os.path.dirname(__file__),
Expand All @@ -270,14 +291,34 @@ def mysql_server(mysql_address):
else:
server_params["host"] = mysql_address[0]
server_params["port"] = mysql_address[1]

if not unix_socket and not implicit_tls:
server_params["ssl"] = ctx

try:
connection = pymysql.connect(
db='mysql',
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor,
**server_params)
if implicit_tls:
sock = ctx.wrap_socket(
socket.create_connection(
(server_params["host"], server_params["port"]),
),
server_hostname=server_params["host"],
)
connection = pymysql.Connection(
db='mysql',
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor,
**server_params,
defer_connect=True,
)
connection.connect(sock)

else:
connection = pymysql.connect(
db='mysql',
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor,
**server_params,
)

with connection.cursor() as cursor:
cursor.execute("SELECT VERSION() AS version")
Expand All @@ -297,7 +338,7 @@ def mysql_server(mysql_address):
pytest.fail("Unable to determine database type from {!r}"
.format(server_version_tuple))

if not unix_socket:
if not unix_socket and not implicit_tls:
cursor.execute("SHOW VARIABLES LIKE '%ssl%';")

result = cursor.fetchall()
Expand Down Expand Up @@ -353,6 +394,10 @@ def mysql_server(mysql_address):
except Exception:
pytest.fail("Cannot initialize MySQL environment")

if implicit_tls:
server_params["ssl"] = ctx
server_params["implicit_tls"] = implicit_tls

return {
"conn_params": server_params,
"server_version": server_version,
Expand Down
2 changes: 2 additions & 0 deletions tests/sa/test_sa_compiled_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ async def _make_engine(**kwargs):
}
if "ssl" in mysql_params:
conn_args["ssl"] = mysql_params["ssl"]
if "implicit_tls" in mysql_params:
conn_args["implicit_tls"] = mysql_params["implicit_tls"]

engine = await sa.create_engine(
db=mysql_params['db'],
Expand Down
2 changes: 2 additions & 0 deletions tests/sa/test_sa_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ async def _make_engine(**kwargs):
}
if "ssl" in mysql_params:
conn_args["ssl"] = mysql_params["ssl"]
if "implicit_tls" in mysql_params:
conn_args["implicit_tls"] = mysql_params["implicit_tls"]

engine = await sa.create_engine(
db=mysql_params['db'],
Expand Down
2 changes: 2 additions & 0 deletions tests/sa/test_sa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ async def _make_engine(**kwargs):
}
if "ssl" in mysql_params:
conn_args["ssl"] = mysql_params["ssl"]
if "implicit_tls" in mysql_params:
conn_args["implicit_tls"] = mysql_params["implicit_tls"]

engine = await sa.create_engine(
db=mysql_params['db'],
Expand Down
Loading

0 comments on commit 173eb36

Please sign in to comment.