Skip to content

Commit ae8ff7c

Browse files
Issue #1698 Model write optimization (#1700)
Fixes #1698 # Description This is part 2 of fixing the performance issues with large model. In part 1 #1693 the modelsplitter has been optimized. In this PR the focus is on wiring the partitioned model. As @Huite pointed out in #1686 the performance bottleneck had to do with the fact that the same package had to be loaded from file multiple times while only a part of the file is actually needed. After digging around for a while i discovered that this had to do with the fact how we open de the dataset. `dataset = xr.open_dataset(path, **kwargs)` In the line above we don't specify anything chunk related. That has as a result that when you access the dataset the entire file has to be loaded from disk. By simply adding `chunks="auto"` this is no longer the case and a huge performance gain is achieved. There are some other changes related to setting chunking to auto. There are some parts of the code that don't expect to receive dask arrays. For instance you can use .item() on a dask array. Instead i now use .values[()]. I was also getting some errors when the to_netcdf method were called on the package. All of them had something to do with wrong/unsupported datatypes. In this PR you will find that an encoding is added for float16 types. And that in some packages the from_file method has been updated to ensure that he loaded type is converted to a supported type An unrelated change but performance wise significant change has been applied to the `_get_transport_models_per_flow_model` method. This method is used to match gwf models to gwt models so that gwfgwt exchanges can be created. This method was doing a full comparison of domains, which is expensive. There is also a method available that does the comparison on domain level. By switching to this method the matching algorithm becomes almost instantaneously. **NOTE** This PR has issue #1699 as a base. The base needs to altered to master once that PR is in **NOTE** This PR also improves the `dump` method **NOTE** some timmings: <img width="833" height="739" alt="image" src="https://github.com/user-attachments/assets/974c841c-0413-4433-8486-1abe98dc0715" /> <img width="843" height="215" alt="image" src="https://github.com/user-attachments/assets/c7082975-af35-4143-a6f9-860557b3eb09" /> <img width="842" height="705" alt="image" src="https://github.com/user-attachments/assets/383bf1a6-f028-4cb4-aa72-48ab95e84e3d" /> <!--- Before requesting review, please go through this checklist: --> - [x] Links to correct issue - [ ] Update changelog, if changes affect users - [x] PR title starts with ``Issue #nr``, e.g. ``Issue #737`` - [ ] Unit tests were added - [ ] **If feature added**: Added/extended example - [ ] **If feature added**: Added feature to API documentation - [ ] **If pixi.lock was changed**: Ran `pixi run generate-sbom` and committed changes --------- Co-authored-by: JoerivanEngelen <joerivanengelen@hotmail.com>
1 parent 5bcf9e5 commit ae8ff7c

File tree

6 files changed

+37
-9
lines changed

6 files changed

+37
-9
lines changed

imod/mf6/oc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def _render(self, directory, pkgname, globaltimes, binary):
172172
package_times = self.dataset[datavar].coords["time"].values
173173
starts = np.searchsorted(globaltimes, package_times) + 1
174174
for i, s in enumerate(starts):
175-
setting = self.dataset[datavar].isel(time=i).item()
175+
setting = self.dataset[datavar].isel(time=i).values[()]
176176
periods[s][key] = self._get_ocsetting(setting)
177177

178178
else:

imod/mf6/pkgbase.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,17 @@ def to_netcdf(
9494
kwargs.update({"encoding": self._netcdf_encoding()})
9595

9696
dataset = self.dataset
97+
98+
# Create encoding dict for float16 variables
99+
for var in dataset.data_vars:
100+
if dataset[var].dtype == np.float16:
101+
kwargs["encoding"][var] = {"dtype": "float32"}
102+
103+
# Also check coordinates
104+
for coord in dataset.coords:
105+
if dataset[coord].dtype == np.float16:
106+
kwargs["encoding"][coord] = {"dtype": "float32"}
107+
97108
if isinstance(dataset, xu.UgridDataset):
98109
if mdal_compliant:
99110
dataset = dataset.ugrid.to_dataset()
@@ -168,7 +179,7 @@ def from_file(cls, path: str | Path, **kwargs) -> Self:
168179
# TODO: seems like a bug? Remove str() call if fixed in xarray/zarr
169180
dataset = xr.open_zarr(str(path), **kwargs)
170181
else:
171-
dataset = xr.open_dataset(path, **kwargs)
182+
dataset = xr.open_dataset(path, chunks="auto", **kwargs)
172183

173184
if dataset.ugrid_roles.topology:
174185
dataset = xu.UgridDataset(dataset)
@@ -183,4 +194,12 @@ def from_file(cls, path: str | Path, **kwargs) -> Self:
183194
if _is_scalar_nan(value):
184195
dataset[key] = None
185196

197+
# to_netcdf converts strings into NetCDF "variable‑length UTF‑8 strings"
198+
# which are loaded as dtype=object arrays. This will convert them back
199+
# to str.
200+
vars = ["species"]
201+
for var in vars:
202+
if var in dataset:
203+
dataset[var] = dataset[var].astype(str)
204+
186205
return cls._from_dataset(dataset)

imod/mf6/simulation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from imod.typing import GridDataArray, GridDataset
6161
from imod.typing.grid import (
6262
concat,
63-
is_equal,
63+
is_same_domain,
6464
is_unstructured,
6565
merge_partitions,
6666
)
@@ -1622,10 +1622,16 @@ def _get_transport_models_per_flow_model(self) -> dict[str, list[str]]:
16221622

16231623
for flow_model_name in flow_models:
16241624
flow_model = self[flow_model_name]
1625+
1626+
matched_tsp_models = []
16251627
for tpt_model_name in transport_models:
16261628
tpt_model = self[tpt_model_name]
1627-
if is_equal(tpt_model.domain, flow_model.domain):
1629+
if is_same_domain(tpt_model.domain, flow_model.domain):
16281630
result[flow_model_name].append(tpt_model_name)
1631+
matched_tsp_models.append(tpt_model_name)
1632+
for tpt_model_name in matched_tsp_models:
1633+
transport_models.pop(tpt_model_name)
1634+
16291635
return result
16301636

16311637
def _generate_gwfgwt_exchanges(self) -> list[GWFGWT]:

imod/mf6/wel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def _to_mf6_package_information(
540540
else:
541541
message += " The first 10 unplaced wells are: \n"
542542

543-
is_filtered = self.dataset["id"].isin([filtered_wells])
543+
is_filtered = self.dataset["id"].compute().isin(filtered_wells)
544544
for i in range(min(10, len(filtered_wells))):
545545
ids = filtered_wells[i]
546546
x = self.dataset["x"].data[is_filtered][i]
@@ -1073,9 +1073,9 @@ def _assign_wells_to_layers(
10731073
) -> pd.DataFrame:
10741074
# Ensure top, bottom & k
10751075
# are broadcasted to 3d grid
1076-
like = ones_like(active)
1077-
bottom = like * bottom
1078-
top_2d = (like * top).sel(layer=1)
1076+
like = ones_like(active.compute())
1077+
bottom = like * bottom.compute()
1078+
top_2d = (like * top.compute()).sel(layer=1)
10791079
top_3d = bottom.shift(layer=1).fillna(top_2d)
10801080
k = like * k
10811081

imod/typing/grid.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,15 @@ def is_spatial_grid(_: Any) -> bool: # noqa: F811
316316

317317
@dispatch
318318
def is_equal(array1: xu.UgridDataArray, array2: xu.UgridDataArray) -> bool:
319+
if not is_same_domain(array1, array2):
320+
return False
319321
return array1.equals(array2) and array1.ugrid.grid.equals(array2.ugrid.grid)
320322

321323

322324
@dispatch # type: ignore[no-redef]
323325
def is_equal(array1: xr.DataArray, array2: xr.DataArray) -> bool: # noqa: F811
326+
if not is_same_domain(array1, array2):
327+
return False
324328
return array1.equals(array2)
325329

326330

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ ignore = [
8080
"E501", # line-too-long. This rule can't be fullfilled by the ruff formatter. The same behavior as black.
8181
"PD003",
8282
"PD004",
83-
"PD901",
8483
"PD011",
8584
"PD013",
8685
"PD015",

0 commit comments

Comments
 (0)