Skip to content

Commit 9c5134a

Browse files
committed
Add chunking when opening netcdf files. Handle errors for code that didn't expect to recieve dask objects. Optimize the flow-transport model matcher
1 parent 04f746e commit 9c5134a

File tree

7 files changed

+51
-9
lines changed

7 files changed

+51
-9
lines changed

imod/mf6/hfb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ def to_netcdf(
557557
558558
"""
559559
kwargs.update({"encoding": self._netcdf_encoding()})
560+
kwargs.update({"format": "NETCDF4"})
560561

561562
new = deepcopy(self)
562563
new.dataset["geometry"] = new.line_data.to_json()

imod/mf6/oc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def _render(self, directory, pkgname, globaltimes, binary):
172172
package_times = self.dataset[datavar].coords["time"].values
173173
starts = np.searchsorted(globaltimes, package_times) + 1
174174
for i, s in enumerate(starts):
175-
setting = self.dataset[datavar].isel(time=i).item()
175+
setting = self.dataset[datavar].isel(time=i).values[()]
176176
periods[s][key] = self._get_ocsetting(setting)
177177

178178
else:

imod/mf6/pkgbase.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,20 @@ def to_netcdf(
9292
9393
"""
9494
kwargs.update({"encoding": self._netcdf_encoding()})
95+
kwargs.update({"format": "NETCDF4"})
9596

9697
dataset = self.dataset
98+
99+
# Create encoding dict for float16 variables
100+
for var in dataset.data_vars:
101+
if dataset[var].dtype == np.float16:
102+
kwargs["encoding"][var] = {"dtype": "float32"}
103+
104+
# Also check coordinates
105+
for coord in dataset.coords:
106+
if dataset[coord].dtype == np.float16:
107+
kwargs["encoding"][coord] = {"dtype": "float32"}
108+
97109
if isinstance(dataset, xu.UgridDataset):
98110
if mdal_compliant:
99111
dataset = dataset.ugrid.to_dataset()
@@ -168,7 +180,7 @@ def from_file(cls, path: str | Path, **kwargs) -> Self:
168180
# TODO: seems like a bug? Remove str() call if fixed in xarray/zarr
169181
dataset = xr.open_zarr(str(path), **kwargs)
170182
else:
171-
dataset = xr.open_dataset(path, **kwargs)
183+
dataset = xr.open_dataset(path, chunks="auto", **kwargs)
172184

173185
if dataset.ugrid_roles.topology:
174186
dataset = xu.UgridDataset(dataset)

imod/mf6/rch.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from datetime import datetime
2-
from typing import Optional
2+
from pathlib import Path
3+
from typing import Optional, Self
34

45
import numpy as np
56
import xarray as xr
@@ -166,6 +167,22 @@ def __init__(
166167
super().__init__(dict_dataset)
167168
self._validate_init_schemata(validate)
168169

170+
@classmethod
171+
def from_file(cls, path: str | Path, **kwargs) -> Self:
172+
instance = super().from_file(path, **kwargs)
173+
174+
# to_netcdf converts strings into NetCDF "variable‑length UTF‑8 strings"
175+
# which are loaded as dtype=object arrays.
176+
# This will convert them back to str.
177+
vars = [
178+
"species",
179+
]
180+
for var in vars:
181+
if var in instance.dataset:
182+
instance.dataset[var] = instance.dataset[var].astype(str)
183+
184+
return instance
185+
169186
def _validate(self, schemata, **kwargs):
170187
# Insert additional kwargs
171188
kwargs["rate"] = self["rate"]

imod/mf6/simulation.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from imod.typing import GridDataArray, GridDataset
6161
from imod.typing.grid import (
6262
concat,
63-
is_equal,
63+
is_same_domain,
6464
is_unstructured,
6565
merge_partitions,
6666
)
@@ -1037,12 +1037,14 @@ def dump(
10371037
_, filename, _, _ = exchange_package.get_specification()
10381038
exchange_class_short = type(exchange_package).__name__
10391039
path = f"{filename}.nc"
1040-
exchange_package.dataset.to_netcdf(directory / path)
1040+
exchange_package.dataset.to_netcdf(
1041+
directory / path, format="NETCDF4"
1042+
)
10411043
toml_content[key][exchange_class_short].append(path)
10421044

10431045
else:
10441046
path = f"{key}.nc"
1045-
value.dataset.to_netcdf(directory / path)
1047+
value.dataset.to_netcdf(directory / path, format="NETCDF4")
10461048
toml_content[cls_name][key] = path
10471049

10481050
with open(directory / f"{self.name}.toml", "wb") as f:
@@ -1620,10 +1622,16 @@ def _get_transport_models_per_flow_model(self) -> dict[str, list[str]]:
16201622

16211623
for flow_model_name in flow_models:
16221624
flow_model = self[flow_model_name]
1625+
1626+
matched_tsp_models = []
16231627
for tpt_model_name in transport_models:
16241628
tpt_model = self[tpt_model_name]
1625-
if is_equal(tpt_model.domain, flow_model.domain):
1629+
if is_same_domain(tpt_model.domain, flow_model.domain):
16261630
result[flow_model_name].append(tpt_model_name)
1631+
matched_tsp_models.append(tpt_model_name)
1632+
for tpt_model_name in matched_tsp_models:
1633+
transport_models.pop(tpt_model_name)
1634+
16271635
return result
16281636

16291637
def _generate_gwfgwt_exchanges(self) -> list[GWFGWT]:

imod/mf6/wel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def _to_mf6_package_information(
540540
else:
541541
message += " The first 10 unplaced wells are: \n"
542542

543-
is_filtered = self.dataset["id"].isin([filtered_wells])
543+
is_filtered = self.dataset["id"].compute().isin(filtered_wells)
544544
for i in range(min(10, len(filtered_wells))):
545545
ids = filtered_wells[i]
546546
x = self.dataset["x"].data[is_filtered][i]
@@ -1076,7 +1076,7 @@ def _assign_wells_to_layers(
10761076
like = ones_like(active)
10771077
bottom = like * bottom
10781078
top_2d = (like * top).sel(layer=1)
1079-
top_3d = bottom.shift(layer=1).fillna(top_2d)
1079+
top_3d = bottom.compute().shift(layer=1).fillna(top_2d)
10801080
k = like * k
10811081

10821082
index_names = wells_df.index.names

imod/typing/grid.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,15 @@ def is_spatial_grid(_: Any) -> bool: # noqa: F811
316316

317317
@dispatch
318318
def is_equal(array1: xu.UgridDataArray, array2: xu.UgridDataArray) -> bool:
319+
if not is_same_domain(array1, array2):
320+
return False
319321
return array1.equals(array2) and array1.ugrid.grid.equals(array2.ugrid.grid)
320322

321323

322324
@dispatch # type: ignore[no-redef]
323325
def is_equal(array1: xr.DataArray, array2: xr.DataArray) -> bool: # noqa: F811
326+
if not is_same_domain(array1, array2):
327+
return False
324328
return array1.equals(array2)
325329

326330

0 commit comments

Comments
 (0)