Skip to content

Commit aacd996

Browse files
authored
ENH: Add support for detecting write drivers using GDAL (#270)
1 parent 3ad88d0 commit aacd996

File tree

12 files changed

+191
-68
lines changed

12 files changed

+191
-68
lines changed

CHANGES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
performance impacts for some data sources that would otherwise return an
1212
unknown count (count is used in `read_info`, `read`, `read_dataframe`) (#271).
1313

14+
- Automatically detect supported driver by extension for all available
15+
write drivers and addition of `detect_write_driver` (#270)
16+
1417
### Bug fixes
1518

1619
- Fix int32 overflow when reading int64 columns (#260)

docs/source/api.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Core
55
----
66

77
.. automodule:: pyogrio
8-
:members: list_drivers, list_layers, read_bounds, read_info, set_gdal_config_options, get_gdal_config_option, __gdal_version__, __gdal_version_string__
8+
:members: list_drivers, detect_write_driver, list_layers, read_bounds, read_info, set_gdal_config_options, get_gdal_config_option, __gdal_version__, __gdal_version_string__
99

1010
GeoPandas integration
1111
---------------------

pyogrio/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from pyogrio.core import (
1010
list_drivers,
11+
detect_write_driver,
1112
list_layers,
1213
read_bounds,
1314
read_info,
@@ -27,6 +28,7 @@
2728

2829
__all__ = [
2930
"list_drivers",
31+
"detect_write_driver",
3032
"list_layers",
3133
"read_bounds",
3234
"read_info",

pyogrio/_ogr.pyx

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,46 @@ def _get_driver_metadata_item(driver, metadata_item):
317317
metadata = None
318318

319319
return metadata
320+
321+
322+
def _get_drivers_for_path(path):
323+
cdef OGRSFDriverH driver = NULL
324+
cdef int i
325+
cdef char *name_c
326+
327+
path = str(path).lower()
328+
329+
parts = os.path.splitext(path)
330+
if len(parts) == 2 and len(parts[1]) > 1:
331+
ext = parts[1][1:]
332+
else:
333+
ext = None
334+
335+
336+
# allow specific drivers to have a .zip extension to match GDAL behavior
337+
if ext == 'zip':
338+
if path.endswith('.shp.zip'):
339+
ext = 'shp.zip'
340+
elif path.endswith('.gpkg.zip'):
341+
ext = 'gpkg.zip'
342+
343+
drivers = []
344+
for i in range(OGRGetDriverCount()):
345+
driver = OGRGetDriver(i)
346+
name_c = <char *>OGR_Dr_GetName(driver)
347+
name = get_string(name_c)
348+
349+
if not ogr_driver_supports_write(name):
350+
continue
351+
352+
# extensions is a space-delimited list of supported extensions
353+
# for driver
354+
extensions = _get_driver_metadata_item(name, "DMD_EXTENSIONS")
355+
if ext is not None and extensions is not None and ext in extensions.lower().split(' '):
356+
drivers.append(name)
357+
else:
358+
prefix = _get_driver_metadata_item(name, "DMD_CONNECTION_PREFIX")
359+
if prefix is not None and path.startswith(prefix.lower()):
360+
drivers.append(name)
361+
362+
return drivers

pyogrio/core.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from pyogrio._env import GDALEnv
2-
from pyogrio.raw import _preprocess_options_key_value
3-
from pyogrio.util import get_vsi_path
2+
from pyogrio.util import get_vsi_path, _preprocess_options_key_value
43

54

65
with GDALEnv():
@@ -16,6 +15,7 @@
1615
init_proj_data as _init_proj_data,
1716
remove_virtual_file,
1817
_register_drivers,
18+
_get_drivers_for_path,
1919
)
2020
from pyogrio._err import _register_error_handler
2121
from pyogrio._io import ogr_list_layers, ogr_read_bounds, ogr_read_info
@@ -58,6 +58,42 @@ def list_drivers(read=False, write=False):
5858
return drivers
5959

6060

61+
def detect_write_driver(path):
62+
"""Attempt to infer the driver for a path by extension or prefix. Only
63+
drivers that support write capabilities will be detected.
64+
65+
If the path cannot be resolved to a single driver, a ValueError will be
66+
raised.
67+
68+
Parameters
69+
----------
70+
path : str
71+
72+
Returns
73+
-------
74+
str
75+
name of the driver, if detected
76+
"""
77+
# try to infer driver from path
78+
drivers = _get_drivers_for_path(path)
79+
80+
if len(drivers) == 0:
81+
raise ValueError(
82+
f"Could not infer driver from path: {path}; please specify driver "
83+
"explicitly"
84+
)
85+
86+
# if there are multiple drivers detected, user needs to specify the correct
87+
# one manually
88+
elif len(drivers) > 1:
89+
raise ValueError(
90+
f"Could not infer driver from path: {path}; multiple drivers are "
91+
"available for that extension. Please specify driver explicitly"
92+
)
93+
94+
return drivers[0]
95+
96+
6197
def list_layers(path_or_buffer, /):
6298
"""List layers available in an OGR data source.
6399

pyogrio/geopandas.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import numpy as np
2-
from pyogrio.raw import DRIVERS_NO_MIXED_SINGLE_MULTI, DRIVERS_NO_MIXED_DIMENSIONS
3-
from pyogrio.raw import detect_driver, read, read_arrow, write
2+
from pyogrio.raw import (
3+
DRIVERS_NO_MIXED_SINGLE_MULTI,
4+
DRIVERS_NO_MIXED_DIMENSIONS,
5+
detect_write_driver,
6+
read,
7+
read_arrow,
8+
write,
9+
)
410
from pyogrio.errors import DataSourceError
511

612

@@ -312,7 +318,7 @@ def write_dataframe(
312318
raise ValueError("'df' must be a DataFrame or GeoDataFrame")
313319

314320
if driver is None:
315-
driver = detect_driver(path)
321+
driver = detect_write_driver(path)
316322

317323
geometry_columns = df.columns[df.dtypes == "geometry"]
318324
if len(geometry_columns) > 1:

pyogrio/raw.py

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import warnings
2-
import os
32

43
from pyogrio._env import GDALEnv
4+
from pyogrio.core import detect_write_driver
55
from pyogrio.errors import DataSourceError
6-
from pyogrio.util import get_vsi_path, vsi_path
6+
from pyogrio.util import get_vsi_path, vsi_path, _preprocess_options_key_value
77

88
with GDALEnv():
99
from pyogrio._io import ogr_open_arrow, ogr_read, ogr_write
@@ -16,17 +16,6 @@
1616
)
1717

1818

19-
DRIVERS = {
20-
".fgb": "FlatGeobuf",
21-
".geojson": "GeoJSON",
22-
".geojsonl": "GeoJSONSeq",
23-
".geojsons": "GeoJSONSeq",
24-
".gpkg": "GPKG",
25-
".json": "GeoJSON",
26-
".shp": "ESRI Shapefile",
27-
}
28-
29-
3019
DRIVERS_NO_MIXED_SINGLE_MULTI = {
3120
"FlatGeobuf",
3221
"GPKG",
@@ -310,26 +299,6 @@ def open_arrow(
310299
remove_virtual_file(path)
311300

312301

313-
def detect_driver(path):
314-
# try to infer driver from path
315-
parts = os.path.splitext(path)
316-
if len(parts) != 2:
317-
raise ValueError(
318-
f"Could not infer driver from path: {path}; please specify driver "
319-
"explicitly"
320-
)
321-
322-
ext = parts[1].lower()
323-
driver = DRIVERS.get(ext, None)
324-
if driver is None:
325-
raise ValueError(
326-
f"Could not infer driver from path: {path}; please specify driver "
327-
"explicitly"
328-
)
329-
330-
return driver
331-
332-
333302
def _parse_options_names(xml):
334303
"""Convert metadata xml to list of names"""
335304
# Based on Fiona's meta.py
@@ -347,27 +316,6 @@ def _parse_options_names(xml):
347316
return options
348317

349318

350-
def _preprocess_options_key_value(options):
351-
"""
352-
Preprocess options, eg `spatial_index=True` gets converted
353-
to `SPATIAL_INDEX="YES"`.
354-
"""
355-
if not isinstance(options, dict):
356-
raise TypeError(f"Expected options to be a dict, got {type(options)}")
357-
358-
result = {}
359-
for k, v in options.items():
360-
if v is None:
361-
continue
362-
k = k.upper()
363-
if isinstance(v, bool):
364-
v = "ON" if v else "OFF"
365-
else:
366-
v = str(v)
367-
result[k] = v
368-
return result
369-
370-
371319
def write(
372320
path,
373321
geometry,
@@ -393,7 +341,7 @@ def write(
393341
path = vsi_path(str(path))
394342

395343
if driver is None:
396-
driver = detect_driver(path)
344+
driver = detect_write_driver(path)
397345

398346
# verify that driver supports writing
399347
if not ogr_driver_supports_write(driver):

pyogrio/tests/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@
88

99

1010
_data_dir = Path(__file__).parent.resolve() / "fixtures"
11+
12+
# mapping of driver extension to driver name for well-supported drivers
13+
DRIVERS = {
14+
".fgb": "FlatGeobuf",
15+
".geojson": "GeoJSON",
16+
".geojsonl": "GeoJSONSeq",
17+
".geojsons": "GeoJSONSeq",
18+
".gpkg": "GPKG",
19+
".json": "GeoJSON",
20+
".shp": "ESRI Shapefile",
21+
}
22+
23+
# mapping of driver name to extension
24+
DRIVER_EXT = {driver: ext for ext, driver in DRIVERS.items()}
25+
1126
ALL_EXTS = [".fgb", ".geojson", ".geojsonl", ".gpkg", ".shp"]
1227

1328

pyogrio/tests/test_core.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
get_gdal_config_option,
1313
get_gdal_data_path,
1414
)
15+
from pyogrio.core import detect_write_driver
1516
from pyogrio.errors import DataSourceError, DataLayerError
1617

1718
from pyogrio._env import GDALEnv
@@ -44,6 +45,58 @@ def test_gdal_geos_version():
4445
assert __gdal_geos_version__ is None or isinstance(__gdal_geos_version__, tuple)
4546

4647

48+
@pytest.mark.parametrize(
49+
"path,expected",
50+
[
51+
("test.shp", "ESRI Shapefile"),
52+
("test.shp.zip", "ESRI Shapefile"),
53+
("test.geojson", "GeoJSON"),
54+
("test.geojsonl", "GeoJSONSeq"),
55+
("test.gpkg", "GPKG"),
56+
pytest.param(
57+
"test.gpkg.zip",
58+
"GPKG",
59+
marks=pytest.mark.skipif(
60+
__gdal_version__ < (3, 7, 0),
61+
reason="writing *.gpkg.zip requires GDAL >= 3.7.0",
62+
),
63+
),
64+
# postgres can be detected by prefix instead of extension
65+
pytest.param(
66+
"PG:dbname=test",
67+
"PostgreSQL",
68+
marks=pytest.mark.skipif(
69+
"PostgreSQL" not in list_drivers(),
70+
reason="PostgreSQL path test requires PostgreSQL driver",
71+
),
72+
),
73+
],
74+
)
75+
def test_detect_write_driver(path, expected):
76+
assert detect_write_driver(path) == expected
77+
78+
79+
@pytest.mark.parametrize(
80+
"path",
81+
[
82+
"test.svg", # only supports read
83+
"test.", # not a valid extension
84+
"test", # no extension or prefix
85+
"test.foo", # not a valid extension
86+
"FOO:test", # not a valid prefix
87+
],
88+
)
89+
def test_detect_write_driver_unsupported(path):
90+
with pytest.raises(ValueError, match="Could not infer driver from path"):
91+
detect_write_driver(path)
92+
93+
94+
@pytest.mark.parametrize("path", ["test.xml", "test.txt"])
95+
def test_detect_write_driver_multiple_unsupported(path):
96+
with pytest.raises(ValueError, match="multiple drivers are available"):
97+
detect_write_driver(path)
98+
99+
47100
@pytest.mark.parametrize(
48101
"driver,expected",
49102
[

pyogrio/tests/test_geopandas_io.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010
from pyogrio.errors import DataLayerError, DataSourceError, FeatureError, GeometryError
1111
from pyogrio.geopandas import read_dataframe, write_dataframe
1212
from pyogrio.raw import (
13-
DRIVERS,
1413
DRIVERS_NO_MIXED_DIMENSIONS,
1514
DRIVERS_NO_MIXED_SINGLE_MULTI,
1615
)
17-
from pyogrio.tests.conftest import ALL_EXTS
16+
from pyogrio.tests.conftest import ALL_EXTS, DRIVERS
1817

1918
try:
2019
import pandas as pd

0 commit comments

Comments
 (0)