Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions asyncmy/connection.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ class Connection:
:param host: Host where the database server is located.
:param user: Username to log in as.
:param password: Password to use.
:param password_creator:
Optional callable or coroutine that returns a password string
every time a new connection is established.
:param database: Database to use, None to not use a particular one.
:param port: MySQL port to use, default is usually OK. (default: 3306)
:param unix_socket: Use a unix socket rather than TCP/IP.
Expand Down Expand Up @@ -152,6 +155,7 @@ class Connection:
*,
user=None, # The first four arguments is based on DB-API 2.0 recommendation.
password="",
password_creator=None,
host=None,
database=None,
unix_socket=None,
Expand Down Expand Up @@ -240,6 +244,7 @@ class Connection:
raise ValueError("port should be of type int")
self._user = user or DEFAULT_USER
self._password = password or b""
self._password_creator = password_creator
if isinstance(self._password, str):
self._password = self._password.encode("latin1")
self._db = database
Expand Down Expand Up @@ -549,6 +554,12 @@ class Connection:
return self._reader, self._writer
try:

if self._password_creator is not None:
new_pw = self._password_creator()
if asyncio.iscoroutine(new_pw):
new_pw = await new_pw
self._password = new_pw.encode("latin1")

if self._unix_socket:
self._reader, self._writer = await asyncio.wait_for(asyncio.open_unix_connection(self._unix_socket),
timeout=self._connect_timeout, )
Expand Down Expand Up @@ -1281,6 +1292,7 @@ class LoadLocalFile:

def connect(user=None,
password="",
password_creator=None,
host=None,
database=None,
unix_socket=None,
Expand Down Expand Up @@ -1310,6 +1322,7 @@ def connect(user=None,
coro = _connect(
user=user,
password=password,
password_creator=password_creator,
host=host,
database=database,
unix_socket=unix_socket,
Expand Down
34 changes: 24 additions & 10 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,27 @@
from asyncmy import connect
from asyncmy.cursors import DictCursor

connection_kwargs = dict(
host="127.0.0.1",
port=3306,
user="root",
password=os.getenv("MYSQL_PASS") or "123456",
echo=True,
)
def mysql_password_creator():
"""Return the MySQL password dynamically"""
return os.getenv("MYSQL_PASS") or "123456"


@pytest_asyncio.fixture(params=["static", "creator"], scope="session")
def connection_kwargs(request):
"""Provide connection args for both static and dynamic password modes."""
base = dict(
host="127.0.0.1",
port=3306,
user="root",
echo=True,
)

if request.param == "static":
base["password"] = os.getenv("MYSQL_PASS") or "123456"
else:
base["password_creator"] = mysql_password_creator

return base


@pytest_asyncio.fixture(scope="session")
Expand All @@ -30,7 +44,7 @@ def event_loop():


@pytest_asyncio.fixture(scope="session")
async def connection():
async def connection(connection_kwargs):
conn = await connect(**connection_kwargs)
yield conn
await conn.ensure_closed()
Expand Down Expand Up @@ -63,8 +77,8 @@ async def truncate_table(connection):


@pytest_asyncio.fixture(scope="session")
async def pool():
async def pool(connection_kwargs):
pool = await asyncmy.create_pool(**connection_kwargs)
yield pool
pool.close()
await pool.wait_closed()
await pool.wait_closed()
5 changes: 2 additions & 3 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

from asyncmy.connection import Connection
from asyncmy.errors import OperationalError
from conftest import connection_kwargs


@pytest.mark.asyncio
async def test_connect():
async def test_connect(connection_kwargs):
connection = Connection(**connection_kwargs)
await connection.connect()
assert connection._connected
Expand All @@ -22,7 +21,7 @@ async def test_connect():


@pytest.mark.asyncio
async def test_read_timeout():
async def test_read_timeout(connection_kwargs):
with pytest.raises(OperationalError):
connection = Connection(read_timeout=1, **connection_kwargs)
await connection.connect()
Expand Down