Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
Changes
=======

Version <next>

- init: create connection string from its parts if available
- chore: remove unused imports

Version v2.0.0 (released 2024-11-19)

- uow: possible solution for the rollback problem
Expand Down
4 changes: 0 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@

"""Sphinx configuration."""

import os
import sys

import sphinx.environment

from invenio_db import __version__

Expand Down
2 changes: 0 additions & 2 deletions invenio_db/alembic/96e796392533_create_database_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

"""Create database migrations."""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "96e796392533"
Expand Down
43 changes: 43 additions & 0 deletions invenio_db/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Copyright (C) 2015-2018 CERN.
# Copyright (C) 2022 RERO.
# Copyright (C) 2022 Graz University of Technology.
# Copyright (C) 2025 TU Wien.
# Copyright (C) 2025 KTH Royal Institute of Technology.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
Expand Down Expand Up @@ -66,8 +68,49 @@
app.extensions["invenio-db"] = self
app.cli.add_command(db_cmd)

def build_db_uri(self, app):
"""Return the connection string if configured or build it from its parts.

If set, then ``SQLALCHEMY_DATABASE_URI`` will be returned.
Otherwise, the URI will be pieced together by the configuration items
``DB_{USER,PASSWORD,HOST,PORT,NAME,PROTOCOL}``, where ``DB_PORT`` is
optional.
If that cannot be done (e.g. because required values are missing), then
``None`` will be returned.
"""
if uri := app.config.get("SQLALCHEMY_DATABASE_URI", None):
return uri

params = {}
for config_name in ["USER", "PASSWORD", "HOST", "PORT", "NAME", "PROTOCOL"]:
params[config_name] = app.config.get(f"DB_{config_name}", None)

# The port is expected to be an int, and optional
if port := params.pop("PORT", None):
params["PORT"] = int(port)

if all(params.values()):
uri = sa.URL.create(
params["PROTOCOL"],
username=params["USER"],
password=params["PASSWORD"],
host=params["HOST"],
port=params["PORT"],
database=params["NAME"],
)
return uri
elif any(params.values()):
app.logger.warn(

Check warning on line 103 in invenio_db/ext.py

View workflow job for this annotation

GitHub Actions / Tests / Tests (3.9, postgresql14)

The 'warn' method is deprecated, use 'warning' instead

Check warning on line 103 in invenio_db/ext.py

View workflow job for this annotation

GitHub Actions / Tests / Tests (3.12, postgresql14)

The 'warn' method is deprecated, use 'warning' instead
'Ignoring "DB_*" config values as they are only partially set.'
)

return None

def init_db(self, app, entry_point_group="invenio_db.models", **kwargs):
"""Initialize Flask-SQLAlchemy extension."""
if uri := self.build_db_uri(app):
app.config["SQLALCHEMY_DATABASE_URI"] = uri

# Setup SQLAlchemy
app.config.setdefault(
"SQLALCHEMY_DATABASE_URI",
Expand Down
59 changes: 59 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
# Copyright (C) 2015-2018 CERN.
# Copyright (C) 2022 RERO.
# Copyright (C) 2024 Graz University of Technology.
# Copyright (C) 2025 TU Wien.
# Copyright (C) 2025 KTH Royal Institute of Technology.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Test database integration layer."""

import os
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -406,3 +409,59 @@ def test_db_create_alembic_upgrade(app, db):
drop_database(str(db.engine.url.render_as_string(hide_password=False)))
remove_versioning(manager=ext.versioning_manager)
create_database(str(db.engine.url.render_as_string(hide_password=False)))


@pytest.mark.parametrize(
"configs, expected_uri",
[
(
{
"DB_USER": "testuser",
"DB_PASSWORD": "testpassword",
"DB_HOST": "testhost",
"DB_PORT": "5432",
"DB_NAME": "testdb",
"DB_PROTOCOL": "postgresql+psycopg2",
},
"postgresql+psycopg2://testuser:testpassword@testhost:5432/testdb",
),
(
{
"SQLALCHEMY_DATABASE_URI": "postgresql+psycopg2://testuser:testpassword@testhost:5432/testdb"
},
"postgresql+psycopg2://testuser:testpassword@testhost:5432/testdb",
),
(
{
"DB_USER": "testuser",
"DB_PASSWORD": "testpassword",
"DB_HOST": "testhost",
"DB_PORT": "5432",
"DB_NAME": "testdb",
"DB_PROTOCOL": "postgresql+psycopg2",
"SQLALCHEMY_DATABASE_URI": "sqlite:///testdb.db",
},
"sqlite:///testdb.db",
),
(
{
"DB_USER": "testuser",
"DB_PASSWORD": "testpassword",
"DB_HOST": "testhost",
"DB_PORT": "5432",
"DB_NAME": "testdb",
"DB_PROTOCOL": None,
},
None,
),
],
)
def test_build_db_uri(configs, expected_uri):
"""Test building database URI."""
app = Flask("test_app")
assert "SQLALCHEMY_DATABASE_URI" not in app.config
app.config["DB_VERSIONING"] = False
app.config.update(configs)
InvenioDB(app)
default_uri = "sqlite:///" + os.path.join(app.instance_path, app.name + ".db")
assert app.config["SQLALCHEMY_DATABASE_URI"] == expected_uri or default_uri
Loading