Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add SQL adapter #779

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Automatically set SQL driver if unset.
  • Loading branch information
danielballan committed Sep 11, 2024
commit cfca6dfffcdb1fc69adb4b56ff7010cd119ca069
3 changes: 2 additions & 1 deletion tiled/_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..catalog import from_uri, in_memory
from ..client.base import BaseClient
from ..server.settings import get_settings
from ..utils import ensure_specified_sql_driver
from .utils import enter_password as utils_enter_password
from .utils import temp_postgres

Expand Down Expand Up @@ -152,7 +153,7 @@ async def postgresql_with_example_data_adapter(request, tmpdir):
if uri.endswith("/"):
uri = uri[:-1]
uri_with_database_name = f"{uri}/{DATABASE_NAME}"
engine = create_async_engine(uri_with_database_name)
engine = create_async_engine(ensure_specified_sql_driver(uri_with_database_name))
try:
async with engine.connect():
pass
Expand Down
3 changes: 2 additions & 1 deletion tiled/_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ..client import context
from ..client.base import BaseClient
from ..utils import ensure_specified_sql_driver

if sys.version_info < (3, 9):
import importlib_resources as resources
Expand All @@ -33,7 +34,7 @@ async def temp_postgres(uri):
if uri.endswith("/"):
uri = uri[:-1]
# Create a fresh database.
engine = create_async_engine(uri)
engine = create_async_engine(ensure_specified_sql_driver(uri))
database_name = f"tiled_test_disposable_{uuid.uuid4().hex}"
async with engine.connect() as connection:
await connection.execute(
Expand Down
5 changes: 4 additions & 1 deletion tiled/authn_database/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine

from ..server.settings import get_settings
from ..utils import ensure_specified_sql_driver

# A given process probably only has one of these at a time, but we
# key on database_settings just case in some testing context or something
Expand All @@ -16,7 +17,9 @@ def open_database_connection_pool(database_settings):
# kwargs["pool_pre_ping"] = database_settings.pool_pre_ping
# kwargs["max_overflow"] = database_settings.max_overflow
engine = create_async_engine(
database_settings.uri, connect_args=connect_args, **kwargs
ensure_specified_sql_driver(database_settings.uri),
connect_args=connect_args,
**kwargs,
)
_connection_pools[database_settings] = engine
return engine
Expand Down
6 changes: 5 additions & 1 deletion tiled/catalog/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
OneShotCachedMap,
UnsupportedQueryType,
ensure_awaitable,
ensure_specified_sql_driver,
ensure_uri,
import_object,
path_from_uri,
Expand Down Expand Up @@ -1381,7 +1382,10 @@ def from_uri(
else:
poolclass = None # defer to sqlalchemy default
engine = create_async_engine(
uri, echo=echo, json_serializer=json_serializer, poolclass=poolclass
ensure_specified_sql_driver(uri),
echo=echo,
json_serializer=json_serializer,
poolclass=poolclass,
)
if engine.dialect.name == "sqlite":
event.listens_for(engine.sync_engine, "connect")(_set_sqlite_pragma)
Expand Down
9 changes: 6 additions & 3 deletions tiled/commandline/_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ def initialize_database(database_uri: str):
REQUIRED_REVISION,
initialize_database,
)
from ..utils import ensure_specified_sql_driver

async def do_setup():
engine = create_async_engine(database_uri)
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
redacted_url = engine.url._replace(password="[redacted]")
try:
await check_database(engine, REQUIRED_REVISION, ALL_REVISIONS)
Expand Down Expand Up @@ -71,9 +72,10 @@ def upgrade_database(
ALEMBIC_INI_TEMPLATE_PATH,
)
from ..authn_database.core import ALL_REVISIONS
from ..utils import ensure_specified_sql_driver

async def do_setup():
engine = create_async_engine(database_uri)
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
redacted_url = engine.url._replace(password="[redacted]")
current_revision = await get_current_revision(engine, ALL_REVISIONS)
await engine.dispose()
Expand Down Expand Up @@ -107,9 +109,10 @@ def downgrade_database(
ALEMBIC_INI_TEMPLATE_PATH,
)
from ..authn_database.core import ALL_REVISIONS
from ..utils import ensure_specified_sql_driver

async def do_setup():
engine = create_async_engine(database_uri)
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
redacted_url = engine.url._replace(password="[redacted]")
current_revision = await get_current_revision(engine, ALL_REVISIONS)
if current_revision is None:
Expand Down
10 changes: 6 additions & 4 deletions tiled/commandline/_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def init(
from ..alembic_utils import UninitializedDatabase, check_database, stamp_head
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import ALL_REVISIONS, REQUIRED_REVISION, initialize_database
from ..utils import SCHEME_PATTERN
from ..utils import SCHEME_PATTERN, ensure_specified_sql_driver

if not SCHEME_PATTERN.match(database):
# Interpret URI as filepath.
database = f"sqlite+aiosqlite:///{database}"

async def do_setup():
engine = create_async_engine(database)
engine = create_async_engine(ensure_specified_sql_driver(database))
redacted_url = engine.url._replace(password="[redacted]")
try:
await check_database(engine, REQUIRED_REVISION, ALL_REVISIONS)
Expand Down Expand Up @@ -94,9 +94,10 @@ def upgrade_database(
from ..alembic_utils import get_current_revision, upgrade
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import ALL_REVISIONS
from ..utils import ensure_specified_sql_driver

async def do_setup():
engine = create_async_engine(database_uri)
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
redacted_url = engine.url._replace(password="[redacted]")
current_revision = await get_current_revision(engine, ALL_REVISIONS)
await engine.dispose()
Expand Down Expand Up @@ -127,9 +128,10 @@ def downgrade_database(
from ..alembic_utils import downgrade, get_current_revision
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import ALL_REVISIONS
from ..utils import ensure_specified_sql_driver

async def do_setup():
engine = create_async_engine(database_uri)
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
redacted_url = engine.url._replace(password="[redacted]")
current_revision = await get_current_revision(engine, ALL_REVISIONS)
if current_revision is None:
Expand Down
6 changes: 4 additions & 2 deletions tiled/commandline/_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ def serve_directory(
from ..alembic_utils import stamp_head
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import initialize_database
from ..utils import ensure_specified_sql_driver

engine = create_async_engine(database)
engine = create_async_engine(ensure_specified_sql_driver(database))
asyncio.run(initialize_database(engine))
stamp_head(ALEMBIC_INI_TEMPLATE_PATH, ALEMBIC_DIR, database)

Expand Down Expand Up @@ -389,8 +390,9 @@ def serve_catalog(
from ..alembic_utils import stamp_head
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import initialize_database
from ..utils import ensure_specified_sql_driver

engine = create_async_engine(database)
engine = create_async_engine(ensure_specified_sql_driver(database))
asyncio.run(initialize_database(engine))
stamp_head(ALEMBIC_INI_TEMPLATE_PATH, ALEMBIC_DIR, database)

Expand Down
25 changes: 25 additions & 0 deletions tiled/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,31 @@ def ensure_uri(uri_or_path) -> str:
return str(uri_str)


SCHEME_TO_SCHEME_PLUS_DRIVER = {
"postgresql": "postgresql+asyncpg",
"sqlite": "sqlite+aiosqlite",
}


def ensure_specified_sql_driver(uri: str) -> str:
"""
Given a URI without a driver in the scheme, add Tiled's preferred driver.

If a driver is already specified, the specified one will be used; it
will NOT be overriden by this function.

'postgresql://...' -> 'postgresql+asynpg://...'
'sqlite://...' -> 'sqlite+aiosqlite://...'
'postgresql+asyncpg://...' -> 'postgresql+asynpg://...'
'postgresql+my_custom_driver://...' -> 'postgresql+my_custom_driver://...'
"""
parsed_uri = urlparse(uri)
scheme = parsed_uri.scheme
new_scheme = SCHEME_TO_SCHEME_PLUS_DRIVER.get(scheme, scheme)
updated_uri = urlunparse(parsed_uri._replace(scheme=new_scheme))
return updated_uri


class catch_warning_msg(warnings.catch_warnings):
"""Backward compatible version of catch_warnings for python <3.11.

Expand Down