Skip to content

Add Alembic support sqlalchemy 2.0 #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
from ydb_sqlalchemy.sqlalchemy import types

if sa.__version__ >= "2.":
from sqlalchemy import NullPool
from sqlalchemy import QueuePool
from sqlalchemy import NullPool, QueuePool
else:
from sqlalchemy.pool import NullPool
from sqlalchemy.pool import QueuePool
from sqlalchemy.pool import NullPool, QueuePool


def clear_sql(stm):
Expand Down
1 change: 0 additions & 1 deletion test/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
from sqlalchemy.testing.suite.test_types import DateTimeTest as _DateTimeTest
from sqlalchemy.testing.suite.test_types import IntegerTest as _IntegerTest
from sqlalchemy.testing.suite.test_types import JSONTest as _JSONTest

from sqlalchemy.testing.suite.test_types import NumericTest as _NumericTest
from sqlalchemy.testing.suite.test_types import StringTest as _StringTest
from sqlalchemy.testing.suite.test_types import (
Expand Down
5 changes: 3 additions & 2 deletions ydb_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._version import VERSION # noqa: F401
import ydb_dbapi as dbapi
from ydb_dbapi import IsolationLevel # noqa: F401

from ._version import VERSION # noqa: F401
from .sqlalchemy import Upsert, types, upsert # noqa: F401
import ydb_dbapi as dbapi
12 changes: 7 additions & 5 deletions ydb_sqlalchemy/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,25 @@

import sqlalchemy as sa
import ydb
import ydb_dbapi
from sqlalchemy import util
from sqlalchemy.engine import characteristics, reflection
from sqlalchemy.engine.default import DefaultExecutionContext, StrCompileDialect
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.sql import functions

from sqlalchemy.sql.elements import ClauseList

import ydb_dbapi
from ydb_sqlalchemy.sqlalchemy.compiler import (
YqlCompiler,
YqlDDLCompiler,
YqlIdentifierPreparer,
YqlTypeCompiler,
)
from ydb_sqlalchemy.sqlalchemy.dbapi_adapter import AdaptedAsyncConnection
from ydb_sqlalchemy.sqlalchemy.dml import Upsert

from ydb_sqlalchemy.sqlalchemy.compiler import YqlCompiler, YqlDDLCompiler, YqlIdentifierPreparer, YqlTypeCompiler

from . import types


OLD_SA = sa.__version__ < "2."


Expand Down
20 changes: 12 additions & 8 deletions ydb_sqlalchemy/sqlalchemy/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
sa_version = sa.__version__

if sa_version.startswith("2."):
from .sa20 import YqlCompiler
from .sa20 import YqlDDLCompiler
from .sa20 import YqlTypeCompiler
from .sa20 import YqlIdentifierPreparer
from .sa20 import (
YqlCompiler,
YqlDDLCompiler,
YqlIdentifierPreparer,
YqlTypeCompiler,
)
elif sa_version.startswith("1.4."):
from .sa14 import YqlCompiler
from .sa14 import YqlDDLCompiler
from .sa14 import YqlTypeCompiler
from .sa14 import YqlIdentifierPreparer
from .sa14 import (
YqlCompiler,
YqlDDLCompiler,
YqlIdentifierPreparer,
YqlTypeCompiler,
)
else:
raise RuntimeError("Unsupported SQLAlchemy version.")
18 changes: 3 additions & 15 deletions ydb_sqlalchemy/sqlalchemy/compiler/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import collections
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union

import sqlalchemy as sa
import ydb
from ydb_dbapi import NotSupportedError

from sqlalchemy.exc import CompileError
from sqlalchemy.sql import ddl
from sqlalchemy.sql.compiler import (
Expand All @@ -12,22 +12,10 @@
StrSQLTypeCompiler,
selectable,
)
from typing import (
Any,
Dict,
List,
Mapping,
Sequence,
Optional,
Tuple,
Type,
Union,
)

from ydb_dbapi import NotSupportedError

from .. import types


OLD_SA = sa.__version__ < "2."
if OLD_SA:
from sqlalchemy import bindparam as _bindparam
Expand Down
1 change: 1 addition & 0 deletions ydb_sqlalchemy/sqlalchemy/compiler/sa14.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union

import sqlalchemy as sa
import ydb

Expand Down
19 changes: 16 additions & 3 deletions ydb_sqlalchemy/sqlalchemy/compiler/sa20.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union

import sqlalchemy as sa
import ydb

from sqlalchemy.exc import CompileError
from sqlalchemy.sql import literal_column
from sqlalchemy.util.compat import inspect_getfullargspec
Expand All @@ -11,7 +12,6 @@
BaseYqlIdentifierPreparer,
BaseYqlTypeCompiler,
)
from typing import Union


class YqlTypeCompiler(BaseYqlTypeCompiler):
Expand Down Expand Up @@ -89,4 +89,17 @@ def visit_upsert(self, insert_stmt, visited_bindparam=None, **kw):


class YqlDDLCompiler(BaseYqlDDLCompiler):
...
def visit_foreign_key_constraint(self, constraint, **kwargs):
return None

def visit_primary_key_constraint(self, constraint, **kwargs):
if len(constraint) == 0:
return ""
text = ""
text += "PRIMARY KEY "
text += "(%s)" % ", ".join(
self.preparer.quote(c.name)
for c in (constraint.columns_autoinc_first if constraint._implicit_generated else constraint.columns)
)
text += self.define_constraint_deferrability(constraint)
return text
3 changes: 1 addition & 2 deletions ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import ydb
from sqlalchemy.engine.interfaces import AdaptedConnection

from sqlalchemy.util.concurrency import await_only
from ydb_dbapi import AsyncConnection, AsyncCursor
import ydb


class AdaptedAsyncConnection(AdaptedConnection):
Expand Down