Skip to content

Commit

Permalink
Merge pull request #85 from allenai/favyen/20241018-fix-geotiff-raster
Browse files Browse the repository at this point in the history
Fix error when loading all patches with windows that have bands at lower resolution
  • Loading branch information
favyen2 authored Nov 14, 2024
2 parents 595e8e4 + 7283f22 commit 3140a58
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
21 changes: 20 additions & 1 deletion rslearn/utils/raster_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,13 +375,32 @@ def decode_raster(
bounds[2] - offset[0],
bounds[3] - offset[1],
]

# Make sure the requested bounds intersects the raster, otherwise the
# windowed read may fail.
# This is generally unexpected but can happen if we are loading a patch
# of a window that is close to the edge of the window, and when we
# downsample it for a lower resolution raster (negative zoom offset) it
# ends up being out of bounds.
if (
relative_bounds[2] < 0
or relative_bounds[3] < 0
or relative_bounds[0] >= src.width
or relative_bounds[1] >= src.height
):
return None
logger.warning(
"GeotiffRasterFormat.decode_raster got request for a window %s "
+ "outside the raster (transform=%s)",
bounds,
transform,
)
# Assume all of the bands have the same dtype, so just use first
# one (src.dtypes is list of dtype per band).
return np.zeros(
(src.count, bounds[3] - bounds[1], bounds[2] - bounds[0]),
dtype=src.dtypes[0],
)

# Now get the actual pixels we will read, which must be contained in
# the GeoTIFF.
# Padding is (before_x, before_y, after_x, after_y) and will be used to
Expand Down
40 changes: 40 additions & 0 deletions tests/unit/utils/test_raster_format.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pathlib
from typing import Any

import numpy as np
import rasterio
from rasterio.crs import CRS
from upath import UPath

from rslearn.const import Projection
from rslearn.utils import raster_format
from rslearn.utils.raster_format import GeotiffRasterFormat


Expand Down Expand Up @@ -42,3 +44,41 @@ def test_geotiff_tiling(tmp_path: pathlib.Path) -> None:
with (path / "geotiff.tif").open("rb") as f:
with rasterio.open(f) as raster:
assert raster.profile["tiled"]


def test_geotiff_out_of_bounds(tmp_path: pathlib.Path, monkeypatch: Any) -> None:
# GeotiffRasterFormat should log warning but return zero array if we request a
# tile that is fully out of bounds.
# If it is partially out of bounds, it shouldn't warn and should just return the
# partial content.
path = UPath(tmp_path)
projection = Projection(CRS.from_epsg(3857), 1, -1)
format = GeotiffRasterFormat()

array = np.ones((1, 8, 8), dtype=np.uint8)
format.encode_raster(path, projection, (0, 0, 8, 8), array)

class TestLogger:
def __init__(self) -> None:
self.warned = False

def warning(self, *args: list[Any], **kwargs: dict[str, Any]) -> None:
self.warned = True

logger = TestLogger()
monkeypatch.setattr(raster_format, "logger", logger)
array = format.decode_raster(path, (2, 2, 6, 6))
assert array.shape == (1, 4, 4)
assert np.all(array == 1)
assert not logger.warned

array = format.decode_raster(path, (4, 4, 12, 12))
assert array.shape == (1, 8, 8)
assert np.all(array[:, 0:4, 0:4] == 1)
assert np.all(array[:, 0:8, 4:8] == 0)
assert not logger.warned

array = format.decode_raster(path, (8, 8, 12, 12))
assert array.shape == (1, 4, 4)
assert np.all(array == 0)
assert logger.warned

0 comments on commit 3140a58

Please sign in to comment.