Skip to content

Commit

Permalink
Geomedian refactor (#50)
Browse files Browse the repository at this point in the history
* Added separate buffering for different cloud classes.  Applies masking nodata and cloud masks (plus buffering) before reprojection.

* Apply the contiguity flag as part of masking bad data.

* Remove the redundant input_data function.
  • Loading branch information
tebadi authored Jun 21, 2022
1 parent 0c352f8 commit 3d014da
Showing 1 changed file with 32 additions and 47 deletions.
79 changes: 32 additions & 47 deletions odc/stats/plugins/gm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""
Geomedian
"""
from typing import Optional, Mapping, Sequence, Tuple, Iterable
from typing import Optional, Mapping, Sequence, Tuple, Iterable, Dict
import xarray as xr
from datacube.model import Dataset
from datacube.utils.geometry import GeoBox
from odc.algo import erase_bad, geomedian_with_mads
from odc.algo.io import load_enum_filtered
from ._registry import StatsPluginInterface, register
from odc.algo import enum_to_bool, mask_cleanup


class StatsGM(StatsPluginInterface):
Expand All @@ -20,9 +20,8 @@ def __init__(
self,
bands: Tuple[str, ...],
mask_band: str,
cloud_classes: Tuple[str, ...],
nodata_classes: Optional[Tuple[str, ...]] = None,
filters: Optional[Iterable[Tuple[str, int]]] = None,
cloud_filters: Dict[str, Iterable[Tuple[str, int]]] = None,
basis_band=None,
aux_names=dict(smad="smad", emad="emad", bcmad="bcmad", count="count"),
work_chunks: Tuple[int, int] = (400, 400),
Expand All @@ -38,18 +37,14 @@ def __init__(
# NOTE: this ends up loading Mask band twice, once to compute
# ``.erase`` band and once to compute ``nodata`` mask.
input_bands = (*input_bands, self._mask_band)

super().__init__(
input_bands=input_bands,
basis=basis_band or self.bands[0],
**kwargs)

self.cloud_classes = tuple(cloud_classes)
# if filters:
# self.filters: Optional[Mapping] = dict(filters)
self.filters = filters
# else:
# self.filters = None
**kwargs,
)

self.cloud_filters = cloud_filters
self._renames = aux_names
self.aux_bands = tuple(
self._renames.get(k, k) for k in ("smad", "emad", "bcmad", "count")
Expand All @@ -67,29 +62,21 @@ def native_transform(self, xx: xr.Dataset) -> xr.Dataset:
if not self._mask_band in xx.data_vars:
return xx

# Apply the contiguity flag
non_contiguent = xx["nbart_contiguity"] == 0

# Erase Data Pixels for which mask == nodata
#
# xx[mask == nodata] = nodata
mask = xx[self._mask_band]
xx = xx.drop_vars([self._mask_band])
keeps = enum_to_bool(mask, self._nodata_classes, invert=True)
xx = keep_good_only(xx, keeps)
bad = enum_to_bool(mask, self._nodata_classes)
bad = bad | non_contiguent

return xx
for cloud_class, filter in self.cloud_filters.items():
cloud_mask = enum_to_bool(mask, (cloud_class,))
cloud_mask_buffered = mask_cleanup(cloud_mask, mask_filters=filter)
bad = cloud_mask_buffered | bad

def input_data(self, datasets: Sequence[Dataset], geobox: GeoBox) -> xr.Dataset:
erased = load_enum_filtered(
datasets,
self._mask_band,
geobox,
categories=self.cloud_classes,
filters=self.filters,
groupby=self.group_by,
resampling=self.resampling,
chunks={},
)
xx = super().input_data(datasets, geobox)
xx = erase_bad(xx, erased)
xx = xx.drop_vars([self._mask_band] + ["nbart_contiguity"])
xx = keep_good_only(xx, ~bad)
return xx

def reduce(self, xx: xr.Dataset) -> xr.Dataset:
Expand Down Expand Up @@ -120,21 +107,21 @@ class StatsGMS2(StatsGM):
SHORT_NAME = NAME
VERSION = "0.0.0"
PRODUCT_FAMILY = "geomedian"
DEFAULT_FILTER = [("opening", 2), ("dilation", 5)]

def __init__(
self,
bands: Optional[Tuple[str, ...]] = None,
mask_band: str = "SCL",
cloud_classes: Tuple[str, ...] = (
"cloud shadows",
"cloud medium probability",
"cloud high probability",
"thin cirrus",
),
filters: Optional[Iterable[Tuple[str, int]]] = [("opening", 2), ("dilation",5)],
cloud_filters: Dict[str, Iterable[Tuple[str, int]]] = {
"cloud shadows": DEFAULT_FILTER,
"cloud medium probability": DEFAULT_FILTER,
"cloud high probability": DEFAULT_FILTER,
"thin cirrus": DEFAULT_FILTER,
},
aux_names=dict(smad="SMAD", emad="EMAD", bcmad="BCMAD", count="COUNT"),
rgb_bands=None,
**kwargs
**kwargs,
):
if bands is None:
bands = (
Expand All @@ -155,11 +142,10 @@ def __init__(
super().__init__(
bands=bands,
mask_band=mask_band,
cloud_classes=cloud_classes,
filters=filters,
cloud_filters=cloud_filters,
aux_names=aux_names,
rgb_bands=rgb_bands,
**kwargs
**kwargs,
)


Expand All @@ -176,12 +162,11 @@ def __init__(
self,
bands: Optional[Tuple[str, ...]] = None,
mask_band: str = "fmask",
cloud_classes: Tuple[str, ...] = ("cloud", "shadow"),
nodata_classes: Optional[Tuple[str, ...]] = ("nodata",),
filters: Optional[Iterable[Tuple[str, int]]] = None,
cloud_filters: Dict[str, Iterable[Tuple[str, int]]] = None,
aux_names=dict(smad="sdev", emad="edev", bcmad="bcdev", count="count"),
rgb_bands=None,
**kwargs
**kwargs,
):
if bands is None:
bands = (
Expand All @@ -191,15 +176,15 @@ def __init__(
"nir",
"swir1",
"swir2",
"nbart_contiguity",
)
if rgb_bands is None:
rgb_bands = ("red", "green", "blue")

super().__init__(
bands=bands,
mask_band=mask_band,
cloud_classes=cloud_classes,
filters=filters,
cloud_filters=cloud_filters,
nodata_classes=nodata_classes,
aux_names=aux_names,
rgb_bands=rgb_bands,
Expand Down

0 comments on commit 3d014da

Please sign in to comment.