Skip to content

Commit

Permalink
👌 Return of the gapfilled REMA training tiles
Browse files Browse the repository at this point in the history
Neural network wasn't training properly, and I tracked it down to the REMA input rasters having low NaN-like values... Found out proper way to get dask DataArray masks using dask.array.ma module, and so we can reintroduce gapfilling in data_prep.selective_tile, this time using dask/xarray to vectorize the operations. The gapfilled raster is also interpolated better along the edges as in 7fd3345 which might help with the neural network training later. Quilt hash updated from 9c8cb530df6340e257e18008b59b9d7b5f701fd9e5cef2c8436984ae49cff237 to b0b090ca35271d41ea1cf5e6afa0e6c6a3da34193c00444963dde7ad20eb7331. Not passing in a gapfill_raster_filepath (when it is needed) now errors out with nicer debugging plots that have EPSG:3031 projected coordinates on the axes!
  • Loading branch information
weiji14 committed Jun 13, 2019
1 parent 7fd3345 commit 4a074d9
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 35 deletions.
74 changes: 51 additions & 23 deletions data_prep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
"import salem\n",
"\n",
"import dask\n",
"import dask.diagnostics\n",
"import geopandas as gpd\n",
"import pygmt as gmt\n",
"import IPython.display\n",
Expand Down Expand Up @@ -1211,8 +1210,9 @@
" for x0, y0, x1, y1 in window_bounds # xmin, ymin, xmax, ymax\n",
" ]\n",
"\n",
" # Retrieve tiles from the main raster\n",
" with xr.open_rasterio(\n",
" filepath, chunks=None if out_shape is None else {}, cache=False\n",
" filepath, chunks=None if out_shape is None else {}\n",
" ) as dataset:\n",
" print(f\"Tiling: {filepath} ... \", end=\"\")\n",
"\n",
Expand All @@ -1231,21 +1231,50 @@
" )\n",
" for da in daarray_list\n",
" ]\n",
" daarray_stack = dask.array.stack(seq=daarray_list)\n",
" daarray_stack = dask.array.ma.masked_values(\n",
" x=dask.array.stack(seq=daarray_list), value=dataset.nodatavals\n",
" )\n",
"\n",
" assert daarray_stack.ndim == 4 # check that shape is like (m, 1, height, width)\n",
" assert daarray_stack.shape[1] == 1 # channel-first (assuming only 1 channel)\n",
" assert not 0 in daarray_stack.shape # ensure no empty dimensions (bad window)\n",
" print(\"done!\")\n",
"\n",
" with dask.diagnostics.ProgressBar(minimum=5.0):\n",
" try:\n",
" out_tiles = daarray_stack.compute().astype(dtype=np.float32)\n",
" assert not np.isnan(out_tiles).any() # check that there are no NAN values\n",
" except AssertionError:\n",
" raise NotImplementedError(\"gapfilling on dask xarray not yet implemented\")\n",
" finally:\n",
" return out_tiles"
" out_tiles = dask.array.ma.getdata(daarray_stack).compute().astype(dtype=np.float32)\n",
" mask = dask.array.ma.getmaskarray(daarray_stack).compute()\n",
"\n",
" # Gapfill main raster if there are blank spaces\n",
" if mask.any(): # check that there are no NAN values\n",
" nan_grid_indexes = np.argwhere(mask.any(axis=(-3, -2, -1))).ravel()\n",
"\n",
" # Replace pixels from another raster if available, else raise error\n",
" if gapfill_raster_filepath is not None:\n",
" with xr.open_rasterio(gapfill_raster_filepath, chunks={}) as dataset2:\n",
" daarray_list2 = [\n",
" dataset2.interp_like(daarray_list[idx].squeeze(), method=\"linear\")\n",
" for idx in nan_grid_indexes\n",
" ]\n",
" daarray_stack2 = dask.array.ma.masked_values(\n",
" x=dask.array.stack(seq=daarray_list2), value=dataset2.nodatavals\n",
" )\n",
"\n",
" fill_tiles = (\n",
" dask.array.ma.getdata(daarray_stack2).compute().astype(dtype=np.float32)\n",
" )\n",
" mask2 = dask.array.ma.getmaskarray(daarray_stack2).compute()\n",
"\n",
" for i, array2 in enumerate(fill_tiles):\n",
" idx = nan_grid_indexes[i]\n",
" np.copyto(dst=out_tiles[idx], src=array2, where=mask[idx])\n",
" assert not (mask[idx] & mask2[i]).any() # Ensure no NANs after gapfill\n",
"\n",
" else:\n",
" for i in nan_grid_indexes:\n",
" daarray_list[i].plot()\n",
" plt.show()\n",
" print(f\"WARN: Tiles have missing data, try pass in gapfill_raster_filepath\")\n",
"\n",
" return out_tiles"
]
},
{
Expand Down Expand Up @@ -1353,7 +1382,7 @@
" filepath=\"misc/REMA_100m_dem.tif\",\n",
" window_bounds=window_bounds_concat,\n",
" padding=1000,\n",
" # gapfill_raster_filepath=\"misc/REMA_200m_dem_filled.tif\",\n",
" gapfill_raster_filepath=\"misc/REMA_200m_dem_filled.tif\",\n",
")\n",
"print(rema.shape, rema.dtype)"
]
Expand Down Expand Up @@ -1390,7 +1419,6 @@
"output_type": "stream",
"text": [
"Tiling: misc/MEaSUREs_IceFlowSpeed_450m.tif ... done!\n",
"[########################################] | 100% Completed | 22.1s\n",
"(2347, 1, 20, 20) float32\n"
]
}
Expand Down Expand Up @@ -1485,7 +1513,7 @@
"name": "stdin",
"output_type": "stream",
"text": [
"Enter the code from the webpage: eyJjb2RlIjogIjg4ODljZTY0LTA1ODMtNGIxYS04YjE2LTQ0MjFjZDViMTQxNCIsICJpZCI6ICIyOWI4YzUyNS1lZmM1LTQ5NTItOGQ4Yy03NzQyYTg1YmI1MmEifQ==\n"
"Enter the code from the webpage: eyJjb2RlIjogIjg0OTA5ODJlLTM0NWYtNDljNC04Y2Q0LTUwY2FlMjhiOWNlZSIsICJpZCI6ICIyOWI4YzUyNS1lZmM1LTQ5NTItOGQ4Yy03NzQyYTg1YmI1MmEifQ==\n"
]
}
],
Expand Down Expand Up @@ -1557,32 +1585,32 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 96%|█████████| 6.47G/6.74G [00:01<04:17, 1.04MB/s] "
" 94%|█████████| 6.35G/6.74G [00:01<01:40, 3.91MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fragment 16ed97cce049cd2859a379964a8fa7575d9b871ec126d33c824542b126eab177 already uploaded; skipping.\n",
"Fragment c665815f043b87cfe94d51caabd1b57d8f6f6773d632503de6db0725f20d391c already uploaded; skipping.\n",
"Fragment 1f66fe557ce079c063597f0b04d15862f67af2c9dd4f286801851e0c71f0e869 already uploaded; skipping.\n",
"Fragment 4a4efc3a84204c3d67887e8d7fa1186467b51e696451f2832ebbea3ca491c8a8 already uploaded; skipping.\n",
"Fragment 2b994ae9d13f6c01ce00c426f52c6dce0c4681f8c8aaf8a96608fd3d62f3a269 already uploaded; skipping.\n",
"Fragment 28e2ca7656d61b0bc7f8f8c1db41914023e0cab1634e0ee645f38a87d894b416 already uploaded; skipping.\n",
"Fragment 1f66fe557ce079c063597f0b04d15862f67af2c9dd4f286801851e0c71f0e869 already uploaded; skipping.\n",
"Fragment f1f660d1287225c30b8b2cbf2a727283d807a1ee443153519cbf407a08937965 already uploaded; skipping.\n",
"Fragment 6ef3a2439a508de0919bd33a713976b5aa4895929a9d7981c09f722ce702e16a already uploaded; skipping.\n",
"Fragment 80c9fa41ccc69be1d2cd4a367d56168321d1079e7260a1996089810db25172f6 already uploaded; skipping.\n",
"Fragment ca9c41a8dd56097e40865d2e65c65d299c22fc17608ddb6c604c532a69936307 already uploaded; skipping.\n",
"Fragment 04a52d9a52901d8f7f74fd9ef6fc9fc215d6c9d787540511f68630f5cca16094 already uploaded; skipping.\n",
"Fragment f1f660d1287225c30b8b2cbf2a727283d807a1ee443153519cbf407a08937965 already uploaded; skipping.\n",
"Fragment f750893861a1a268c8ffe0ba7db36c933223bbf5fcbb786ecef3f052b20f9b8a already uploaded; skipping.\n",
"Fragment e6b139801bf4541f1e4989a8aa8b26ab37eca81bb5eaffa8028b744782455db0 already uploaded; skipping.\n"
"Fragment c665815f043b87cfe94d51caabd1b57d8f6f6773d632503de6db0725f20d391c already uploaded; skipping.\n",
"Fragment 16ed97cce049cd2859a379964a8fa7575d9b871ec126d33c824542b126eab177 already uploaded; skipping.\n",
"Fragment e6b139801bf4541f1e4989a8aa8b26ab37eca81bb5eaffa8028b744782455db0 already uploaded; skipping.\n",
"Fragment 4a4efc3a84204c3d67887e8d7fa1186467b51e696451f2832ebbea3ca491c8a8 already uploaded; skipping.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 6.74G/6.74G [00:03<00:00, 1.77GB/s]\n"
"100%|██████████| 6.74G/6.74G [00:09<00:00, 688MB/s] \n"
]
},
{
Expand Down
53 changes: 41 additions & 12 deletions data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import salem

import dask
import dask.diagnostics
import geopandas as gpd
import pygmt as gmt
import IPython.display
Expand Down Expand Up @@ -621,8 +620,9 @@ def selective_tile(
for x0, y0, x1, y1 in window_bounds # xmin, ymin, xmax, ymax
]

# Retrieve tiles from the main raster
with xr.open_rasterio(
filepath, chunks=None if out_shape is None else {}, cache=False
filepath, chunks=None if out_shape is None else {}
) as dataset:
print(f"Tiling: {filepath} ... ", end="")

Expand All @@ -641,21 +641,50 @@ def selective_tile(
)
for da in daarray_list
]
daarray_stack = dask.array.stack(seq=daarray_list)
daarray_stack = dask.array.ma.masked_values(
x=dask.array.stack(seq=daarray_list), value=dataset.nodatavals
)

assert daarray_stack.ndim == 4 # check that shape is like (m, 1, height, width)
assert daarray_stack.shape[1] == 1 # channel-first (assuming only 1 channel)
assert not 0 in daarray_stack.shape # ensure no empty dimensions (bad window)
print("done!")

with dask.diagnostics.ProgressBar(minimum=5.0):
try:
out_tiles = daarray_stack.compute().astype(dtype=np.float32)
assert not np.isnan(out_tiles).any() # check that there are no NAN values
except AssertionError:
raise NotImplementedError("gapfilling on dask xarray not yet implemented")
finally:
return out_tiles
out_tiles = dask.array.ma.getdata(daarray_stack).compute().astype(dtype=np.float32)
mask = dask.array.ma.getmaskarray(daarray_stack).compute()

# Gapfill main raster if there are blank spaces
if mask.any(): # check that there are no NAN values
nan_grid_indexes = np.argwhere(mask.any(axis=(-3, -2, -1))).ravel()

# Replace pixels from another raster if available, else raise error
if gapfill_raster_filepath is not None:
with xr.open_rasterio(gapfill_raster_filepath, chunks={}) as dataset2:
daarray_list2 = [
dataset2.interp_like(daarray_list[idx].squeeze(), method="linear")
for idx in nan_grid_indexes
]
daarray_stack2 = dask.array.ma.masked_values(
x=dask.array.stack(seq=daarray_list2), value=dataset2.nodatavals
)

fill_tiles = (
dask.array.ma.getdata(daarray_stack2).compute().astype(dtype=np.float32)
)
mask2 = dask.array.ma.getmaskarray(daarray_stack2).compute()

for i, array2 in enumerate(fill_tiles):
idx = nan_grid_indexes[i]
np.copyto(dst=out_tiles[idx], src=array2, where=mask[idx])
assert not (mask[idx] & mask2[i]).any() # Ensure no NANs after gapfill

else:
for i in nan_grid_indexes:
daarray_list[i].plot()
plt.show()
print(f"WARN: Tiles have missing data, try pass in gapfill_raster_filepath")

return out_tiles


# %%
Expand Down Expand Up @@ -695,7 +724,7 @@ def selective_tile(
filepath="misc/REMA_100m_dem.tif",
window_bounds=window_bounds_concat,
padding=1000,
# gapfill_raster_filepath="misc/REMA_200m_dem_filled.tif",
gapfill_raster_filepath="misc/REMA_200m_dem_filled.tif",
)
print(rema.shape, rema.dtype)

Expand Down

0 comments on commit 4a074d9

Please sign in to comment.