Skip to content
Merged
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History
Latest
------
- ENH: RPCs: Load and write RPCs (#837)
- ENH: Convert xarray structure back to rasterio Dataset (#309, #777)

0.20.0
------
Expand Down
25 changes: 25 additions & 0 deletions rioxarray/raster_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import copy
import os
from collections.abc import Hashable, Iterable, Mapping
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Literal, Optional, Union

Expand All @@ -25,6 +26,7 @@
from rasterio.dtypes import dtype_rev
from rasterio.enums import Resampling
from rasterio.features import geometry_mask
from rasterio.io import DatasetReader, MemoryFile
from xarray.backends.file_manager import FileManager
from xarray.core.dtypes import get_fill_value

Expand Down Expand Up @@ -1270,3 +1272,26 @@ def to_raster(
compute=compute,
**out_profile,
)

@contextmanager
def to_rasterio_dataset(self) -> DatasetReader:
"""
Return the xarray.Dataset or xarray.DataArray as a rasterio.Dataset.

As rioxarray is able to ingest a rasterio.Dataset, this function is its counterpart.

To be used as a context manager.

.. versionadded:: 0.21

Example
-------

>>> with xds.to_rasterio_dataset() as rio_ds:
>>> rio_ds.count

"""
with MemoryFile() as memfile:
self.to_raster(memfile.name)
with memfile.open() as src_ds:
yield src_ds
59 changes: 59 additions & 0 deletions test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from numpy.testing import assert_almost_equal, assert_array_equal
from packaging import version
from pyproj import CRS as pCRS
from rasterio import DatasetReader
from rasterio.control import GroundControlPoint
from rasterio.crs import CRS
from rasterio.windows import Window
Expand Down Expand Up @@ -3354,3 +3355,61 @@ def test_non_rectilinear__reproject(rename, open_rasterio):
2818720.0,
)
)


def test_to_rasterio_dataset():
in_rpc_path = os.path.join(TEST_INPUT_DATA_DIR, "cog.tif")
in_rpc = rioxarray.open_rasterio(in_rpc_path)
with in_rpc.rio.to_rasterio_dataset() as riox_ds, rasterio.open(
in_rpc_path
) as rio_ds:
# object type
assert isinstance(riox_ds, DatasetReader), "Error in object type"

# metadata vs rioxarray
assert in_rpc.rio.crs == riox_ds.crs, "Error in CRS vs rio accessor"
assert in_rpc.rio.shape == riox_ds.shape, "Error in shape vs rio accessor"
assert in_rpc.dtype == riox_ds.meta["dtype"], "Error in dtype vs rio accessor"
assert (
in_rpc.rio.transform() == riox_ds.transform
), "Error in transform vs rio accessor"
assert riox_ds.profile["driver"] == "GTiff", "Error in driver vs rio accessor"

# metadata vs rioxarray
assert rio_ds.crs == riox_ds.crs, "Error in CRS vs rasterio"
assert rio_ds.shape == riox_ds.shape, "Error in shape vs rasterio"
assert (
rio_ds.meta["dtype"] == riox_ds.meta["dtype"]
), "Error in dtype vs rasterio"
assert rio_ds.transform == riox_ds.transform, "Error in transform vs rasterio"
assert rio_ds.profile == riox_ds.profile, "Error in driver vs rasterio"


def test_to_rasterio_dataset_rpcs():
in_rpc_path = os.path.join(TEST_INPUT_DATA_DIR, "test_rpcs.tif")
in_rpc = rioxarray.open_rasterio(in_rpc_path)
with in_rpc.rio.to_rasterio_dataset() as riox_ds, rasterio.open(
in_rpc_path
) as rio_ds:
# object type
assert isinstance(riox_ds, DatasetReader), "Error in object type"

# metadata vs rioxarray
assert in_rpc.rio.get_rpcs() == riox_ds.rpcs, "Error in RPCs vs rio accessor"
assert in_rpc.rio.crs == riox_ds.crs, "Error in CRS vs rio accessor"
assert in_rpc.rio.shape == riox_ds.shape, "Error in shape vs rio accessor"
assert in_rpc.dtype == riox_ds.meta["dtype"], "Error in dtype vs rio accessor"
assert (
in_rpc.rio.transform() == riox_ds.transform
), "Error in transform vs rio accessor"
assert riox_ds.profile["driver"] == "GTiff", "Error in driver vs rio accessor"

# metadata vs rioxarray
assert rio_ds.rpcs == riox_ds.rpcs, "Error in RPCs vs rasterio"
assert rio_ds.crs == riox_ds.crs, "Error in CRS vs rasterio"
assert rio_ds.shape == riox_ds.shape, "Error in shape vs rasterio"
assert (
rio_ds.meta["dtype"] == riox_ds.meta["dtype"]
), "Error in dtype vs rasterio"
assert rio_ds.transform == riox_ds.transform, "Error in transform vs rasterio"
assert rio_ds.profile == riox_ds.profile, "Error in driver vs rasterio"
Loading