Skip to content

Conversation

vschaffn
Copy link
Contributor

@vschaffn vschaffn commented Mar 6, 2025

Resolves #648.

Summary

This PR introduces the multiprocessing possibility for raster reprojection, improving memory efficiency and processing speed when performing reprojection by leveraging tiled processing and multiprocessing.

Features

  • Implements _multiproc_reproject for tiled raster reprojection with multiprocessing and direct disk writing.
  • Adds _wrapper_multiproc_reproject_per_block and _reproject_per_block to enable independent per-block reprojection using Dask structures.
  • Introduces distributing_computed in raster.__init__.

Changes

  • Fixes reprojection by correctly passing GDAL warp options: XSCALE=1, YSCALE=1, tolerance=0.
  • Refactors Dask-based reprojection into modular functions used by _multiproc_reproject.
  • Enables multi-band reprojection within Dask and multiprocessing workflows.
  • Extracts the writing logic from map_overlap_multiproc_save for reuse in _multiproc_reproject.
  • Integrates multiproc_reproject into Raster.reproject and _reproject, reorganizing internal logic to better support multiprocessing.
  • Applies minor fixes and cleanup in multiproc.py.

Tests

  • Adds test cases for multiprocessing reprojection, covering single and multi-band rasters with bounds, resolution and crs changes, and assert that the raster is never fully loaded during the process of reprojection.
  • Adds a test for map_multiproc_collect with tile coordinate return.
  • Temporarily skips test_multiproc_reproject (pending next rasterio release).
    image

Documentation

  • Adds a note about out-of-memory reprojection in raster_class.md.
  • Adds MultiprocConfig in the API.
  • Fixes imports and structure in multiprocessing.md.

@rhugonnet
Copy link
Member

rhugonnet commented Mar 10, 2025

Hi @vschaffn @adebardo, nice to see the progress!

Summary: Upon reading the code in delayed_multiproc, I'm fairly sure that in many cases the reproject won't work, and that interp might work but won't be computationally optimal.

Why?
For reproject, it's due to shape deformation during CRS reprojection, you can't work simply on extents. You need the equivalent of the GeoGrid classes, a vectorized format of the grid shape. Because of the nature of the problem, writing these steps as embedded functions would probably take three times the code size including many loops and be hard to grasp, which is why I wrote it this way (and they also did in other packages, like OpenDataCube-Geo).
For interp, it's because loading many tiny windows will be extremely slow. It's better to load a whole block that fits in memory, and perform interpolation on all the points of this block at once.

Moving forward:
Joining my previous comment (#655 (comment)), I suggest to simply re-use all the internal steps of dask_reproject and dask_interp functions that were already developed and extensively tested, and only replace the bits about loading/tiling that work differently without Dask and using Multiprocessing instead.
For example, in dask_reproject, nothing uses Dask until this line:

blocks = darr.to_delayed().ravel()

So all the projected grid matching calculations to map the input chunks to output chunks written above that line (using the GeoGrid classes I mentioned) can simply stay the same 😉

@rhugonnet
Copy link
Member

Also note that Rasterio has reprojection errors that can appear (especially when comparing a tiled reprojection to a reprojection on the full raster), due to internal inconsistencies...
But those look stripped, like this: rasterio/rasterio#2052 (comment)

@rhugonnet
Copy link
Member

To clarify one more remark: In multiprocessing_reproject(), you'd have to slightly change the logic you currently have within your tile loop (line 404).

Due to shape deformations, the mapping of input-output tiles between any CRS requires more than an overlap value, and also depends on the input/output chunksizes (what the GeoGrid code mentioned above does).
So you'll have to loop for each output tile (always 1 by 1), and every time open a combination of input tiles (could be 1, 2, 3, 4, 5, ...).

You can essentially turn this list comprehension here into your loop for multiprocessing_reproject:

# Run the delayed reprojection, looping for each destination block

And let reproject_block stick the pieces of input tiles in the right places (you might not have a perfect square with all the tiles you opened), and give you the output:
def _delayed_reproject_per_block(

@rhugonnet
Copy link
Member

I should have thought about this last week (I was mostly off work and only popped for the comment above, so I didn't think about the big picture, sorry!):
As I just mentioned in GlacioHack/xdem#699, the current code for reproject developed by @vschaffn will still be extremely useful, as it is essentially the multiprocessing equivalent of dask.map_overlap that we need absolutely everywhere (all terrain attributes such as the slope, all kernel-based filters, re-gridding in Coreg.apply(), and much more 😊).

It is only reproject that is a tricky exception! 😅

@vschaffn vschaffn force-pushed the 648-multiprocessing2 branch from 9b7fe6c to 06fbcfc Compare March 28, 2025 14:07
@rhugonnet
Copy link
Member

rhugonnet commented Mar 29, 2025

I looked at the Dask code again, and it didn't look too hard to adapt for Multiprocessing when already familiar with the underlying functions, so I gave it a go to save you time 😉.

It's only about 10 lines of new or adapted code (ignoring the writing part of the Multiprocessing that we write the same as in #669). The rest is an exact copy of the Dask code.

The most critical aspect was to call dst_block_ids = np.array(dst_geotiling.get_block_locations()) to easily map the output grid before writing, and to pass the existing src_block_ids to a wrapper reproject-block function to directly get the coordinates to read chunks with Raster.icrop.

Here it is, with comments explaining the changes with the Dask code:

### (NO CHANGE) BLOCK FUNCTION FULLY COPIED FROM THE DASK ONE ABOVE, JUST WITHOUT THE @DELAYED DECORATOR
def _multiproc_reproject_per_block(
    *src_arrs: tuple[NDArrayNum], block_ids: list[dict[str, int]], combined_meta: dict[str, Any], **kwargs: Any
) -> NDArrayNum:
    """
    Delayed reprojection per destination block (also rebuilds a square array combined from intersecting source blocks).
    """

    # If no source chunk intersects, we return a chunk of destination nodata values
    if len(src_arrs) == 0:
        # We can use float32 to return NaN, will be cast to other floating type later if that's not source array dtype
        dst_arr = np.zeros(combined_meta["dst_shape"], dtype=np.dtype("float32"))
        dst_arr[:] = kwargs["dst_nodata"]
        return dst_arr

    # First, we build an empty array with the combined shape, only with nodata values
    comb_src_arr = np.ones((combined_meta["src_shape"]), dtype=src_arrs[0].dtype)
    comb_src_arr[:] = kwargs["src_nodata"]

    # Then fill it with the source chunks values
    for i, arr in enumerate(src_arrs):
        bid = block_ids[i]
        comb_src_arr[bid["rys"] : bid["rye"], bid["rxs"] : bid["rxe"]] = arr

    # Now, we can simply call Rasterio!

    # We build the combined transform from tuple
    src_transform = rio.transform.Affine(*combined_meta["src_transform"])
    dst_transform = rio.transform.Affine(*combined_meta["dst_transform"])

    # Reproject
    dst_arr = np.zeros(combined_meta["dst_shape"], dtype=comb_src_arr.dtype)

    _ = rio.warp.reproject(
        comb_src_arr,
        dst_arr,
        src_transform=src_transform,
        src_crs=kwargs["src_crs"],
        dst_transform=dst_transform,
        dst_crs=kwargs["dst_crs"],
        resampling=kwargs["resampling"],
        src_nodata=kwargs["src_nodata"],
        dst_nodata=kwargs["dst_nodata"],
        num_threads=1,  # Force the number of threads to 1 to avoid Dask/Rasterio conflicting on multi-threading
    )

    return dst_arr

### (NEW WRAPPER) TO RUN BLOCK FUNCTION FOR MULTIPROC CALL: ADD READING FOR RASTER BLOCK + RETURN DESTINATION BLOCK ID
def _wrapper_multiproc_reproject_per_block(rst, src_block_ids, dst_block_id, idx_d2s, block_ids, combined_meta, kwargs):
    """Wrapper to use reproject_per_block for multiprocessing."""

    # Get source array block for each destination block
    s = src_block_ids
    src_arrs = (rst.icrop(bbox=(s[idx]["xs"], s[idx]["ys"], s[idx]["xe"], s[idx]["ye"])).data for idx in idx_d2s)

    # Call reproject per block
    dst_block_arr = _multiproc_reproject_per_block(*src_arrs, block_ids=block_ids, combined_meta=combined_meta, **kwargs)

    return dst_block_arr, dst_block_id

### (SMALL CHANGES, MOSTLY AT THE END) FINAL REPROJECT FUNCTION
def reproject_multiproc(
    rst: Any,  # NEW INPUT FOR MULTIPROC (INSTEAD OF DASK ARRAY)
    config: Any,  # NEW INPUT FOR MULTIPROC
    src_chunksizes: tuple[int, int],  # NEW INPUT FOR MULTIPROC (INSTEAD OF BEING SAVED WITHIN DASK ARRAY)
    src_transform: rio.transform.Affine,
    src_crs: rio.crs.CRS,
    dst_transform: rio.transform.Affine,
    dst_shape: tuple[int, int],
    dst_crs: rio.crs.CRS,
    resampling: rio.enums.Resampling,
    src_nodata: int | float | None = None,
    dst_nodata: int | float | None = None,
    dst_chunksizes: tuple[int, int] | None = None,
    **kwargs: Any,
) -> None:
    """
    Reproject georeferenced raster on out-of-memory chunks with multiprocessing
    """

    # 1/ Define source and destination chunked georeferenced grid through simple classes storing CRS/transform/shape,
    # which allow to consistently derive shape/transform for each block and their CRS-projected footprints

    # Define georeferenced grids for source/destination array
    src_geogrid = GeoGrid(transform=src_transform, shape=rst.shape, crs=src_crs)
    dst_geogrid = GeoGrid(transform=dst_transform, shape=dst_shape, crs=dst_crs)

    # Add the chunking
    # For source, we can use the .chunks attribute

    # SMALL CHANGE: GENERATE THE TUPLES OF CHUNKS SIMILARLY AS FOR DASK, BASED ON CHUNKSIZE INPUT
    chunks_x = tuple((src_chunksizes[0] if i<=rst.shape[0] else rst.shape[0] % src_chunksizes[0])
                     for i in np.arange(src_chunksizes[0], rst.shape[0] + src_chunksizes[0], src_chunksizes[0]))
    chunks_y = tuple((src_chunksizes[1] if i<=rst.shape[1] else rst.shape[1] % src_chunksizes[1])
                     for i in np.arange(src_chunksizes[1], rst.shape[1] + src_chunksizes[1], src_chunksizes[1]))
    src_chunks = (chunks_x, chunks_y)

    src_geotiling = ChunkedGeoGrid(grid=src_geogrid, chunks=src_chunks)

    # For destination, we need to create the chunks based on destination chunksizes
    if dst_chunksizes is None:
        dst_chunksizes = src_chunksizes
    dst_chunks = _chunks2d_from_chunksizes_shape(chunksizes=dst_chunksizes, shape=dst_shape)
    dst_geotiling = ChunkedGeoGrid(grid=dst_geogrid, chunks=dst_chunks)

    # 2/ Get footprints of tiles in CRS of destination array, with a buffer of 2 pixels for destination ones to ensure
    # overlap, then map indexes of source blocks that intersect a given destination block
    src_footprints = src_geotiling.get_block_footprints(crs=dst_crs)
    dst_footprints = dst_geotiling.get_block_footprints().buffer(2 * max(dst_geogrid.res))
    dest2source = [np.where(dst.intersects(src_footprints).values)[0] for dst in dst_footprints]

    # 3/ To reconstruct a square source array during chunked reprojection, we need to derive the combined shape and
    # transform of each tuples of source blocks
    src_block_ids = np.array(src_geotiling.get_block_locations())
    meta_params = [
        (
            _combined_blocks_shape_transform(sub_block_ids=src_block_ids[sbid], src_geogrid=src_geogrid)  # type: ignore
            if len(sbid) > 0
            else ({}, [])
        )
        for sbid in dest2source
    ]
    # We also add the output transform/shape for this destination chunk in the combined meta
    # (those are the only two that are chunk-specific)
    dst_block_geogrids = dst_geotiling.get_blocks_as_geogrids()
    for i, (c, _) in enumerate(meta_params):
        c.update({"dst_shape": dst_block_geogrids[i].shape, "dst_transform": tuple(dst_block_geogrids[i].transform)})

    # 4/ Call a delayed function that uses rio.warp to reproject the combined source block(s) to each destination block

    # Add fixed arguments to keywords
    kwargs.update(
        {
            "src_nodata": src_nodata,
            "dst_nodata": dst_nodata,
            "resampling": resampling,
            "src_crs": src_crs,
            "dst_crs": dst_crs,
        }
    )

    ### FROM HERE: ADAPTED CODE FOR MULTIPROC

    # SMALL CHANGE: RETURN BLOCK LOCATIONS TO EASILY WRITE OUTPUT
    # (WASN'T NEEDED FOR DASK THAT JUST CONCATENATED BLOCKS IN THE RIGHT ORDER)
    # Get location of destination blocks to write file
    dst_block_ids = np.array(dst_geotiling.get_block_locations())

    # MODIFY DASK LIST COMPREHENSION INTO LOOP TO LAUNCH TASKS
    # ADDING SRC BLOCK IDS TO THE WRAPPER CALL TO LOAD WITH ICROP
    # Create tasks for multiprocessing
    tasks = []
    for i in range(len(dest2source)):
        tasks.append(
            config.cluster.launch_task(
                fun=_wrapper_multiproc_reproject_per_block, args=[rst, src_block_ids, dst_block_ids[i], dest2source[i], meta_params[i][1], meta_params[i][0], kwargs],
            )
        )

    result_list = []
    # get first tile to retrieve dtype and nodata
    result_tile0, _ = config.cluster.get_res(tasks[0])

    # WRITE OUTPUT AS IN GENERIC MULTIPROC PR
    # Create a new raster file to save the processed results
    with rio.open(
            config.outfile,
            "w",
            driver="GTiff",
            height=dst_shape[0],
            width=dst_shape[1],
            count=1,
            dtype=rst.dtype,
            crs=dst_crs,
            transform=dst_transform,
            nodata=dst_nodata,
    ) as dst:
        try:
            # Iterate over the tasks and retrieve the processed tiles
            for results in tasks:

                dst_block_arr, dst_block_id = config.cluster.get_res(results)

                # Define the window in the output file where the tile should be written
                dst_window = rio.windows.Window(
                    col_off=dst_block_id["xs"],
                    row_off=dst_block_id["ys"],
                    width=dst_block_id["xe"] - dst_block_id["xs"],
                    height=dst_block_id["ye"] - dst_block_id["ys"],
                )

                # Cast to 3D
                dst_block_arr = dst_block_arr[np.newaxis, :, :]

                # Write the processed tile to the appropriate location in the output file
                dst.write(dst_block_arr, window=dst_window)
        except Exception as e:
            raise RuntimeError(f"Error retrieving data from multiprocessing tasks: {e}")

@rhugonnet
Copy link
Member

rhugonnet commented Mar 29, 2025

It runs but I didn't test it thoroughly, so hopefully I didn't add any bug and it works well quickly 😉.
We should be able to easily adjust existing tests (the ones of this PR or in test_delayed) to check that.

Also, I should justify why delayed_reproject() (and also this multiproc one) don't have the same function call as Raster.reproject():
It's because the idea was, like for DEM.slope(), to add them within _reproject's last step directly once #446 is merged, here:

# 4/ Perform reprojection

This way, they don't have to go through the _user_input_reproject and _get_reproj_params steps which stay consistent no matter if the method runs on the full array, or on chunks of the array.

I think we should do the same with the final multiproc function of this PR 🙂

@rhugonnet
Copy link
Member

Also, if you want to test the internals of this reproject function (to get a sense of the variables dest2source or dst_block_ids, which are hard to grasp by just reading the code), you can uncomment the following test lines here + run them and all other inputs lines below (until comment starting with "2/"):

# Keeping this commented here if we need to redo local tests due to Rasterio errors

This will define all the arguments needed to run the reproject. If testing the multiproc function and not the delayed one, the darr needs to be replaced by a Raster input, and the src_transform/crs by the Raster.transform/crs.

@vschaffn
Copy link
Contributor Author

vschaffn commented Apr 2, 2025

@rhugonnet many thanks for your work and your explications.
However, using your version of reproject_multiproc, I still have some small shifts when reprojecting on a new crs.

Here is my test code:

import matplotlib.pyplot as plt
import rasterio as rio

import geoutils as gu
from geoutils import Raster
from geoutils.raster.distributed_computing import MultiprocConfig
from geoutils.raster.distributed_computing.delayed_multiproc import reproject_multiproc

example = gu.examples.get_path("exploradores_aster_dem")
outfile = "test.tif"
config = MultiprocConfig(chunk_size=200, outfile=outfile)

r = Raster(example)

# - Test reprojection with CRS change -
out_crs = rio.crs.CRS.from_epsg(4326)

# Single-process reprojection
r_single = r.reproject(crs=out_crs)

# Multiprocessing reprojection
reproject_multiproc(r, config, crs=out_crs)
r_multi = Raster(outfile)

plt.figure()
r_single.plot()
plt.figure()
r_multi.plot()
plt.figure()
diff1 = r_single - r_multi
diff1.plot()
plt.show()

Here is the reproject_multiproc I used (same as yours with some adaptations for the inputs):

### (SMALL CHANGES, MOSTLY AT THE END) FINAL REPROJECT FUNCTION
def reproject_multiproc(
    rst: RasterType,
    config: MultiprocConfig,
    ref: RasterType | str | None = None,
    crs: CRS | str | int | None = None,
    res: float | abc.Iterable[float] | None = None,
    grid_size: tuple[int, int] | None = None,
    bounds: rio.coords.BoundingBox | None = None,
    nodata: int | float | None = None,
    dtype: DTypeLike | None = None,
    resampling: Resampling | str = Resampling.bilinear,
    force_source_nodata: int | float | None = None,
    **kwargs: Any,
) -> None:
    """
    Reproject georeferenced raster on out-of-memory chunks with multiprocessing
    """
    # Process user inputs
    dst_crs, dst_dtype, src_nodata, dst_nodata, dst_res, dst_bounds = _user_input_reproject(
        source_raster=rst,
        ref=ref,
        crs=crs,
        bounds=bounds,
        res=res,
        nodata=nodata,
        dtype=dtype,
        force_source_nodata=force_source_nodata,
    )

    # Retrieve transform and grid_size
    dst_transform, dst_grid_size = _get_target_georeferenced_grid(
        rst, crs=dst_crs, grid_size=grid_size, res=dst_res, bounds=dst_bounds
    )
    dst_width, dst_height = dst_grid_size
    dst_shape = (dst_height, dst_width)

    # 1/ Define source and destination chunked georeferenced grid through simple classes storing CRS/transform/shape,
    # which allow to consistently derive shape/transform for each block and their CRS-projected footprints

    # Define georeferenced grids for source/destination array
    src_geogrid = GeoGrid(transform=rst.transform, shape=rst.shape, crs=rst.crs)
    dst_geogrid = GeoGrid(transform=dst_transform, shape=dst_shape, crs=dst_crs)

    # Add the chunking
    # For source, we can use the .chunks attribute

    # SMALL CHANGE: GENERATE THE TUPLES OF CHUNKS SIMILARLY AS FOR DASK, BASED ON CHUNKSIZE INPUT
    chunks_x = tuple((config.chunk_size if i<=rst.shape[0] else rst.shape[0] % config.chunk_size)
                     for i in np.arange(config.chunk_size, rst.shape[0] + config.chunk_size, config.chunk_size))
    chunks_y = tuple((config.chunk_size if i<=rst.shape[1] else rst.shape[1] % config.chunk_size)
                     for i in np.arange(config.chunk_size, rst.shape[1] + config.chunk_size, config.chunk_size))
    src_chunks = (chunks_x, chunks_y)

    src_geotiling = ChunkedGeoGrid(grid=src_geogrid, chunks=src_chunks)

    # For destination, we need to create the chunks based on destination chunksizes
    dst_chunks = _chunks2d_from_chunksizes_shape(chunksizes=(config.chunk_size, config.chunk_size), shape=dst_shape)
    dst_geotiling = ChunkedGeoGrid(grid=dst_geogrid, chunks=dst_chunks)

    # 2/ Get footprints of tiles in CRS of destination array, with a buffer of 2 pixels for destination ones to ensure
    # overlap, then map indexes of source blocks that intersect a given destination block
    src_footprints = src_geotiling.get_block_footprints(crs=dst_crs)
    dst_footprints = dst_geotiling.get_block_footprints().buffer(2 * max(dst_geogrid.res))
    dest2source = [np.where(dst.intersects(src_footprints).values)[0] for dst in dst_footprints]

    # 3/ To reconstruct a square source array during chunked reprojection, we need to derive the combined shape and
    # transform of each tuples of source blocks
    src_block_ids = np.array(src_geotiling.get_block_locations())
    meta_params = [
        (
            _combined_blocks_shape_transform(sub_block_ids=src_block_ids[sbid], src_geogrid=src_geogrid)  # type: ignore
            if len(sbid) > 0
            else ({}, [])
        )
        for sbid in dest2source
    ]
    # We also add the output transform/shape for this destination chunk in the combined meta
    # (those are the only two that are chunk-specific)
    dst_block_geogrids = dst_geotiling.get_blocks_as_geogrids()
    for i, (c, _) in enumerate(meta_params):
        c.update({"dst_shape": dst_block_geogrids[i].shape, "dst_transform": tuple(dst_block_geogrids[i].transform)})

    # 4/ Call a delayed function that uses rio.warp to reproject the combined source block(s) to each destination block

    # Add fixed arguments to keywords
    kwargs.update(
        {
            "src_nodata": src_nodata,
            "dst_nodata": dst_nodata,
            "resampling": resampling,
            "src_crs": rst.crs,
            "dst_crs": dst_crs,
        }
    )

    ### FROM HERE: ADAPTED CODE FOR MULTIPROC

    # SMALL CHANGE: RETURN BLOCK LOCATIONS TO EASILY WRITE OUTPUT
    # (WASN'T NEEDED FOR DASK THAT JUST CONCATENATED BLOCKS IN THE RIGHT ORDER)
    # Get location of destination blocks to write file
    dst_block_ids = np.array(dst_geotiling.get_block_locations())

    # MODIFY DASK LIST COMPREHENSION INTO LOOP TO LAUNCH TASKS
    # ADDING SRC BLOCK IDS TO THE WRAPPER CALL TO LOAD WITH ICROP
    # Create tasks for multiprocessing
    tasks = []
    for i in range(len(dest2source)):
        tasks.append(
            config.cluster.launch_task(
                fun=_wrapper_multiproc_reproject_per_block, args=[rst, src_block_ids, dst_block_ids[i], dest2source[i], meta_params[i][1], meta_params[i][0], kwargs],
            )
        )

    result_list = []
    # get first tile to retrieve dtype and nodata
    result_tile0, _ = config.cluster.get_res(tasks[0])

    # WRITE OUTPUT AS IN GENERIC MULTIPROC PR
    # Create a new raster file to save the processed results
    with rio.open(
            config.outfile,
            "w",
            driver="GTiff",
            height=dst_height,
            width=dst_width,
            count=1,
            dtype=rst.dtype,
            crs=dst_crs,
            transform=dst_transform,
            nodata=dst_nodata,
    ) as dst:
        try:
            # Iterate over the tasks and retrieve the processed tiles
            for results in tasks:

                dst_block_arr, dst_block_id = config.cluster.get_res(results)

                # Define the window in the output file where the tile should be written
                dst_window = rio.windows.Window(
                    col_off=dst_block_id["xs"],
                    row_off=dst_block_id["ys"],
                    width=dst_block_id["xe"] - dst_block_id["xs"],
                    height=dst_block_id["ye"] - dst_block_id["ys"],
                )

                # Cast to 3D
                dst_block_arr = dst_block_arr[np.newaxis, :, :]

                # Write the processed tile to the appropriate location in the output file
                dst.write(dst_block_arr, window=dst_window)
        except Exception as e:
            raise RuntimeError(f"Error retrieving data from multiprocessing tasks: {e}")

Here is the difference plot:
image

@rhugonnet
Copy link
Member

@vschaffn I can't reproduce this, my diff plot is as large as the DEMs, looks like the tiles might be inverted. Did you fix something in my code that is not above?

test

@vschaffn
Copy link
Contributor Author

vschaffn commented Apr 2, 2025

@rhugonnet it certainly comes from the icrop, as I fixed the order of the coordinates in the last PR which was merged yesterday (#660), you should probably rebase your branch on it 😄

@rhugonnet
Copy link
Member

Thanks, I forgot, that fixed it! 😉

OK, so I've investigated a bit.

First, I ran tests between the _reproject_multiproc and rio.warp.reproject directly (same as in test_delayed), to ensure no error is tied to our steps in georeferencing subfunctions. It is not, exact same difference with this test.

Then, I ran tests between the _reproject_multiproc and _reproject_dask (delayed_reproject): Exact same result between both, so we didn't introduce any differences with the multiprocessing icrop logic.
For both, there is no error for a same-CRS reprojection, only for differing ones.

test2_multiproc
test2_dask

So I think the cause is very likely the Rasterio errors I mention above in #661 (comment), even though they didn't take this shape last time I tested it, but maybe it's because I used random placeholder data and not DEM data. Here it looks dependent on the gradient...

To test if this is indeed the case or an error in our implementation, we should try to run gdal.Reproject() under-the-hood instead of rio.warp, the same way we do it here: https://github.com/GlacioHack/xdem-data/blob/98004a09f84def4c78b253d41b212baca2b3cccb/generate_ground_truth_gdal.py#L54

@rhugonnet
Copy link
Member

Looking at it right now...

@rhugonnet
Copy link
Member

It's indeed an error in Rasterio, here's the same difference running gdal.Reproject() instead of rio.warp.reproject() (on the full raster, and in the function _multiproc_reproject_per_block):

test2_multiproc_gdal

Here's the code to do a 1-on-1 replacement between the two reproject:

def _reproject_gdal(src_arr, src_transform, src_crs, dst_transform, dst_shape, dst_crs, resampling, src_nodata, dst_nodata):
    from osgeo import gdal, gdalconst

    resampling_mapping = {"nearest":  gdalconst.GRA_NearestNeighbour, "bilinear": gdalconst.GRA_Bilinear,
                  "cubic": gdalconst.GRA_Cubic, "cubic_spline": gdalconst.GRA_CubicSpline}

    gdal_resampling = resampling_mapping[resampling]

    def _rio_to_gdal_geotransform(gt: rio.Affine) -> tuple[float, ...]:

        dx, b, xmin, d, dy, ymax = list(gt)[:6]

        gdal_gt = (xmin, dx, 0, ymax, 0, dy)

        return gdal_gt

    # Create input raster from array shape
    drv = gdal.GetDriverByName('MEM')
    src_ds = drv.Create('', src_arr.shape[1], src_arr.shape[0], 1, gdal.GDT_Float32)
    # Create, define and set projection
    src_ds.SetProjection(src_crs.to_wkt())
    # Convert and set geotransform
    src_gt = _rio_to_gdal_geotransform(src_transform)
    src_ds.SetGeoTransform(src_gt)
    # Write array and nodata value on first band
    src_band = src_ds.GetRasterBand(1)
    src_band.WriteArray(src_arr)
    src_band.SetNoDataValue(src_nodata)

    # Create output raster from destination shape
    dst_ds = drv.Create("", dst_shape[1], dst_shape[0], 1, gdal.GDT_Float32)
    # Create output projection
    dst_ds.SetProjection(dst_crs.to_wkt())
    dst_gt = _rio_to_gdal_geotransform(dst_transform)
    dst_ds.SetGeoTransform(dst_gt)

    # Copy the raster metadata of the source to dest
    dst_ds.GetRasterBand(1).SetNoDataValue(dst_nodata)
    dst_ds.GetRasterBand(1).Fill(dst_nodata)

    # Reproject with resampling
    gdal.ReprojectImage(src_ds, dst_ds, src_crs.to_wkt(), dst_crs.to_wkt(), gdal_resampling)

    # Extract reprojected array
    array = dst_ds.GetRasterBand(1).ReadAsArray().astype("float32")
    array[array == dst_nodata] = np.nan

    return array

@rhugonnet
Copy link
Member

Funnily, there's a tiny area at the bottom left that still has some differences...

@rhugonnet
Copy link
Member

rhugonnet commented Apr 2, 2025

Actually the artefact above appears for the destination chunk after (600, XXX).
Testing other destination CRSs, like South Polar Stereo (EPSG:3412), there are also differences appearing even with GDAL, and that follow the tiling of (200, 200)...

test2

I wonder if it's possible to test GDAL against itself (tiled versus all raster at once), to ensure that these last errors are from us.

Maybe the issue is the precision of the dst_transform and src_transform in each chunk... Rounding problems or something, maybe around there:

if distance_unit == "pixel":

To be sure of this, we could try with placeholder data where we control the transform...

@rhugonnet
Copy link
Member

Reminds me a bit of #357 and #354.

@rhugonnet
Copy link
Member

Tried to tweak floating point precision, without success...

I got me thinking: In the tiled Rasterio vs full Rasterio, maybe one result is correct, and not the other.
GDAL is definitely more tested/robust, and given its (almost) consistent results for the exact same inputs, it would make sense that it is the correct one.

Interestingly, it's the Rasterio tiled reprojection (Multiprocessing or Dask) that it very close to GDAL:

test2_gdal_all_rio_block

And the full-array Rasterio reprojection that is very different:

test2_gdal_rio_all

@rhugonnet
Copy link
Member

rhugonnet commented Apr 2, 2025

Some potential reasons: OSGeo/gdal#2810
See explanations here too: OSGeo/gdal#1620 (comment)

We could use gdal.Warp and pass more thorough gdal.WarpOptions to ensure that it's not default parameters changing depending on tile location that create those artifacts!
In particular the tap option mentioned in the issues above, which in gdal.WarpOptions is: targetAlignedPixels -- whether to force output bounds to be multiple of output resolution
https://gdal.org/en/stable/api/python/utilities.html#osgeo.gdal.WarpOptions

We could also open an issue in GDAL to get feedback.

@rhugonnet
Copy link
Member

ALELUJAH, GOT IT! 🥳 🎈

Two issues were the cause:
1/ GDAL automatically computes internal "XSCALE" and "YSCALE" factors when resampling, which depend on array size. That creates the problem when chunking!
This is explained in the documentation of the C++ WarpOptions here (lines XSCALE/YSCALE): https://gdal.org/en/stable/api/gdalwarp_cpp.html#_CPPv415GDALWarpOptions

2/ GDAL uses different PROJ transformations based on the error threshold passed by the user. When using gdal.ReprojectImage as above, it defaults to an error threshold of 0, which is why it gives a different (better) result than rio.warp.reproject()!
But, most importantly, gdal.Warp without the option errorThreshold defined (see Python API https://gdal.org/en/stable/api/python/utilities.html#osgeo.gdal.WarpOptions) actually gives the same result as rio.warp.reproject(), which explains the discrepancy we had between the two. It's simply because they weren't using the same transformation.

Now, combining the two issues above, we can solve all of our problems.
We can define a GDAL reprojection without inconsistent transformations, and without reprojection scaling effects on smaller chunks.
Here's the code replace the gdal.ReprojectImage line above:

warp_opts = {"XSCALE":1, "YSCALE": 1}
opts = gdal.WarpOptions(srcSRS=src_crs.to_wkt(), dstSRS=dst_crs.to_wkt(), 
    resampleAlg=gdal_resampling, errorThreshold=0, warpOptions=warp_opts)
gdal.Warp(dst_ds, src_ds, options=opts)

Now with this, we have a perfect match for both EPSG:4326 (that had the lower left corner wrong) and EPSG:3412 (that was completely off) shown above:

test2_gdal_epsg4326
test2_gdal_epsg3412

And, for our purpose, we need to do the same with Rasterio.
Thankfully this is possible, by passing kwargs arguments to rio.warp.reproject() which are passed to gdal.WarpOptions, see: https://rasterio.readthedocs.io/en/latest/api/rasterio.warp.html#rasterio.warp.reproject.
Should be able to make it work with Rasterio too tomorrow! 😁

@rhugonnet
Copy link
Member

Looks like OpenDataCube came to the same conclusion: opendatacube/datacube-core#1456
Also the linked issue in Rasterio: rasterio/rasterio#2995

@rhugonnet
Copy link
Member

OK, so passing XSCALE and YSCALE to GDAL through Rasterio works well.
However, the second factor that creates inconsistencies between chunked reprojection and full-array reprojection, the error_threshold, cannot be parametrized in Rasterio.
Its value is a constant, defined here: https://github.com/rasterio/rasterio/blob/b59373148294a6b3a729397c5828e28ab1fd4d53/rasterio/_warp.pyx#L324

I'll open a PR to make it user-defined.

@vschaffn
Copy link
Contributor Author

vschaffn commented Apr 3, 2025

@rhugonnet Many thanks for your investigations and your work, It would have taken me weeks to track down the source of the problem 😄. I saw that you opened the PR on rasterio rasterio/rasterio#3325, I hope it will be validated soon 🤞
To sum up what remains to be done:

  • Add the warp options {"XSCALE": 1, "YSCALE": 1, "tolerance": 0} to the rasterio reprojection for both tiled and classic reprojection
  • Add tests for the tiled reprojection
  • Add GDAL reprojection ground truth in xdem-data to test it against rasterio ?

Does that sound right ?

@rhugonnet
Copy link
Member

rhugonnet commented Apr 3, 2025

@vschaffn Yes, exactly! 😁
We shouldn't wait for this PR to be merged/released in Rasterio, so I think we can just do something along the lines of:

kwargs = {"XSCALE": 1, "YSCALE": 1}
from packaging.version import Version
if Version(rasterio.__version__) > Version("1.4.3"):
    kwargs.update({"tolerance": 0})
_reproject_rio(..., **kwargs)

For the tests: We can probably mirror a lot of stuff from here. And it's a good idea to add real data like you did (it's only synthetic in the Dask tests I think).

For the GDAL code: As Rasterio directly calls GDAL.Warp, it's supposed to be a 1 on 1 (were it not for that option!), so maybe less useful here than for other cases like in xDEM, where we compare GDAL to SciPy or to our own terrain function. Yet GeoUtils does still use GDAL for one other test: Raster.proximity() here, so it could be the occasion to move everything at once to geoutils-data like we did in xdem-data (but it's not a priority, we can also just open an issue on it).

And I see also these minor changes:

  • Adding multiproc_config to Raster.reproject() and add _reproject_multiproc within the last step _reproject(), triggered when the multiproc_config argument is passed,
  • Restructuring delayed_reproject into subfunction(s) (around the steps 1-2-3/ re-used in both) to minimize copy/pasting, have shorter code and everything consistent between Dask and Multiproc implementations!
  • Maybe also restructuring the reprojection step of _reproject() into a _reproject_rio() subfunction that we can re-use in the _multiproc_reproject_per_block and _delayed_reproject_per_block, to avoid repeating the same call everywhere (and define the XSCALE/YSCALE arguments only there). 😉

@rhugonnet
Copy link
Member

Ah and I just thought that we can't make the tests np.allclose(reproj_arr_chunked, reproj_arr_full) pass without being on the Rasterio branch from my PR. 😅
I guess we should still write the tests with np.allclose to be rigorous, check they pass locally while on the Rasterio branch with tolerance=0, and add a @pytest.mark.skip() depending on the Rasterio version so that they trigger in the CI in the future.

@vschaffn vschaffn force-pushed the 648-multiprocessing2 branch from 06fbcfc to 33a890c Compare April 4, 2025 16:31
@vschaffn
Copy link
Contributor Author

vschaffn commented Apr 7, 2025

@rhugonnet As discussed I have adapted the structure to reduce at most duplication between multiprocessing, dask and reproject, and the multiprocessing is now available in Raster.reproject. I have updated the PR descriptions with all the changes 😄

@rhugonnet
Copy link
Member

Amazing!! 🥳 🎈
A big piece of work! I really like the MultiprocConfig design we converged on 😄

Great idea to add support for multi-band, this is extremely useful for reprojecting large raster stacks that are increasingly common 😉 (we might even have to allow the user to specify a "band" chunksize at some point! this can come later), and to verify this thoroughly in the tests!

@adebardo
Copy link
Contributor

adebardo commented Apr 8, 2025

The work you both have done is impressive it's truly a privilege to witness it live. I don't feel comfortable merging this ticket myself. @rhugonnet, it's up to you. Do it whenever you feel ready.

@rhugonnet
Copy link
Member

I think it's almost ready to merge! 😄

My last thought is that, like for #669, we should probably wait to show the changes in the documentation until we have finished converging on an API internally, so that it remains stable for users. For instance, we just realized now about the driver argument, and we might find many more little changes/additions while using this on the xDEM side.
We can also document once we've added the (much easier!) cases of interp_points, subsample, and and other functions that can be wrapped out-of-the-box with map_overlap like the get_terrain_attribute in xDEM (this likely includes rasterize, create_mask, and maybe proximity()?).

We can keep the nice additions to the /doc/ (api.md and raster.md) commented and wait to release the new doc all at once later 😉

@rhugonnet
Copy link
Member

A new release documenting all of these multiprocessing features would deserve to bump the package by a new minor version: GeoUtils 0.2!

@rhugonnet
Copy link
Member

For instance, I've been thinking that it might be practical to return a Raster object directly after the multiprocessing call, by opening the outfile with Raster again when MultiprocConfig is passed?

This would allow to do this:

rst = Raster(starting_file)
rst_reproj = rst.reproject(config=config1)
rst_reproj_prox = rst_reproj.proximity(config=config2)

Instead of having to do this:

rst = Raster(starting_file)
rst.reproject(config=...)
rst_reproj = Raster(config1.outfile)
rst_reproj.proximity(config=config2)
rst_reproj_prox = Raster(config2.outfile)

I'll open a separate issue where we can list all our ideas towards a final API linked to multiprocessing calls! 😉

@rhugonnet
Copy link
Member

Are we commenting the new doc lines (or moving them to a single md not rendered) so that we can merge this?

@vschaffn
Copy link
Contributor Author

@rhugonnet I agree, especially after having manipulating a few times the multiproc reprojection. I made some quick changes:

  • multiproc_map_overlap_save and _multiproc_reproject return the Raster by opening the outfile
  • A tempfile is called by default in MultiprocConfig when no ouftile is provided
  • Small updates in the doc

Which doc lines do you want to comment ?

@rhugonnet
Copy link
Member

Great! 😄
For the doc: We can comment (or not list in index.md) everything until we release the documentation later all at once, we often do this for a new feature 😉

@vschaffn
Copy link
Contributor Author

@rhugonnet all right, then I have commented the changes in api.md and raster_class.md, and I will open a new PR for #677.
I think we can merge this one 😄

@rhugonnet
Copy link
Member

Perfect, thanks! I'll merge after the CI run 😄.

FYI: I noticed an error randomly failing in CI since a couple days. It's linked to Multiprocessing, and seems to be happening only on "ubuntu-latest" and "python 3.12". See here: https://github.com/GlacioHack/geoutils/actions/runs/14429620562/job/40462822308#step:13:760. You can find the same error in other Action runs of the past days.
Any ideas? 🤔

@vschaffn
Copy link
Contributor Author

I think it is the tests that use a cluster, sometimes the cluster crashes (I don't know why unfortunately, I'm not a multiprocessing expert), hence the try/except block used with a time limit in functions using a cluster. I will try to investigate anyway.

@rhugonnet
Copy link
Member

Ah yes, now that you say this, I've seen some CI runs be stuck for 2 hours a couple times recently too. Actually it's the case of the one I sent you above, lasted 1h45min!
If we can ensure that when a random failure like this happens it does kill the job rapidly, like you say, that'll already sort a big portion of the problem.

@rhugonnet rhugonnet merged commit 7f8faac into GlacioHack:main Apr 14, 2025
21 checks passed
@vschaffn vschaffn deleted the 648-multiprocessing2 branch April 14, 2025 10:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[POC] first try of multiprocessing for reprojection 2/2

3 participants