Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support for sqlalchemy default parameters #455 #456

Merged
merged 1 commit into from
May 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions aiomysql/sa/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,29 @@

try:
from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql
from sqlalchemy.dialects.mysql.mysqldb import MySQLCompiler_mysqldb
except ImportError: # pragma: no cover
raise ImportError('aiomysql.sa requires sqlalchemy')


class MySQLCompiler_pymysql(MySQLCompiler_mysqldb):
def construct_params(self, params=None, _group_number=None, _check=True):
pd = super().construct_params(params, _group_number, _check)

for column in self.prefetch:
pd[column.key] = self._exec_default(column.default)

return pd

def _exec_default(self, default):
if default.is_callable:
return default.arg(self.dialect)
else:
return default.arg


_dialect = MySQLDialect_pymysql(paramstyle='pyformat')
_dialect.statement_compiler = MySQLCompiler_pymysql
_dialect.default_paramstyle = 'pyformat'


Expand Down
105 changes: 105 additions & 0 deletions tests/sa/test_sa_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import datetime

import pytest
from sqlalchemy import MetaData, Table, Column, Integer, String
from sqlalchemy import func, DateTime, Boolean

from aiomysql import sa

meta = MetaData()
table = Table('sa_tbl_default_test', meta,
Column('id', Integer, nullable=False, primary_key=True),
Column('string_length', Integer,
default=func.length('qwerty')),
Column('number', Integer, default=100, nullable=False),
Column('description', String(255), nullable=False,
default='default test'),
Column('created_at', DateTime,
default=datetime.datetime.now),
Column('enabled', Boolean, default=True))


@pytest.fixture()
def make_engine(mysql_params, connection):
async def _make_engine(**kwargs):
return (await sa.create_engine(db=mysql_params['db'],
user=mysql_params['user'],
password=mysql_params['password'],
host=mysql_params['host'],
port=mysql_params['port'],
minsize=10,
**kwargs))

return _make_engine


async def start(engine):
async with engine.acquire() as conn:
await conn.execute("DROP TABLE IF EXISTS sa_tbl_default_test")
await conn.execute("CREATE TABLE sa_tbl_default_test "
"(id integer,"
" string_length integer, "
"number integer,"
" description VARCHAR(255), "
"created_at DATETIME(6), "
"enabled TINYINT)")


@pytest.mark.run_loop
async def test_default_fields(make_engine):
engine = await make_engine()
await start(engine)
async with engine.acquire() as conn:
await conn.execute(table.insert().values())
res = await conn.execute(table.select())
row = await res.fetchone()
assert row.string_length == 6
assert row.number == 100
assert row.description == 'default test'
assert row.enabled is True
assert type(row.created_at) == datetime.datetime


@pytest.mark.run_loop
async def test_default_fields_isnull(make_engine):
engine = await make_engine()
await start(engine)
async with engine.acquire() as conn:
created_at = None
enabled = False
await conn.execute(table.insert().values(
enabled=enabled,
created_at=created_at,
))

res = await conn.execute(table.select())
row = await res.fetchone()
assert row.number == 100
assert row.string_length == 6
assert row.description == 'default test'
assert row.enabled == enabled
assert row.created_at == created_at


async def test_default_fields_edit(make_engine):
engine = await make_engine()
await start(engine)
async with engine.acquire() as conn:
created_at = datetime.datetime.now()
description = 'new descr'
enabled = False
number = 111
await conn.execute(table.insert().values(
description=description,
enabled=enabled,
created_at=created_at,
number=number,
))

res = await conn.execute(table.select())
row = await res.fetchone()
assert row.number == number
assert row.string_length == 6
assert row.description == description
assert row.enabled == enabled
assert row.created_at == created_at