Skip to content

Commit

Permalink
filter data by sensor temporal coverage in ml
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Nov 5, 2024
1 parent 4ce1d71 commit 5eef33a
Showing 1 changed file with 56 additions and 15 deletions.
71 changes: 56 additions & 15 deletions odc/stats/plugins/lc_ml_treelite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from odc.algo.io import load_with_native_transform

from odc.stats._algebra import expr_eval
from odc.stats.model import DateTimeRange
from ._registry import StatsPluginInterface
from ._worker import TreeliteModelPlugin
import tl2cgen
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
output_classes: Dict,
model_path: str,
mask_bands: Optional[Dict] = None,
temporal_coverage: Optional[Dict] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -74,6 +76,7 @@ def __init__(
self.dask_worker_plugin = TreeliteModelPlugin(model_path)
self.output_classes = output_classes
self.mask_bands = mask_bands
self.temporal_coverage = temporal_coverage
self._log = logging.getLogger(__name__)

def input_data(
Expand Down Expand Up @@ -107,40 +110,78 @@ def input_data(
(self.chunks["x"], self.chunks["y"], -1, -1),
dtype="float32",
name=ds.type.name + "_yxbt",
).squeeze("spec", drop=True)
)
data_vars[ds.type.name] = input_array
else:
for var in xx.data_vars:
data_vars[var] = yxt_sink(
xx[var].astype("uint8"),
(self.chunks["x"], self.chunks["y"], -1),
name=ds.type.name + "_yxt",
).squeeze("spec", drop=True)
)

coords = dict((dim, input_array.coords[dim]) for dim in input_array.dims)
return xr.Dataset(data_vars=data_vars, coords=coords)

def impute_missing_values(self, xx: xr.Dataset, image):
imputed = None
for var in xx.data_vars:
if var in self.mask_bands:
continue
nodata = xx[var].attrs.get("nodata", -999)
imputed = expr_eval(
"where((a==a)|(b<=nodata), a, b)",
{
"a": image,
"b": xx[var].squeeze("spec", drop=True).data,
},
name="impute_missing",
dtype="float32",
**{"nodata": nodata},
)
return imputed if imputed is not None else image

def preprocess_predict_input(self, xx: xr.Dataset):
images = []
veg_mask = None

def convert_dtype(var):
nodata = xx[var].attrs.get("nodata", -999)
image = expr_eval(
"where((a<=nodata), _nan, a)",
{
"a": xx[var].squeeze("spec", drop=True).data,
},
name="convert_dtype",
dtype="float32",
**{"nodata": nodata, "_nan": np.nan},
)
return image

for var in xx.data_vars:
image = xx[var].data
if var not in self.mask_bands:
nodata = xx[var].attrs.get("nodata", -999)
image = expr_eval(
"where((a<=nodata), _nan, a)",
{
"a": image,
},
name="convert_dtype",
dtype="float32",
**{"nodata": nodata, "_nan": np.nan},
)
images += [image]
if self.temporal_coverage is not None:
# filter and impute by sensors
temporal_range = [
DateTimeRange(v) for v in self.temporal_coverage.get(var)
]
for tr in temporal_range:
if xx.solar_day in tr:
self._log.info("Impute missing values of %s", var)
image = convert_dtype(var)
images += [
self.impute_missing_values(xx.drop_vars(var), image)
]
break
else:
# use data from all sensors
image = convert_dtype(var)
images += [image]

else:
veg_mask = expr_eval(
"where(a==_v, 1, 0)",
{"a": image},
{"a": xx[var].squeeze("spec", drop=True).data},
name="make_mask",
dtype="float32",
**{"_v": int(self.mask_bands[var])},
Expand Down

0 comments on commit 5eef33a

Please sign in to comment.