Skip to content

Commit

Permalink
Merge pull request #2 from adrien-berchet/fix_sqlite
Browse files Browse the repository at this point in the history
Fix sqlite to handle M dimension
  • Loading branch information
sdp5 authored Apr 19, 2024
2 parents 96af7c6 + e584682 commit 8590a9e
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 57 deletions.
26 changes: 16 additions & 10 deletions geoalchemy2/types/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
from geoalchemy2.shape import to_shape


def format_geom_type(wkt, forced_srid=None):
def format_geom_type(wkt, default_srid=None):
"""Format the Geometry type for SQLite."""
match = re.match(WKTElement.SPLIT_WKT_PATTERN, wkt)
if match is None:
return wkt
_, srid, geom_type, coords = match.groups()
geom_type = geom_type.replace(" ", "")
if geom_type.endswith("M"):
geom_type = geom_type[:-1]
if geom_type.endswith("ZM"):
geom_type = geom_type[:-2]
if geom_type.endswith("Z"):
geom_type = geom_type[:-1]
if forced_srid is not None:
srid = f"SRID={forced_srid}"
if srid is None and default_srid is not None:
srid = f"SRID={default_srid}"
if srid is not None:
return "%s;%s%s" % (srid, geom_type, coords)
else:
Expand All @@ -29,16 +29,22 @@ def format_geom_type(wkt, forced_srid=None):

def bind_processor_process(spatial_type, bindvalue):
if isinstance(bindvalue, WKTElement):
return format_geom_type(bindvalue.data, forced_srid=bindvalue.srid)
return format_geom_type(
bindvalue.data,
default_srid=bindvalue.srid if bindvalue.srid >= 0 else spatial_type.srid,
)
elif isinstance(bindvalue, WKBElement):
if bindvalue.srid == -1:
bindvalue.srid = spatial_type.srid
# With SpatiaLite we use Shapely to convert the WKBElement to an EWKT string
shape = to_shape(bindvalue)
# shapely.wkb.loads returns geom_type with a 'Z', for example, 'LINESTRING Z'
# which is a limitation with SpatiaLite. Hence, a temporary fix.
return format_geom_type(shape.wkt, forced_srid=bindvalue.srid)
res = format_geom_type(
shape.wkt, default_srid=bindvalue.srid if bindvalue.srid >= 0 else spatial_type.srid
)
return res
elif isinstance(bindvalue, RasterElement):
return "%s" % (bindvalue.data)
elif isinstance(bindvalue, str):
return format_geom_type(bindvalue, default_srid=spatial_type.srid)
else:
return format_geom_type(bindvalue)
return bindvalue
72 changes: 60 additions & 12 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,28 +280,59 @@ def test_insert(self, conn, Lake, setup_tables):
[
pytest.param("POINT", "(1 2)", id="Point"),
pytest.param("POINTZ", "(1 2 3)", id="Point Z"),
pytest.param("POINTM", "(1 2 3)", id="Point M"),
pytest.param("POINTZM", "(1 2 3 4)", id="Point ZM"),
pytest.param("LINESTRING", "(1 2, 3 4)", id="LineString"),
pytest.param("LINESTRINGZ", "(1 2 3, 4 5 6)", id="LineString Z"),
pytest.param("LINESTRINGM", "(1 2 3, 4 5 6)", id="LineString M"),
pytest.param("LINESTRINGZM", "(1 2 3 4, 5 6 7 8)", id="LineString ZM"),
pytest.param("POLYGON", "((1 2, 3 4, 5 6, 1 2))", id="Polygon"),
pytest.param("POLYGONZ", "((1 2 3, 4 5 6, 7 8 9, 1 2 3))", id="Polygon Z"),
pytest.param("POLYGONM", "((1 2 3, 4 5 6, 7 8 9, 1 2 3))", id="Polygon M"),
pytest.param(
"POLYGONZM", "((1 2 3 4, 5 6 7 8, 9 10 11 12, 1 2 3 4))", id="Polygon ZM"
),
pytest.param("MULTIPOINT", "(1 2, 3 4)", id="Multi Point"),
pytest.param("MULTIPOINTZ", "(1 2 3, 4 5 6)", id="Multi Point Z"),
pytest.param("MULTIPOINTM", "(1 2 3, 4 5 6)", id="Multi Point M"),
pytest.param("MULTIPOINTZM", "(1 2 3 4, 5 6 7 8)", id="Multi Point ZM"),
pytest.param("MULTILINESTRING", "((1 2, 3 4), (10 20, 30 40))", id="Multi LineString"),
pytest.param(
"MULTILINESTRINGZ",
"((1 2 3, 4 5 6), (10 20 30, 40 50 60))",
id="Multi LineString Z",
),
pytest.param(
"MULTILINESTRINGM",
"((1 2 3, 4 5 6), (10 20 30, 40 50 60))",
id="Multi LineString M",
),
pytest.param(
"MULTILINESTRINGZM",
"((1 2 3 4, 5 6 7 8), (10 20 30 40, 50 60 70 80))",
id="Multi LineString ZM",
),
pytest.param(
"MULTIPOLYGON",
"(((1 2, 3 4, 5 6, 1 2), (10 20, 30 40, 50 60, 10 20)))",
"(((1 2, 3 4, 5 6, 1 2)), ((10 20, 30 40, 50 60, 10 20)))",
id="Multi Polygon",
),
pytest.param(
"MULTIPOLYGONZ",
"(((1 2 3, 4 5 6, 7 8 9, 1 2 3), (10 20 30, 40 50 60, 70 80 90, 10 20 30)))",
"(((1 2 3, 4 5 6, 7 8 9, 1 2 3)), ((10 20 30, 40 50 60, 70 80 90, 10 20 30)))",
id="Multi Polygon Z",
),
pytest.param(
"MULTIPOLYGONM",
"(((1 2 3, 4 5 6, 7 8 9, 1 2 3)), ((10 20 30, 40 50 60, 70 80 90, 10 20 30)))",
id="Multi Polygon M",
),
pytest.param(
"MULTIPOLYGONZM",
"(((1 2 3 4, 5 6 7 8, 9 10 11 12, 1 2 3 4)),"
" ((10 20 30 40, 50 60 70 80, 90 100 100 120, 10 20 30 40)))",
id="Multi Polygon ZM",
),
],
)
def test_insert_all_geom_types(self, dialect_name, base, conn, metadata, geom_type, wkt):
Expand All @@ -311,10 +342,13 @@ def test_insert_all_geom_types(self, dialect_name, base, conn, metadata, geom_ty
ndims += 1
if geom_type.endswith("M"):
ndims += 1
has_m = True
else:
has_m = False

if ndims > 2 and dialect_name == "mysql":
# Explicitly skip MySQL dialect to show that it can only work with 2D geometries
pytest.skip(reason="MySQL only supports 2D geometry types")
pytest.xfail(reason="MySQL only supports 2D geometry types")

class GeomTypeTable(base):
__tablename__ = "test_geom_types"
Expand All @@ -331,15 +365,17 @@ class GeomTypeTable(base):
text("SELECT ST_AsBinary(ST_GeomFromText('{}', 4326))".format(inserted_wkt))
).scalar()

wkb_elem = WKBElement(raw_wkb, srid=4326)
inserted_elements = [
{"geom": inserted_wkt},
{"geom": f"SRID=4326;{inserted_wkt}"},
{"geom": WKTElement(inserted_wkt, srid=4326)},
{"geom": WKTElement(f"SRID=4326;{inserted_wkt}")},
]
if dialect_name not in ["sqlite", "geopackage"]:
inserted_elements.append({"geom": inserted_wkt})
if dialect_name not in ["sqlite", "geopackage"] or ndims == 2:
inserted_elements.append({"geom": WKBElement(raw_wkb, srid=4326)})
if dialect_name not in ["postgresql", "sqlite"] or not has_m:
# Currently Shapely does not support geometry types with M dimension
inserted_elements.append({"geom": wkb_elem})
inserted_elements.append({"geom": wkb_elem.as_ewkb()})

# Insert the elements
conn.execute(
Expand All @@ -348,17 +384,29 @@ class GeomTypeTable(base):
)

# Select the elements
query = select([GeomTypeTable.__table__.c.geom.ST_AsText()])
query = select(
[
GeomTypeTable.__table__.c.id,
GeomTypeTable.__table__.c.geom.ST_AsText(),
GeomTypeTable.__table__.c.geom.ST_SRID(),
],
)
results = conn.execute(query)
rows = results.scalars().all()
rows = results.all()

# Check that the selected elements are the same as the inputs
for row in rows:
for row_id, row, srid in rows:
checked_wkt = row.upper().replace(" ", "")
expected_wkt = inserted_wkt.upper().replace(" ", "")
if dialect_name == "mysql" and geom_type == "MULTIPOINT":
checked_wkt = re.sub(r"\((\d+)\)", "\\1", checked_wkt)
assert checked_wkt == expected_wkt
print(row_id, row, srid)
if row_id >= 5 and dialect_name in ["geopackage"] and has_m:
# Currently Shapely does not support geometry types with M dimension
assert checked_wkt != expected_wkt
else:
assert checked_wkt == expected_wkt
assert srid == 4326

@test_only_with_dialects("postgresql", "sqlite")
def test_insert_geom_poi(self, conn, Poi, setup_tables):
Expand Down Expand Up @@ -465,7 +513,7 @@ def test_WKT(self, session, Lake, setup_tables, dialect_name, postgis_version):
lake = Lake("LINESTRING(0 0,1 1)")
session.add(lake)

if (dialect_name == "postgresql" and postgis_version < 3) or dialect_name == "sqlite":
if dialect_name == "postgresql" and postgis_version < 3:
with pytest.raises((DataError, IntegrityError)):
session.flush()
else:
Expand Down
35 changes: 0 additions & 35 deletions tests/test_functional_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,41 +129,6 @@ def test_explicit_schema(self, conn):
# Drop the table
t.drop(bind=conn)

def test_3d_geometry(self, conn, metadata):
# Define the table
col = Column(
"geom",
Geometry(geometry_type=None, srid=4326, spatial_index=False),
nullable=False,
)
t = Table(
"3d_geom_type",
metadata,
Column("id", Integer, primary_key=True),
col,
)

# Create the table
t.create(bind=conn)

# Should be 'LINESTRING Z (0 0 0, 1 1 1)'
# Read comments at geoalchemy2/types/dialects/sqlite.py#L22
elements = {"geom": "SRID=4326;LINESTRING (0 0 0, 1 1 1)"}
conn.execute(t.insert(), elements)

with pytest.raises((IntegrityError, OperationalError)):
with conn.begin_nested():
# This returns a NULL for the geom field.
conn.execute(t.insert(), [{"geom": "SRID=4326;LINESTRING Z (0 0 0, 1 1 1)"}])

results = conn.execute(t.select())
rows = results.fetchall()

assert len(rows) == 1

# Drop the table
t.drop(bind=conn)


class TestIndex:
@pytest.fixture
Expand Down

0 comments on commit 8590a9e

Please sign in to comment.