Skip to content

Commit

Permalink
BUG: Properly handle encoding/decoding scales and offsets (#821)
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 authored Nov 4, 2024
1 parent 8a642fb commit 6a29c36
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 30 deletions.
4 changes: 4 additions & 0 deletions rioxarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,8 +957,12 @@ def _handle_encoding(
variables.pop_to(
result.attrs, result.encoding, "scale_factor", name=da_name
)
if "scales" in result.attrs:
variables.pop_to(result.attrs, result.encoding, "scales", name=da_name)
if "add_offset" in result.attrs:
variables.pop_to(result.attrs, result.encoding, "add_offset", name=da_name)
if "offsets" in result.attrs:
variables.pop_to(result.attrs, result.encoding, "offsets", name=da_name)
if masked:
if "_FillValue" in result.attrs:
variables.pop_to(result.attrs, result.encoding, "_FillValue", name=da_name)
Expand Down
19 changes: 11 additions & 8 deletions rioxarray/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@ def __init__(self, xds: DataArray):
"crs": self.crs,
"nodata": self.nodatavals[0],
}
self._scale_factor = self._xds.encoding.get("scale_factor", 1.0)
self._add_offset = self._xds.encoding.get("add_offset", 0.0)
valid_scale_factor = self._xds.encoding.get("scale_factor", 1) != 1 or any(
scale != 1 for scale in self._xds.encoding.get("scales", (1,))
)
valid_offset = self._xds.encoding.get("add_offset", 0.0) != 0 or any(
offset != 0 for offset in self._xds.encoding.get("offsets", (0,))
)
self._mask_and_scale = (
self._xds.rio.encoded_nodata is not None
or self._scale_factor != 1
or self._add_offset != 0
or valid_scale_factor
or valid_offset
or self._xds.encoding.get("_Unsigned") is not None
)

Expand All @@ -70,10 +74,9 @@ def read(self, *args, **kwargs) -> numpy.ma.MaskedArray:
kwargs["masked"] = True
out = dataset.read(*args, **kwargs)
if self._mask_and_scale:
if self._scale_factor != 1:
out = out * self._scale_factor
if self._add_offset != 0:
out = out + self._add_offset
out = out.astype(self._xds.dtype)
for iii in range(self.count):
out[iii] = out[iii] * dataset.scales[iii] + dataset.offsets[iii]
return out


Expand Down
43 changes: 33 additions & 10 deletions rioxarray/raster_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,35 +509,58 @@ def to_raster(
is True. Otherwise None is returned.
"""
# pylint: disable=too-many-locals
variable_dim = f"band_{uuid4()}"
data_array = self._obj.to_array(dim=variable_dim)
# ensure raster metadata preserved
scales = []
offsets = []
nodatavals = []
attr_scales = []
attr_offsets = []
attr_nodatavals = []
encoded_scales = []
encoded_offsets = []
encoded_nodatavals = []
band_tags = []
long_name = []
for data_var in data_array[variable_dim].values:
scales.append(self._obj[data_var].attrs.get("scale_factor", 1.0))
offsets.append(self._obj[data_var].attrs.get("add_offset", 0.0))
try:
encoded_scales.append(self._obj[data_var].encoding["scale_factor"])
except KeyError:
attr_scales.append(self._obj[data_var].attrs.get("scale_factor", 1.0))
try:
encoded_offsets.append(self._obj[data_var].encoding["add_offset"])
except KeyError:
attr_offsets.append(self._obj[data_var].attrs.get("add_offset", 0.0))
long_name.append(self._obj[data_var].attrs.get("long_name", data_var))
nodatavals.append(self._obj[data_var].rio.nodata)
if self._obj[data_var].rio.encoded_nodata is not None:
encoded_nodatavals.append(self._obj[data_var].rio.encoded_nodata)
else:
attr_nodatavals.append(self._obj[data_var].rio.nodata)
band_tags.append(self._obj[data_var].attrs.copy())
data_array.attrs["scales"] = scales
data_array.attrs["offsets"] = offsets
if encoded_scales:
data_array.encoding["scales"] = encoded_scales
else:
data_array.attrs["scales"] = attr_scales
if encoded_offsets:
data_array.encoding["offsets"] = encoded_offsets
else:
data_array.attrs["offsets"] = attr_offsets
data_array.attrs["band_tags"] = band_tags
data_array.attrs["long_name"] = long_name

use_encoded_nodatavals = bool(encoded_nodatavals)
nodatavals = encoded_nodatavals if use_encoded_nodatavals else attr_nodatavals
nodata = nodatavals[0]
if (
all(nodataval == nodata for nodataval in nodatavals)
or numpy.isnan(nodatavals).all()
):
data_array.rio.write_nodata(nodata, inplace=True)
data_array.rio.write_nodata(
nodata, inplace=True, encoded=use_encoded_nodatavals
)
else:
raise RioXarrayError(
"All nodata values must be the same when exporting to raster. "
f"Current values: {nodatavals}"
f"Current values: {attr_nodatavals}"
)
if self.crs is not None:
data_array.rio.write_crs(self.crs, inplace=True)
Expand Down
19 changes: 11 additions & 8 deletions rioxarray/raster_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,23 @@ def _write_metatata_to_raster(*, raster_handle, xarray_dataset, tags):
)

# write scales and offsets
try:
raster_handle.scales = tags["scales"]
except KeyError:
scales = tags.get("scales", xarray_dataset.encoding.get("scales"))
if scales is None:
scale_factor = tags.get(
"scale_factor", xarray_dataset.encoding.get("scale_factor")
)
if scale_factor is not None:
raster_handle.scales = (scale_factor,) * raster_handle.count
try:
raster_handle.offsets = tags["offsets"]
except KeyError:
scales = (scale_factor,) * raster_handle.count
if scales is not None:
raster_handle.scales = scales

offsets = tags.get("offsets", xarray_dataset.encoding.get("offsets"))
if offsets is None:
add_offset = tags.get("add_offset", xarray_dataset.encoding.get("add_offset"))
if add_offset is not None:
raster_handle.offsets = (add_offset,) * raster_handle.count
offsets = (add_offset,) * raster_handle.count
if offsets is not None:
raster_handle.offsets = offsets

_write_tags(raster_handle=raster_handle, tags=tags)
_write_band_description(raster_handle=raster_handle, xarray_dataset=xarray_dataset)
Expand Down
47 changes: 43 additions & 4 deletions test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,8 @@ def test_to_raster__offsets_and_scales(chunks, tmpdir):
tmp_raster = tmpdir.join("air_temp_offset.tif")

with rioxarray.open_rasterio(
os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"), chunks=chunks
os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"),
chunks=chunks,
) as rds:
rds = _ensure_dataset(rds)
attrs = dict(rds.air_temperature.attrs)
Expand All @@ -1795,6 +1796,38 @@ def test_to_raster__offsets_and_scales(chunks, tmpdir):
assert rds.rio.nodata == 32767.0


@pytest.mark.parametrize("mask_and_scale", [True, False])
def test_to_raster__scales__offsets(mask_and_scale, tmpdir):
tmp_raster = tmpdir.join("air_temp_offset.tif")

with rioxarray.open_rasterio(
os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"),
mask_and_scale=mask_and_scale,
) as rds:
rds = _ensure_dataset(rds)
rds["air_temperature_2"] = rds.air_temperature.copy()
if mask_and_scale:
rds.air_temperature_2.encoding["scale_factor"] = 0.2
rds.air_temperature_2.encoding["add_offset"] = 110.0
else:
rds.air_temperature_2.attrs["scale_factor"] = 0.2
rds.air_temperature_2.attrs["add_offset"] = 110.0
rds.squeeze(dim="band", drop=True).rio.to_raster(str(tmp_raster))

with rasterio.open(str(tmp_raster)) as rds:
assert rds.scales == (0.1, 0.2)
assert rds.offsets == (220.0, 110.0)

# test roundtrip
with rioxarray.open_rasterio(str(tmp_raster), mask_and_scale=mask_and_scale) as rds:
if mask_and_scale:
assert rds.encoding["scales"] == (0.1, 0.2)
assert rds.encoding["offsets"] == (220.0, 110.0)
else:
assert rds.attrs["scales"] == (0.1, 0.2)
assert rds.attrs["offsets"] == (220.0, 110.0)


def test_to_raster__custom_description__wrong(tmpdir):
tmp_raster = tmpdir.join("planet_3d_raster.tif")
with xarray.open_dataset(
Expand Down Expand Up @@ -1857,11 +1890,14 @@ def test_to_raster__dataset(tmpdir):
assert numpy.isnan(rdscompare.rio.nodata)


@pytest.mark.parametrize("mask_and_scale", [True, False])
@pytest.mark.parametrize("chunks", [True, None])
def test_to_raster__dataset__mask_and_scale(chunks, tmpdir):
def test_to_raster__dataset__mask_and_scale(chunks, mask_and_scale, tmpdir):
output_raster = tmpdir.join("tmmx_20190121.tif")
with rioxarray.open_rasterio(
os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"), chunks=chunks
os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"),
chunks=chunks,
mask_and_scale=mask_and_scale,
) as rds:
rds = _ensure_dataset(rds)
rds.isel(band=0).rio.to_raster(str(output_raster))
Expand All @@ -1871,7 +1907,10 @@ def test_to_raster__dataset__mask_and_scale(chunks, tmpdir):
assert rdscompare.add_offset == 220.0
assert rdscompare.long_name == "tmmx"
assert rdscompare.rio.crs == rds.rio.crs
assert rdscompare.rio.nodata == rds.air_temperature.rio.nodata
if mask_and_scale:
assert rdscompare.rio.nodata == rds.air_temperature.rio.encoded_nodata
else:
assert rdscompare.rio.nodata == rds.air_temperature.rio.nodata


def test_to_raster__dataset__different_crs(tmpdir):
Expand Down

0 comments on commit 6a29c36

Please sign in to comment.