Skip to content

Commit

Permalink
Improve monkeypatch robustness and related tests (#425)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrien-berchet authored Feb 3, 2023
1 parent 308ac48 commit 5973ea3
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ repos:
rev: 6.3.0
hooks:
- id: pydocstyle
additional_dependencies: ["toml"]
additional_dependencies: ["tomli"]
exclude: "tests"
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
Expand Down
13 changes: 8 additions & 5 deletions geoalchemy2/alembic_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ def _monkey_patch_get_indexes_for_sqlite():
def spatial_behavior(self, connection, table_name, schema=None, **kw):
indexes = self._get_indexes_normal_behavior(connection, table_name, schema=None, **kw)

# Check that SpatiaLite was loaded into the DB
is_spatial_db = connection.exec_driver_sql(
"""PRAGMA main.table_info(geometry_columns)"""
).fetchall()
if not is_spatial_db:
try:
# Check that SpatiaLite was loaded into the DB
is_spatial_db = connection.exec_driver_sql(
"""PRAGMA main.table_info(geometry_columns)"""
).fetchall()
if not is_spatial_db:
return indexes
except AttributeError:
return indexes

# Get spatial indexes
Expand Down
1 change: 0 additions & 1 deletion geoalchemy2/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class BaseComparator(UserDefinedType.Comparator):
key = None

def __getattr__(self, name):

# Function names that don't start with "ST_" are rejected.
# This is not to mess up with SQLAlchemy's use of
# hasattr/getattr on Column objects.
Expand Down
1 change: 0 additions & 1 deletion geoalchemy2/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def _register_geo_function(cls, clsname, clsdict):


class TableRowElement(ColumnElement):

inherit_cache = False
"""The cache is disabled for this class."""

Expand Down
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from sqlalchemy import MetaData
from sqlalchemy import create_engine
from sqlalchemy import text
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import sessionmaker

from geoalchemy2.alembic_helpers import _monkey_patch_get_indexes_for_sqlite

from . import copy_and_connect_sqlite_db
from . import get_postgis_version
from . import get_postgres_major_version
Expand Down Expand Up @@ -175,6 +178,23 @@ def postgres_major_version(conn):
return get_postgres_major_version(conn)


@pytest.fixture(autouse=True)
def reset_sqlite_monkeypatch():
"""Disable Alembic monkeypatching by default."""
try:
normal_behavior = SQLiteDialect._get_indexes_normal_behavior
SQLiteDialect.get_indexes = normal_behavior
SQLiteDialect._get_indexes_normal_behavior = normal_behavior
except AttributeError:
pass


@pytest.fixture(autouse=True)
def use_sqlite_monkeypatch():
"""Enable Alembic monkeypatching ."""
_monkey_patch_get_indexes_for_sqlite()


@pytest.fixture
def setup_tables(session, metadata):
conn = session.connection()
Expand Down
8 changes: 5 additions & 3 deletions tests/test_alembic_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def filter_tables(name, type_, parent_names):


class TestAutogenerate:
def test_no_diff(self, conn, Lake, setup_tables):
def test_no_diff(self, conn, Lake, setup_tables, use_sqlite_monkeypatch):
"""Check that the autogeneration detects spatial types properly."""
metadata = MetaData()

Expand Down Expand Up @@ -56,7 +56,7 @@ def test_no_diff(self, conn, Lake, setup_tables):

assert diff == []

def test_diff(self, conn, Lake, setup_tables):
def test_diff(self, conn, Lake, setup_tables, use_sqlite_monkeypatch):
"""Check that the autogeneration detects spatial types properly."""
metadata = MetaData()

Expand Down Expand Up @@ -252,7 +252,9 @@ class = StreamHandler


@test_only_with_dialects("postgresql", "sqlite-spatialite4")
def test_migration_revision(conn, metadata, alembic_config, alembic_env_path, test_script_path):
def test_migration_revision(
conn, metadata, alembic_config, alembic_env_path, test_script_path, use_sqlite_monkeypatch
):
initial_rev = command.revision(
alembic_config,
"Initial state",
Expand Down
4 changes: 0 additions & 4 deletions tests/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def test_hash(self):


class TestExtendedWKTElement:

_srid = 3857 # expected srid
_wkt = "POINT (1 2 3)" # expected wkt
_ewkt = "SRID=3857;POINT (1 2 3)" # expected ewkt
Expand Down Expand Up @@ -208,7 +207,6 @@ def test_ST_Equal_Column_WKTElement(self, geometry_table):


class TestExtendedWKBElement:

# _bin/_hex computed by following query:
# SELECT ST_GeomFromEWKT('SRID=3;POINT(1 2)');
_bin = memoryview(
Expand Down Expand Up @@ -328,7 +326,6 @@ def test_hash(self):


class TestNotEqualSpatialElement:

# _bin/_hex computed by following query:
# SELECT ST_GeomFromEWKT('SRID=3;POINT(1 2)');
_ewkb = memoryview(
Expand Down Expand Up @@ -368,7 +365,6 @@ def test_neq_other_types(self):


class TestRasterElement:

rast_data = (
b"\x01\x00\x00\x01\x00\x9a\x99\x99\x99\x99\x99\xc9?\x9a\x99\x99\x99\x99\x99"
b"\xc9\xbf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00"
Expand Down
1 change: 0 additions & 1 deletion tests/test_functional_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,6 @@ def test_unknown_func(self):

class TestSTSummaryStatsAgg:
def test_st_summary_stats_agg(self, session, Ocean, setup_tables):

# Create a new raster
polygon = WKTElement("POLYGON((0 0,1 1,0 1,0 0))", srid=4326)
o = Ocean(polygon.ST_AsRaster(5, 6))
Expand Down
1 change: 0 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ def test_function_call(self, raster_table):
eq_sql(s, 'SELECT ST_Height("table".rast) ' 'AS "ST_Height_1" FROM "table"')

def test_non_ST_function_call(self, raster_table):

with pytest.raises(AttributeError):
raster_table.c.geom.Height()

Expand Down

0 comments on commit 5973ea3

Please sign in to comment.