Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: allow using open_arrow with PyCapsule protocol (without pyarrow dependency) #349

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
- `read_arrow` and `open_arrow` now provide
[GeoArrow-compliant extension metadata](https://geoarrow.org/extension-types.html),
including the CRS, when using GDAL 3.8 or higher (#366).
- The `open_arrow` function can now be used without a `pyarrow` dependency. By
default, it will now return a stream object implementing the
[Arrow PyCapsule Protocol](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html)
(i.e. having an `__arrow_c_stream__`method). This object can then be consumed
by your Arrow implementation of choice that supports this protocol. To keep
the previous behaviour of returning a `pyarrow.RecordBatchReader`, specify
`use_pyarrow=True` (#349).
- Warn when reading from a multilayer file without specifying a layer (#362).

### Bug fixes
Expand Down
3 changes: 2 additions & 1 deletion pyogrio/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
pandas = None


HAS_ARROW_API = __gdal_version__ >= (3, 6, 0) and pyarrow is not None
HAS_ARROW_API = __gdal_version__ >= (3, 6, 0)
HAS_PYARROW = pyarrow is not None

HAS_GEOPANDAS = geopandas is not None

Expand Down
59 changes: 50 additions & 9 deletions pyogrio/_io.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ from libc.string cimport strlen
from libc.math cimport isnan

cimport cython
from cpython.pycapsule cimport PyCapsule_New, PyCapsule_GetPointer

import numpy as np

from pyogrio._ogr cimport *
Expand Down Expand Up @@ -1256,6 +1258,35 @@ def ogr_read(
field_data
)


cdef void pycapsule_array_stream_deleter(object stream_capsule) noexcept:
cdef ArrowArrayStream* stream = <ArrowArrayStream*>PyCapsule_GetPointer(
jorisvandenbossche marked this conversation as resolved.
Show resolved Hide resolved
stream_capsule, 'arrow_array_stream'
)
# Do not invoke the deleter on a used/moved capsule
if stream.release != NULL:
stream.release(stream)

free(stream)


cdef object alloc_c_stream(ArrowArrayStream** c_stream):
c_stream[0] = <ArrowArrayStream*> malloc(sizeof(ArrowArrayStream))
# Ensure the capsule destructor doesn't call a random release pointer
c_stream[0].release = NULL
return PyCapsule_New(c_stream[0], 'arrow_array_stream', &pycapsule_array_stream_deleter)


class _ArrowStream:
def __init__(self, capsule):
self._capsule = capsule

def __arrow_c_stream__(self, requested_schema=None):
if requested_schema is not None:
raise NotImplementedError("requested_schema is not supported")
return self._capsule


@contextlib.contextmanager
def ogr_open_arrow(
str path,
Expand All @@ -1274,7 +1305,9 @@ def ogr_open_arrow(
str sql=None,
str sql_dialect=None,
int return_fids=False,
int batch_size=0):
int batch_size=0,
use_pyarrow=False,
):

cdef int err = 0
cdef const char *path_c = NULL
Expand All @@ -1286,7 +1319,7 @@ def ogr_open_arrow(
cdef char **fields_c = NULL
cdef const char *field_c = NULL
cdef char **options = NULL
cdef ArrowArrayStream stream
cdef ArrowArrayStream* stream
cdef ArrowSchema schema

IF CTE_GDAL_VERSION < (3, 6, 0):
Expand Down Expand Up @@ -1419,19 +1452,23 @@ def ogr_open_arrow(
# make sure layer is read from beginning
OGR_L_ResetReading(ogr_layer)

if not OGR_L_GetArrowStream(ogr_layer, &stream, options):
raise RuntimeError("Failed to open ArrowArrayStream from Layer")
# allocate the stream struct and wrap in capsule to ensure clean-up on error
capsule = alloc_c_stream(&stream)

stream_ptr = <uintptr_t> &stream
if not OGR_L_GetArrowStream(ogr_layer, stream, options):
raise RuntimeError("Failed to open ArrowArrayStream from Layer")

if skip_features:
# only supported for GDAL >= 3.8.0; have to do this after getting
# the Arrow stream
OGR_L_SetNextByIndex(ogr_layer, skip_features)

# stream has to be consumed before the Dataset is closed
import pyarrow as pa
reader = pa.RecordBatchStreamReader._import_from_c(stream_ptr)
if use_pyarrow:
import pyarrow as pa
jorisvandenbossche marked this conversation as resolved.
Show resolved Hide resolved

reader = pa.RecordBatchStreamReader._import_from_c(<uintptr_t> stream)
else:
reader = _ArrowStream(capsule)

meta = {
'crs': crs,
Expand All @@ -1442,13 +1479,16 @@ def ogr_open_arrow(
'fid_column': fid_column,
}

# stream has to be consumed before the Dataset is closed
yield meta, reader

finally:
if reader is not None:
if use_pyarrow and reader is not None:
# Mark reader as closed to prevent reading batches
reader.close()

# `stream` will be freed through `capsule` destructor

CSLDestroy(options)
if fields_c != NULL:
CSLDestroy(fields_c)
Expand All @@ -1465,6 +1505,7 @@ def ogr_open_arrow(
GDALClose(ogr_dataset)
ogr_dataset = NULL


def ogr_read_bounds(
str path,
object layer=None,
Expand Down
5 changes: 3 additions & 2 deletions pyogrio/_ogr.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,13 @@ cdef extern from "ogr_srs_api.h":
void OSRRelease(OGRSpatialReferenceH srs)


cdef extern from "arrow_bridge.h":
cdef extern from "arrow_bridge.h" nogil:
struct ArrowSchema:
int64_t n_children

struct ArrowArrayStream:
int (*get_schema)(ArrowArrayStream* stream, ArrowSchema* out)
int (*get_schema)(ArrowArrayStream* stream, ArrowSchema* out) noexcept
void (*release)(ArrowArrayStream*) noexcept


cdef extern from "ogr_api.h":
Expand Down
57 changes: 49 additions & 8 deletions pyogrio/raw.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings

from pyogrio._env import GDALEnv
from pyogrio._compat import HAS_ARROW_API
from pyogrio._compat import HAS_ARROW_API, HAS_PYARROW
from pyogrio.core import detect_write_driver
from pyogrio.errors import DataSourceError
from pyogrio.util import (
Expand Down Expand Up @@ -256,6 +256,12 @@ def read_arrow(
"geometry_name": "<name of geometry column in arrow table>",
}
"""
if not HAS_PYARROW:
raise RuntimeError(
"pyarrow required to read using 'read_arrow'. You can use 'open_arrow' "
"to read data with an alternative Arrow implementation"
)

from pyarrow import Table

gdal_version = get_gdal_version()
Expand Down Expand Up @@ -297,6 +303,7 @@ def read_arrow(
return_fids=return_fids,
skip_features=gdal_skip_features,
batch_size=batch_size,
use_pyarrow=True,
**kwargs,
) as source:
meta, reader = source
Expand Down Expand Up @@ -351,17 +358,37 @@ def open_arrow(
sql_dialect=None,
return_fids=False,
batch_size=65_536,
use_pyarrow=False,
**kwargs,
):
"""
Open OGR data source as a stream of pyarrow record batches.
Open OGR data source as a stream of Arrow record batches.

See docstring of `read` for parameters.

The RecordBatchStreamReader is reading from a stream provided by OGR and must not be
The returned object is reading from a stream provided by OGR and must not be
accessed after the OGR dataset has been closed, i.e. after the context manager has
been closed.

By default this functions returns a generic stream object implementing
the `Arrow PyCapsule Protocol`_ (i.e. having an ``__arrow_c_stream__``
method). This object can then be consumed by your Arrow implementation
of choice that supports this protocol.
Optionally, you can specify ``use_pyarrow=True`` to directly get the
stream as a `pyarrow.RecordBatchReader`.

.. _Arrow PyCapsule Protocol: https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html

Other Parameters
----------------
batch_size : int (default: 65_536)
Maximum number of features to retrieve in a batch.
use_pyarrow : bool (default: False)
If True, return a pyarrow RecordBatchReader instead of a generic
ArrowStream object. In the default case, this stream object needs
to be passed to another library supporting the Arrow PyCapsule
Protocol to consume the stream of data.

Examples
--------

Expand All @@ -370,16 +397,29 @@ def open_arrow(
>>> import shapely
>>>
>>> with open_arrow(path) as source:
>>> meta, stream = source
>>> # wrap the arrow stream object in a pyarrow RecordBatchReader
>>> reader = pa.RecordBatchReader.from_stream(stream)
>>> for batch in reader:
>>> geometries = shapely.from_wkb(batch[meta["geometry_name"] or "wkb_geometry"])

The returned `stream` object needs to be consumed by a library implementing
the Arrow PyCapsule Protocol. In the above example, pyarrow is used through
its RecordBatchReader. For this case, you can also specify ``use_pyarrow=True``
to directly get this result as a short-cut:

>>> with open_arrow(path, use_pyarrow=True) as source:
>>> meta, reader = source
>>> for table in reader:
>>> geometries = shapely.from_wkb(table[meta["geometry_name"]])
>>> for batch in reader:
>>> geometries = shapely.from_wkb(batch[meta["geometry_name"] or "wkb_geometry"])

Returns
-------
(dict, pyarrow.RecordBatchStreamReader)
(dict, pyarrow.RecordBatchReader or ArrowStream)

Returns a tuple of meta information about the data source in a dict,
and a pyarrow RecordBatchStreamReader with data.
and a data stream object (a generic ArrowStream object, or a pyarrow
RecordBatchReader if `use_pyarrow` is set to True).

Meta is: {
"crs": "<crs>",
Expand All @@ -390,7 +430,7 @@ def open_arrow(
}
"""
if not HAS_ARROW_API:
raise RuntimeError("pyarrow and GDAL>= 3.6 required to read using arrow")
raise RuntimeError("GDAL>= 3.6 required to read using arrow")

path, buffer = get_vsi_path(path_or_buffer)

Expand All @@ -415,6 +455,7 @@ def open_arrow(
return_fids=return_fids,
dataset_kwargs=dataset_kwargs,
batch_size=batch_size,
use_pyarrow=use_pyarrow,
)
finally:
if buffer is not None:
Expand Down
7 changes: 4 additions & 3 deletions pyogrio/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
__version__,
list_drivers,
)
from pyogrio._compat import HAS_ARROW_API, HAS_GDAL_GEOS, HAS_SHAPELY
from pyogrio._compat import HAS_ARROW_API, HAS_GDAL_GEOS, HAS_PYARROW, HAS_SHAPELY
from pyogrio.raw import read, write


Expand Down Expand Up @@ -43,8 +43,9 @@ def pytest_report_header(config):


# marks to skip tests if optional dependecies are not present
requires_arrow_api = pytest.mark.skipif(
not HAS_ARROW_API, reason="GDAL>=3.6 and pyarrow required"
requires_arrow_api = pytest.mark.skipif(not HAS_ARROW_API, reason="GDAL>=3.6 required")
requires_pyarrow_api = pytest.mark.skipif(
not HAS_ARROW_API or not HAS_PYARROW, reason="GDAL>=3.6 and pyarrow required"
)

requires_gdal_geos = pytest.mark.skipif(
Expand Down
47 changes: 41 additions & 6 deletions pyogrio/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import json
import math
import os
import sys

import pytest

import numpy as np

import pyogrio
from pyogrio import __gdal_version__, read_dataframe
from pyogrio.raw import open_arrow, read_arrow, write
from pyogrio.tests.conftest import ALL_EXTS, requires_arrow_api
from pyogrio.tests.conftest import ALL_EXTS, requires_pyarrow_api

try:
import pandas as pd
Expand All @@ -20,7 +22,7 @@
pass

# skip all tests in this file if Arrow API or GeoPandas are unavailable
pytestmark = requires_arrow_api
pytestmark = requires_pyarrow_api
pytest.importorskip("geopandas")


Expand Down Expand Up @@ -137,8 +139,8 @@ def test_read_arrow_raw(naturalearth_lowres):
assert isinstance(table, pyarrow.Table)


def test_open_arrow(naturalearth_lowres):
with open_arrow(naturalearth_lowres) as (meta, reader):
def test_open_arrow_pyarrow(naturalearth_lowres):
with open_arrow(naturalearth_lowres, use_pyarrow=True) as (meta, reader):
assert isinstance(meta, dict)
assert isinstance(reader, pyarrow.RecordBatchReader)
assert isinstance(reader.read_all(), pyarrow.Table)
Expand All @@ -148,7 +150,10 @@ def test_open_arrow_batch_size(naturalearth_lowres):
meta, table = read_arrow(naturalearth_lowres)
batch_size = math.ceil(len(table) / 2)

with open_arrow(naturalearth_lowres, batch_size=batch_size) as (meta, reader):
with open_arrow(naturalearth_lowres, batch_size=batch_size, use_pyarrow=True) as (
meta,
reader,
):
assert isinstance(meta, dict)
assert isinstance(reader, pyarrow.RecordBatchReader)
count = 0
Expand Down Expand Up @@ -207,6 +212,36 @@ def test_read_arrow_geoarrow_metadata(naturalearth_lowres):
assert parsed_meta["crs"]["id"]["code"] == 4326


def test_open_arrow_capsule_protocol(naturalearth_lowres):
jorisvandenbossche marked this conversation as resolved.
Show resolved Hide resolved
pytest.importorskip("pyarrow", minversion="14")

with open_arrow(naturalearth_lowres) as (meta, reader):
assert isinstance(meta, dict)
assert isinstance(reader, pyogrio._io._ArrowStream)

result = pyarrow.table(reader)

_, expected = read_arrow(naturalearth_lowres)
assert result.equals(expected)


def test_open_arrow_capsule_protocol_without_pyarrow(naturalearth_lowres):
pyarrow = pytest.importorskip("pyarrow", minversion="14")

# Make PyArrow temporarily unavailable (importing will fail)
sys.modules["pyarrow"] = None
try:
with open_arrow(naturalearth_lowres) as (meta, reader):
assert isinstance(meta, dict)
assert isinstance(reader, pyogrio._io._ArrowStream)
result = pyarrow.table(reader)
finally:
sys.modules["pyarrow"] = pyarrow

_, expected = read_arrow(naturalearth_lowres)
assert result.equals(expected)


@contextlib.contextmanager
def use_arrow_context():
original = os.environ.get("PYOGRIO_USE_ARROW", None)
Expand Down
Loading