|
12 | 12 | get_gdal_config_option,
|
13 | 13 | get_gdal_data_path,
|
14 | 14 | )
|
| 15 | +from pyogrio.core import detect_write_driver |
15 | 16 | from pyogrio.errors import DataSourceError, DataLayerError
|
16 | 17 |
|
17 | 18 | from pyogrio._env import GDALEnv
|
@@ -44,6 +45,58 @@ def test_gdal_geos_version():
|
44 | 45 | assert __gdal_geos_version__ is None or isinstance(__gdal_geos_version__, tuple)
|
45 | 46 |
|
46 | 47 |
|
| 48 | +@pytest.mark.parametrize( |
| 49 | + "path,expected", |
| 50 | + [ |
| 51 | + ("test.shp", "ESRI Shapefile"), |
| 52 | + ("test.shp.zip", "ESRI Shapefile"), |
| 53 | + ("test.geojson", "GeoJSON"), |
| 54 | + ("test.geojsonl", "GeoJSONSeq"), |
| 55 | + ("test.gpkg", "GPKG"), |
| 56 | + pytest.param( |
| 57 | + "test.gpkg.zip", |
| 58 | + "GPKG", |
| 59 | + marks=pytest.mark.skipif( |
| 60 | + __gdal_version__ < (3, 7, 0), |
| 61 | + reason="writing *.gpkg.zip requires GDAL >= 3.7.0", |
| 62 | + ), |
| 63 | + ), |
| 64 | + # postgres can be detected by prefix instead of extension |
| 65 | + pytest.param( |
| 66 | + "PG:dbname=test", |
| 67 | + "PostgreSQL", |
| 68 | + marks=pytest.mark.skipif( |
| 69 | + "PostgreSQL" not in list_drivers(), |
| 70 | + reason="PostgreSQL path test requires PostgreSQL driver", |
| 71 | + ), |
| 72 | + ), |
| 73 | + ], |
| 74 | +) |
| 75 | +def test_detect_write_driver(path, expected): |
| 76 | + assert detect_write_driver(path) == expected |
| 77 | + |
| 78 | + |
| 79 | +@pytest.mark.parametrize( |
| 80 | + "path", |
| 81 | + [ |
| 82 | + "test.svg", # only supports read |
| 83 | + "test.", # not a valid extension |
| 84 | + "test", # no extension or prefix |
| 85 | + "test.foo", # not a valid extension |
| 86 | + "FOO:test", # not a valid prefix |
| 87 | + ], |
| 88 | +) |
| 89 | +def test_detect_write_driver_unsupported(path): |
| 90 | + with pytest.raises(ValueError, match="Could not infer driver from path"): |
| 91 | + detect_write_driver(path) |
| 92 | + |
| 93 | + |
| 94 | +@pytest.mark.parametrize("path", ["test.xml", "test.txt"]) |
| 95 | +def test_detect_write_driver_multiple_unsupported(path): |
| 96 | + with pytest.raises(ValueError, match="multiple drivers are available"): |
| 97 | + detect_write_driver(path) |
| 98 | + |
| 99 | + |
47 | 100 | @pytest.mark.parametrize(
|
48 | 101 | "driver,expected",
|
49 | 102 | [
|
|
0 commit comments