From 5973ea3ab105f37e78e06002faad1c499985ac24 Mon Sep 17 00:00:00 2001 From: Adrien Berchet Date: Fri, 3 Feb 2023 23:02:53 +0100 Subject: [PATCH] Improve monkeypatch robustness and related tests (#425) --- .pre-commit-config.yaml | 2 +- geoalchemy2/alembic_helpers.py | 13 ++++++++----- geoalchemy2/comparator.py | 1 - geoalchemy2/functions.py | 1 - tests/conftest.py | 20 ++++++++++++++++++++ tests/test_alembic_migrations.py | 8 +++++--- tests/test_elements.py | 4 ---- tests/test_functional_postgresql.py | 1 - tests/test_types.py | 1 - 9 files changed, 34 insertions(+), 17 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 065baff2..d6737165 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: rev: 6.3.0 hooks: - id: pydocstyle - additional_dependencies: ["toml"] + additional_dependencies: ["tomli"] exclude: "tests" - repo: https://github.com/PyCQA/flake8 rev: 6.0.0 diff --git a/geoalchemy2/alembic_helpers.py b/geoalchemy2/alembic_helpers.py index c5aade54..33e3e66a 100644 --- a/geoalchemy2/alembic_helpers.py +++ b/geoalchemy2/alembic_helpers.py @@ -41,11 +41,14 @@ def _monkey_patch_get_indexes_for_sqlite(): def spatial_behavior(self, connection, table_name, schema=None, **kw): indexes = self._get_indexes_normal_behavior(connection, table_name, schema=None, **kw) - # Check that SpatiaLite was loaded into the DB - is_spatial_db = connection.exec_driver_sql( - """PRAGMA main.table_info(geometry_columns)""" - ).fetchall() - if not is_spatial_db: + try: + # Check that SpatiaLite was loaded into the DB + is_spatial_db = connection.exec_driver_sql( + """PRAGMA main.table_info(geometry_columns)""" + ).fetchall() + if not is_spatial_db: + return indexes + except AttributeError: return indexes # Get spatial indexes diff --git a/geoalchemy2/comparator.py b/geoalchemy2/comparator.py index 366a4440..542c3475 100644 --- a/geoalchemy2/comparator.py +++ b/geoalchemy2/comparator.py @@ -77,7 +77,6 @@ class BaseComparator(UserDefinedType.Comparator): key = None def __getattr__(self, name): - # Function names that don't start with "ST_" are rejected. # This is not to mess up with SQLAlchemy's use of # hasattr/getattr on Column objects. diff --git a/geoalchemy2/functions.py b/geoalchemy2/functions.py index 7a7d3c1b..6cc9a462 100644 --- a/geoalchemy2/functions.py +++ b/geoalchemy2/functions.py @@ -125,7 +125,6 @@ def _register_geo_function(cls, clsname, clsdict): class TableRowElement(ColumnElement): - inherit_cache = False """The cache is disabled for this class.""" diff --git a/tests/conftest.py b/tests/conftest.py index 03cb4f44..b86e7323 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,9 +5,12 @@ from sqlalchemy import MetaData from sqlalchemy import create_engine from sqlalchemy import text +from sqlalchemy.dialects.sqlite.base import SQLiteDialect from sqlalchemy.orm import declarative_base from sqlalchemy.orm import sessionmaker +from geoalchemy2.alembic_helpers import _monkey_patch_get_indexes_for_sqlite + from . import copy_and_connect_sqlite_db from . import get_postgis_version from . import get_postgres_major_version @@ -175,6 +178,23 @@ def postgres_major_version(conn): return get_postgres_major_version(conn) +@pytest.fixture(autouse=True) +def reset_sqlite_monkeypatch(): + """Disable Alembic monkeypatching by default.""" + try: + normal_behavior = SQLiteDialect._get_indexes_normal_behavior + SQLiteDialect.get_indexes = normal_behavior + SQLiteDialect._get_indexes_normal_behavior = normal_behavior + except AttributeError: + pass + + +@pytest.fixture(autouse=True) +def use_sqlite_monkeypatch(): + """Enable Alembic monkeypatching .""" + _monkey_patch_get_indexes_for_sqlite() + + @pytest.fixture def setup_tables(session, metadata): conn = session.connection() diff --git a/tests/test_alembic_migrations.py b/tests/test_alembic_migrations.py index 60be527c..83a4b8d1 100644 --- a/tests/test_alembic_migrations.py +++ b/tests/test_alembic_migrations.py @@ -25,7 +25,7 @@ def filter_tables(name, type_, parent_names): class TestAutogenerate: - def test_no_diff(self, conn, Lake, setup_tables): + def test_no_diff(self, conn, Lake, setup_tables, use_sqlite_monkeypatch): """Check that the autogeneration detects spatial types properly.""" metadata = MetaData() @@ -56,7 +56,7 @@ def test_no_diff(self, conn, Lake, setup_tables): assert diff == [] - def test_diff(self, conn, Lake, setup_tables): + def test_diff(self, conn, Lake, setup_tables, use_sqlite_monkeypatch): """Check that the autogeneration detects spatial types properly.""" metadata = MetaData() @@ -252,7 +252,9 @@ class = StreamHandler @test_only_with_dialects("postgresql", "sqlite-spatialite4") -def test_migration_revision(conn, metadata, alembic_config, alembic_env_path, test_script_path): +def test_migration_revision( + conn, metadata, alembic_config, alembic_env_path, test_script_path, use_sqlite_monkeypatch +): initial_rev = command.revision( alembic_config, "Initial state", diff --git a/tests/test_elements.py b/tests/test_elements.py index d754ac88..75c7d811 100644 --- a/tests/test_elements.py +++ b/tests/test_elements.py @@ -82,7 +82,6 @@ def test_hash(self): class TestExtendedWKTElement: - _srid = 3857 # expected srid _wkt = "POINT (1 2 3)" # expected wkt _ewkt = "SRID=3857;POINT (1 2 3)" # expected ewkt @@ -208,7 +207,6 @@ def test_ST_Equal_Column_WKTElement(self, geometry_table): class TestExtendedWKBElement: - # _bin/_hex computed by following query: # SELECT ST_GeomFromEWKT('SRID=3;POINT(1 2)'); _bin = memoryview( @@ -328,7 +326,6 @@ def test_hash(self): class TestNotEqualSpatialElement: - # _bin/_hex computed by following query: # SELECT ST_GeomFromEWKT('SRID=3;POINT(1 2)'); _ewkb = memoryview( @@ -368,7 +365,6 @@ def test_neq_other_types(self): class TestRasterElement: - rast_data = ( b"\x01\x00\x00\x01\x00\x9a\x99\x99\x99\x99\x99\xc9?\x9a\x99\x99\x99\x99\x99" b"\xc9\xbf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00" diff --git a/tests/test_functional_postgresql.py b/tests/test_functional_postgresql.py index 9366a70f..52ab8008 100644 --- a/tests/test_functional_postgresql.py +++ b/tests/test_functional_postgresql.py @@ -612,7 +612,6 @@ def test_unknown_func(self): class TestSTSummaryStatsAgg: def test_st_summary_stats_agg(self, session, Ocean, setup_tables): - # Create a new raster polygon = WKTElement("POLYGON((0 0,1 1,0 1,0 0))", srid=4326) o = Ocean(polygon.ST_AsRaster(5, 6)) diff --git a/tests/test_types.py b/tests/test_types.py index b372e2d1..9b201c36 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -252,7 +252,6 @@ def test_function_call(self, raster_table): eq_sql(s, 'SELECT ST_Height("table".rast) ' 'AS "ST_Height_1" FROM "table"') def test_non_ST_function_call(self, raster_table): - with pytest.raises(AttributeError): raster_table.c.geom.Height()