Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Here you can see the full list of changes between each SQLAlchemy-Utils release.
- Added mixed case support for pg composite (#584, pull request courtesy of bamartin125)
- Support Python 3.10.
- Remove the dependency on the six package. (#605)
- Introduce sqlalchemy 2.0 compatibility. (#513)


0.38.2 (2021-12-29)
Expand Down
27 changes: 20 additions & 7 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import pytest
import sqlalchemy as sa
import sqlalchemy.event
import sqlalchemy.exc
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base, synonym_for
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import close_all_sessions
Expand All @@ -15,6 +16,11 @@
i18n,
InstrumentedList
)
from sqlalchemy_utils.compat import (
_declarative_base,
_select_args,
_synonym_for
)
from sqlalchemy_utils.functions.orm import _get_class_registry
from sqlalchemy_utils.types.pg_composite import remove_composite_listeners

Expand Down Expand Up @@ -148,7 +154,7 @@ def connection(engine):

@pytest.fixture
def Base():
return declarative_base()
return _declarative_base()


@pytest.fixture
Expand Down Expand Up @@ -185,7 +191,7 @@ def articles_count(self):
def articles_count(cls):
Article = _get_class_registry(Base)['Article']
return (
sa.select([sa.func.count(Article.id)])
sa.select(*_select_args(sa.func.count(Article.id)))
.where(Article.category_id == cls.id)
.correlate(Article.__table__)
.label('article_count')
Expand All @@ -195,7 +201,7 @@ def articles_count(cls):
def name_alias(self):
return self.name

@synonym_for('name')
@_synonym_for('name')
@property
def name_synonym(self):
return self.name
Expand Down Expand Up @@ -229,15 +235,22 @@ def init_models(User, Category, Article):
@pytest.fixture
def session(request, engine, connection, Base, init_models):
sa.orm.configure_mappers()
Base.metadata.create_all(connection)
with connection.begin():
Base.metadata.create_all(connection)
Session = sessionmaker(bind=connection)
session = Session()
try:
# Enable sqlalchemy 2.0 behavior.
session = Session(future=True)
except TypeError:
# sqlalchemy 1.3
session = Session()
i18n.get_locale = get_locale

def teardown():
aggregates.manager.reset()
close_all_sessions()
Base.metadata.drop_all(connection)
with connection.begin():
Base.metadata.drop_all(connection)
remove_composite_listeners()
connection.close()
engine.dispose()
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def get_version():
platforms='any',
install_requires=[
'SQLAlchemy>=1.0',
"importlib_metadata ; python_version<'3.8'",
],
extras_require=extras_require,
python_requires='~=3.6',
Expand Down
11 changes: 6 additions & 5 deletions sqlalchemy_utils/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,12 @@ class Rating(Base):
from weakref import WeakKeyDictionary

import sqlalchemy as sa
import sqlalchemy.event
import sqlalchemy.orm
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.sql.functions import _FunctionGenerator

from .compat import get_scalar_subquery
from .compat import _select_args, get_scalar_subquery
from .functions.orm import get_column_key
from .relationships import (
chained_join,
Expand Down Expand Up @@ -488,10 +490,9 @@ def update_query(self, objects):
return query.where(
local.in_(
sa.select(
[remote],
from_obj=[
chained_join(*reversed(self.relationships))
]
*_select_args(remote)
).select_from(
chained_join(*reversed(self.relationships))
).where(
condition
)
Expand Down
77 changes: 74 additions & 3 deletions sqlalchemy_utils/compat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,76 @@
def get_scalar_subquery(query):
try:
import sys

if sys.version_info >= (3, 8):
from importlib.metadata import metadata
else:
from importlib_metadata import metadata


_sqlalchemy_version = tuple(
[int(i) for i in metadata("sqlalchemy")["Version"].split(".")[:2]]
)


# In sqlalchemy 2.0, some functions moved to sqlalchemy.orm.
# In sqlalchemy 1.3, they are only available in .ext.declarative.
# In sqlalchemy 1.4, they are available in both places.
#
# WARNING
# -------
#
# These imports are for internal, private compatibility.
# They are not supported and may change or move at any time.
# Do not import these in your own code.
#

if _sqlalchemy_version >= (1, 4):
from sqlalchemy.orm import declarative_base as _declarative_base
from sqlalchemy.orm import synonym_for as _synonym_for
else:
from sqlalchemy.ext.declarative import \
declarative_base as _declarative_base
from sqlalchemy.ext.declarative import synonym_for as _synonym_for


# scalar subqueries
if _sqlalchemy_version >= (1, 4):
def get_scalar_subquery(query):
return query.scalar_subquery()
except AttributeError: # SQLAlchemy <1.4
else:
def get_scalar_subquery(query):
return query.as_scalar()


# In sqlalchemy 2.0, select() columns are positional.
# In sqlalchemy 1.3, select() columns must be wrapped in a list.
#
# _select_args() is designed so its return value can be unpacked:
#
# select(*_select_args(1, 2))
#
# When sqlalchemy 1.3 support is dropped, remove the call to _select_args()
# and keep the arguments the same:
#
# select(1, 2)
#
# WARNING
# -------
#
# _select_args() is a private, internal function.
# It is not supported and may change or move at any time.
# Do not import this in your own code.
#
if _sqlalchemy_version >= (1, 4):
def _select_args(*args):
return args
else:
def _select_args(*args):
return [args]


__all__ = (
"_declarative_base",
"get_scalar_subquery",
"_select_args",
"_synonym_for",
)
38 changes: 19 additions & 19 deletions sqlalchemy_utils/functions/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def _set_url_database(url: sa.engine.url.URL, database):


def _get_scalar_result(engine, sql):
with engine.connect() as conn:
with engine.begin() as conn:
return conn.scalar(sql)


Expand Down Expand Up @@ -485,7 +485,7 @@ def database_exists(url):
url = _set_url_database(url, database=db)
engine = sa.create_engine(url)
try:
return bool(_get_scalar_result(engine, text))
return bool(_get_scalar_result(engine, sa.text(text)))
except (ProgrammingError, OperationalError):
pass
return False
Expand All @@ -495,22 +495,22 @@ def database_exists(url):
engine = sa.create_engine(url)
text = ("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA "
"WHERE SCHEMA_NAME = '%s'" % database)
return bool(_get_scalar_result(engine, text))
return bool(_get_scalar_result(engine, sa.text(text)))

elif dialect_name == 'sqlite':
url = _set_url_database(url, database=None)
engine = sa.create_engine(url)
if database:
return database == ':memory:' or _sqlite_file_exists(database)
else:
# The default SQLAlchemy database is in memory, and :memory is
# The default SQLAlchemy database is in memory, and :memory: is
# not required, thus we should support that use case.
return True
else:
text = 'SELECT 1'
try:
engine = sa.create_engine(url)
return bool(_get_scalar_result(engine, text))
return bool(_get_scalar_result(engine, sa.text(text)))
except (ProgrammingError, OperationalError):
return False
finally:
Expand Down Expand Up @@ -571,27 +571,27 @@ def create_database(url, encoding='utf8', template=None):
quote(engine, template)
)

with engine.connect() as connection:
connection.execute(text)
with engine.begin() as connection:
connection.execute(sa.text(text))

elif dialect_name == 'mysql':
text = "CREATE DATABASE {0} CHARACTER SET = '{1}'".format(
quote(engine, database),
encoding
)
with engine.connect() as connection:
connection.execute(text)
with engine.begin() as connection:
connection.execute(sa.text(text))

elif dialect_name == 'sqlite' and database != ':memory:':
if database:
with engine.connect() as connection:
connection.execute("CREATE TABLE DB(id int);")
connection.execute("DROP TABLE DB;")
with engine.begin() as connection:
connection.execute(sa.text("CREATE TABLE DB(id int);"))
connection.execute(sa.text("DROP TABLE DB;"))

else:
text = 'CREATE DATABASE {0}'.format(quote(engine, database))
with engine.connect() as connection:
connection.execute(text)
with engine.begin() as connection:
connection.execute(sa.text(text))

engine.dispose()

Expand Down Expand Up @@ -635,7 +635,7 @@ def drop_database(url):
if database:
os.remove(database)
elif dialect_name == 'postgresql':
with engine.connect() as connection:
with engine.begin() as connection:
# Disconnect all users from the database we are dropping.
version = connection.dialect.server_version_info
pid_column = (
Expand All @@ -647,14 +647,14 @@ def drop_database(url):
WHERE pg_stat_activity.datname = '%(database)s'
AND %(pid_column)s <> pg_backend_pid();
''' % {'pid_column': pid_column, 'database': database}
connection.execute(text)
connection.execute(sa.text(text))

# Drop the database.
text = 'DROP DATABASE {0}'.format(quote(connection, database))
connection.execute(text)
connection.execute(sa.text(text))
else:
text = 'DROP DATABASE {0}'.format(quote(engine, database))
with engine.connect() as connection:
connection.execute(text)
with engine.begin() as connection:
connection.execute(sa.text(text))

engine.dispose()
1 change: 0 additions & 1 deletion sqlalchemy_utils/functions/foreign_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ def non_indexed_foreign_keys(metadata, engine=None):
table = Table(
table_name,
reflected_metadata,
autoload=True,
autoload_with=metadata.bind or engine
)

Expand Down
6 changes: 5 additions & 1 deletion sqlalchemy_utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def get(self, state, dict_, passive=attributes.PASSIVE_OFF):

id = self.get_state_id(state)

target = session.query(target_class).get(id)
try:
target = session.get(target_class, id)
except AttributeError:
# sqlalchemy 1.3
target = session.query(target_class).get(id)

# Return found (or not found) target.
return target
Expand Down
3 changes: 2 additions & 1 deletion sqlalchemy_utils/listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def auto_delete_orphans(attr):
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.declarative import declarative_base
# Necessary in sqlalchemy 1.3:
# from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import event


Expand Down
4 changes: 3 additions & 1 deletion sqlalchemy_utils/relationships/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sqlalchemy as sa
import sqlalchemy.orm
from sqlalchemy.sql.util import ClauseAdapter

from ..compat import _select_args
from .chained_join import chained_join # noqa


Expand Down Expand Up @@ -94,7 +96,7 @@ def select_correlated_expression(
):
relationships = list(reversed(path_to_relationships(path, root_model)))

query = sa.select([expr])
query = sa.select(*_select_args(expr))

join_expr, aliases = chained_inverse_join(relationships, leaf_model)

Expand Down
6 changes: 5 additions & 1 deletion sqlalchemy_utils/types/encrypted/encrypted_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,11 @@ class StringEncryptedType(TypeDecorator, ScalarCoercible):

import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
try:
from sqlalchemy.orm import declarative_base
except ImportError:
# sqlalchemy 1.3
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

from sqlalchemy_utils import EncryptedType
Expand Down
8 changes: 5 additions & 3 deletions sqlalchemy_utils/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def create_view(
Column('premium_user', Boolean, default=False),
)

premium_members = select([users]).where(users.c.premium_user == True)
premium_members = select(users).where(users.c.premium_user == True)
# sqlalchemy 1.3:
# premium_members = select([users]).where(users.c.premium_user == True)
create_view('premium_users', premium_members, metadata)

metadata.create_all(engine) # View is created at this point
Expand Down Expand Up @@ -189,8 +191,8 @@ def refresh_materialized_view(session, name, concurrently=False):
# order to include newly-created/modified objects in the refresh.
session.flush()
session.execute(
'REFRESH MATERIALIZED VIEW {}{}'.format(
sa.text('REFRESH MATERIALIZED VIEW {}{}'.format(
'CONCURRENTLY ' if concurrently else '',
session.bind.engine.dialect.identifier_preparer.quote(name)
)
))
)
Loading