Skip to content

Commit

Permalink
Handle special SQLite URIs
Browse files Browse the repository at this point in the history
  • Loading branch information
danielballan committed Sep 11, 2024
1 parent 9714e86 commit 8f676e8
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 4 deletions.
61 changes: 61 additions & 0 deletions tiled/_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from ..utils import ensure_specified_sql_driver


def test_ensure_specified_sql_driver():
# Postgres
# Default driver is added if missing.
assert (
ensure_specified_sql_driver(
"postgresql://user:password@localhost:5432/database"
)
== "postgresql+asyncpg://user:password@localhost:5432/database"
)
# Default driver passes through if specified.
assert (
ensure_specified_sql_driver(
"postgresql+asyncpg://user:password@localhost:5432/database"
)
== "postgresql+asyncpg://user:password@localhost:5432/database"
)
# Do not override user-provided.
assert (
ensure_specified_sql_driver(
"postgresql+custom://user:password@localhost:5432/database"
)
== "postgresql+custom://user:password@localhost:5432/database"
)

# SQLite
# Default driver is added if missing.
assert (
ensure_specified_sql_driver("sqlite:////test.db")
== "sqlite+aiosqlite:////test.db"
)
# Default driver passes through if specified.
assert (
ensure_specified_sql_driver("sqlite+aiosqlite:////test.db")
== "sqlite+aiosqlite:////test.db"
)
# Do not override user-provided.
assert (
ensure_specified_sql_driver("sqlite+custom:////test.db")
== "sqlite+custom:////test.db"
)
# Handle SQLite :memory: URIs
assert (
ensure_specified_sql_driver("sqlite+aiosqlite://:memory:")
== "sqlite+aiosqlite://:memory:"
)
assert (
ensure_specified_sql_driver("sqlite://:memory:")
== "sqlite+aiosqlite://:memory:"
)
# Handle SQLite relative URIs
assert (
ensure_specified_sql_driver("sqlite+aiosqlite:///test.db")
== "sqlite+aiosqlite:///test.db"
)
assert (
ensure_specified_sql_driver("sqlite:///test.db")
== "sqlite+aiosqlite:///test.db"
)
6 changes: 2 additions & 4 deletions tiled/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,11 +739,9 @@ def ensure_specified_sql_driver(uri: str) -> str:
'postgresql+asyncpg://...' -> 'postgresql+asynpg://...'
'postgresql+my_custom_driver://...' -> 'postgresql+my_custom_driver://...'
"""
parsed_uri = urlparse(uri)
scheme = parsed_uri.scheme
scheme, rest = uri.split(":", 1)
new_scheme = SCHEME_TO_SCHEME_PLUS_DRIVER.get(scheme, scheme)
updated_uri = urlunparse(parsed_uri._replace(scheme=new_scheme))
return updated_uri
return ":".join([new_scheme, rest])


class catch_warning_msg(warnings.catch_warnings):
Expand Down

0 comments on commit 8f676e8

Please sign in to comment.