Skip to content

Commit ccf9f21

Browse files
Issue #1158 remaining hfb bottlenecks (#1159)
Fixes #1158 # Description Fix remaining HFB bottlenecks. This reduces writing the HFB package from the LHM from 12.5 minutes to 2 minutes. - Cache results ``xu.Ugrid2d.from_structured``, as this is a costly operation - Use ``and`` instead of ``&`` operator in the ``scalar_None`` function, to enable shortcutting. - Remove ``mask_all_packages`` function in ``from_imod5_data``, call in tests. - Create separate pixi task for slow unittests, where just in time compilation is enabled. ``pixi run unittests`` still runs all unittests, by starting two pixi tasks ``unittests_njit`` & ``unittests_jit``. # Checklist - [x] Links to correct issue - [x] Update changelog, if changes affect users - [x] PR title starts with ``Issue #nr``, e.g. ``Issue #737`` - [x] Unit tests were added - [ ] **If feature added**: Added/extended example
1 parent 4f3b716 commit ccf9f21

File tree

10 files changed

+157
-22
lines changed

10 files changed

+157
-22
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,5 +141,4 @@ examples/data
141141
.pixi
142142

143143
/imod/tests/mydask.png
144-
/imod/tests/unittest_report.xml
145-
/imod/tests/examples_report.xml
144+
/imod/tests/*_report.xml

docs/api/changelog.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ Fixed
1717
flow barrier for MODFLOW 6
1818
- Bug where error would be thrown when barriers in a ``HorizontalFlowBarrier``
1919
would be snapped to the same cell edge. These are now summed.
20+
- Improve performance validation upon Package initialization
21+
- Improve performance writing ``HorizontalFlowBarrier`` objects
2022

2123
Changed
2224
~~~~~~~

imod/mf6/simulation.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,8 +1378,5 @@ def from_imod5_data(
13781378
)
13791379
simulation["ims"] = solution
13801380

1381-
# cleanup packages for validation
1382-
idomain = groundwaterFlowModel.domain
1383-
simulation.mask_all_models(idomain)
13841381
simulation.create_time_discretization(additional_times=times)
13851382
return simulation

imod/schemata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def scalar_None(obj):
7474
if not isinstance(obj, (xr.DataArray, xu.UgridDataArray)):
7575
return False
7676
else:
77-
return (len(obj.shape) == 0) & (~obj.notnull()).all()
77+
return (len(obj.shape) == 0) and (obj.isnull()).all()
7878

7979

8080
def align_other_obj_with_coords(

imod/tests/test_mf6/test_mf6_chd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def test_from_imod5_shd(imod5_dataset, tmp_path):
238238
chd_shd.write("chd_shd", [1], write_context)
239239

240240

241+
@pytest.mark.unittest_jit
241242
@pytest.mark.parametrize("remove_merged_packages", [True, False])
242243
@pytest.mark.usefixtures("imod5_dataset")
243244
def test_concatenate_chd(imod5_dataset, tmp_path, remove_merged_packages):

imod/tests/test_mf6/test_mf6_simulation.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ def compare_submodel_partition_info(first: PartitionInfo, second: PartitionInfo)
476476
)
477477

478478

479+
@pytest.mark.unittest_jit
479480
@pytest.mark.usefixtures("imod5_dataset")
480481
def test_import_from_imod5(imod5_dataset, tmp_path):
481482
imod5_data = imod5_dataset[0]
@@ -495,18 +496,20 @@ def test_import_from_imod5(imod5_dataset, tmp_path):
495496
simulation["imported_model"]["oc"] = OutputControl(
496497
save_head="last", save_budget="last"
497498
)
498-
499499
simulation.create_time_discretization(["01-01-2003", "02-01-2003"])
500-
500+
# Cleanup
501501
# Remove HFB packages outside domain
502502
# TODO: Build in support for hfb packages outside domain
503503
for hfb_outside in ["hfb-24", "hfb-26"]:
504504
simulation["imported_model"].pop(hfb_outside)
505-
505+
# Align NoData to domain
506+
idomain = simulation["imported_model"].domain
507+
simulation.mask_all_models(idomain)
506508
# write and validate the simulation.
507509
simulation.write(tmp_path, binary=False, validate=True)
508510

509511

512+
@pytest.mark.unittest_jit
510513
@pytest.mark.usefixtures("imod5_dataset")
511514
def test_import_from_imod5__correct_well_type(imod5_dataset):
512515
# Unpack
@@ -537,6 +540,7 @@ def test_import_from_imod5__correct_well_type(imod5_dataset):
537540
assert isinstance(simulation["imported_model"]["wel-WELLS_L5"], LayeredWell)
538541

539542

543+
@pytest.mark.unittest_jit
540544
@pytest.mark.usefixtures("imod5_dataset")
541545
def test_import_from_imod5__nonstandard_regridding(imod5_dataset, tmp_path):
542546
imod5_data = imod5_dataset[0]
@@ -558,22 +562,23 @@ def test_import_from_imod5__nonstandard_regridding(imod5_dataset, tmp_path):
558562
times,
559563
regridding_option,
560564
)
561-
562565
simulation["imported_model"]["oc"] = OutputControl(
563566
save_head="last", save_budget="last"
564567
)
565-
566568
simulation.create_time_discretization(["01-01-2003", "02-01-2003"])
567-
569+
# Cleanup
568570
# Remove HFB packages outside domain
569571
# TODO: Build in support for hfb packages outside domain
570572
for hfb_outside in ["hfb-24", "hfb-26"]:
571573
simulation["imported_model"].pop(hfb_outside)
572-
574+
# Align NoData to domain
575+
idomain = simulation["imported_model"].domain
576+
simulation.mask_all_models(idomain)
573577
# write and validate the simulation.
574578
simulation.write(tmp_path, binary=False, validate=True)
575579

576580

581+
@pytest.mark.unittest_jit
577582
@pytest.mark.usefixtures("imod5_dataset")
578583
def test_import_from_imod5_no_storage_no_recharge(imod5_dataset, tmp_path):
579584
# this test imports an imod5 simulation, but it has no recharge and no storage package.
@@ -594,23 +599,22 @@ def test_import_from_imod5_no_storage_no_recharge(imod5_dataset, tmp_path):
594599
default_simulation_distributing_options,
595600
times,
596601
)
597-
598602
simulation["imported_model"]["oc"] = OutputControl(
599603
save_head="last", save_budget="last"
600604
)
601-
602605
simulation.create_time_discretization(["01-01-2003", "02-01-2003"])
603-
606+
# Cleanup
604607
# Remove HFB packages outside domain
605608
# TODO: Build in support for hfb packages outside domain
606609
for hfb_outside in ["hfb-24", "hfb-26"]:
607610
simulation["imported_model"].pop(hfb_outside)
608-
609611
# check storage is present and rch is absent
610612
assert not simulation["imported_model"]["sto"].dataset["transient"].values[()]
611613
package_keys = simulation["imported_model"].keys()
612614
for key in package_keys:
613615
assert key[0:3] != "rch"
614-
616+
# Align NoData to domain
617+
idomain = simulation["imported_model"].domain
618+
simulation.mask_all_models(idomain)
615619
# write and validate the simulation.
616620
simulation.write(tmp_path, binary=False, validate=True)

imod/tests/test_typing/test_typing_grid.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import numpy as np
12
import xarray as xr
23
import xugrid as xu
34

45
from imod.typing.grid import (
6+
UGRID2D_FROM_STRUCTURED_CACHE,
7+
GridCache,
8+
as_ugrid_dataarray,
59
enforce_dim_order,
610
is_planar_grid,
711
is_spatial_grid,
@@ -145,3 +149,73 @@ def test_merge_dictionary__unstructured(basic_unstructured_dis):
145149
assert isinstance(uds["bottom"], xr.DataArray)
146150
assert uds["ibound"].dims == ("layer", "mesh2d_nFaces")
147151
assert uds["bottom"].dims == ("layer",)
152+
153+
154+
def test_as_ugrid_dataarray__structured(basic_dis):
155+
# Arrange
156+
ibound, top, bottom = basic_dis
157+
top_3d = top * ibound
158+
bottom_3d = bottom * ibound
159+
# Clear cache
160+
UGRID2D_FROM_STRUCTURED_CACHE.clear()
161+
# Act
162+
ibound_disv = as_ugrid_dataarray(ibound)
163+
top_disv = as_ugrid_dataarray(top_3d)
164+
bottom_disv = as_ugrid_dataarray(bottom_3d)
165+
# Assert
166+
# Test types
167+
assert isinstance(ibound_disv, xu.UgridDataArray)
168+
assert isinstance(top_disv, xu.UgridDataArray)
169+
assert isinstance(bottom_disv, xu.UgridDataArray)
170+
# Test cache proper size
171+
assert len(UGRID2D_FROM_STRUCTURED_CACHE.grid_cache) == 1
172+
# Test that data is different
173+
assert np.all(ibound_disv != top_disv)
174+
assert np.all(top_disv != bottom_disv)
175+
# Test that grid is equal
176+
assert np.all(ibound_disv.grid == top_disv.grid)
177+
assert np.all(top_disv.grid == bottom_disv.grid)
178+
179+
180+
def test_as_ugrid_dataarray__unstructured(basic_unstructured_dis):
181+
# Arrange
182+
ibound, top, bottom = basic_unstructured_dis
183+
top_3d = enforce_dim_order(ibound * top)
184+
bottom_3d = enforce_dim_order(ibound * bottom)
185+
# Clear cache
186+
UGRID2D_FROM_STRUCTURED_CACHE.clear()
187+
# Act
188+
ibound_disv = as_ugrid_dataarray(ibound)
189+
top_disv = as_ugrid_dataarray(top_3d)
190+
bottom_disv = as_ugrid_dataarray(bottom_3d)
191+
# Assert
192+
# Test types
193+
assert isinstance(ibound_disv, xu.UgridDataArray)
194+
assert isinstance(top_disv, xu.UgridDataArray)
195+
assert isinstance(bottom_disv, xu.UgridDataArray)
196+
assert len(UGRID2D_FROM_STRUCTURED_CACHE.grid_cache) == 0
197+
198+
199+
def test_ugrid2d_cache(basic_dis):
200+
# Arrange
201+
ibound, _, _ = basic_dis
202+
# Act
203+
cache = GridCache(xu.Ugrid2d.from_structured, max_cache_size=3)
204+
for i in range(5):
205+
ugrid2d = cache.get_grid(ibound[:, i:, :])
206+
# Assert
207+
# Test types
208+
assert isinstance(ugrid2d, xu.Ugrid2d)
209+
# Test cache proper size
210+
assert cache.max_cache_size == 3
211+
assert len(cache.grid_cache) == 3
212+
# Check if smallest grid in last cache list by checking if amount of faces
213+
# correct
214+
expected_size = ibound[0, i:, :].size
215+
keys = list(cache.grid_cache.keys())
216+
last_ugrid = cache.grid_cache[keys[-1]]
217+
actual_size = last_ugrid.n_face
218+
assert expected_size == actual_size
219+
# Test clear cache
220+
cache.clear()
221+
assert len(cache.grid_cache) == 0

imod/typing/grid.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,10 +435,58 @@ def is_transient_data_grid(
435435
return False
436436

437437

438+
class GridCache:
439+
"""
440+
Cache grids in this object for a specific function, lookup grids based on
441+
unique geometry hash.
442+
"""
443+
444+
def __init__(self, func: Callable, max_cache_size=5):
445+
self.max_cache_size = max_cache_size
446+
self.grid_cache: dict[int, GridDataArray] = {}
447+
self.func = func
448+
449+
def get_grid(self, grid: GridDataArray):
450+
geom_hash = get_grid_geometry_hash(grid)
451+
if geom_hash not in self.grid_cache.keys():
452+
if len(self.grid_cache.keys()) >= self.max_cache_size:
453+
self.remove_first()
454+
self.grid_cache[geom_hash] = self.func(grid)
455+
return self.grid_cache[geom_hash]
456+
457+
def remove_first(self):
458+
keys = list(self.grid_cache.keys())
459+
self.grid_cache.pop(keys[0])
460+
461+
def clear(self):
462+
self.grid_cache = {}
463+
464+
465+
UGRID2D_FROM_STRUCTURED_CACHE = GridCache(xu.Ugrid2d.from_structured)
466+
467+
438468
@typedispatch
439469
def as_ugrid_dataarray(grid: xr.DataArray) -> xu.UgridDataArray:
440-
"""Enforce GridDataArray to UgridDataArray"""
441-
return xu.UgridDataArray.from_structured(grid)
470+
"""
471+
Enforce GridDataArray to UgridDataArray, calls
472+
xu.UgridDataArray.from_structured, which is a costly operation. Therefore
473+
cache results.
474+
"""
475+
476+
topology = UGRID2D_FROM_STRUCTURED_CACHE.get_grid(grid)
477+
478+
# Copied from:
479+
# https://github.com/Deltares/xugrid/blob/3dee693763da1c4c0859a4f53ac38d4b99613a33/xugrid/core/wrap.py#L236
480+
# Note that "da" is renamed to "grid" and "grid" to "topology"
481+
dims = grid.dims[:-2]
482+
coords = {k: grid.coords[k] for k in dims}
483+
face_da = xr.DataArray(
484+
grid.data.reshape(*grid.shape[:-2], -1),
485+
coords=coords,
486+
dims=[*dims, topology.face_dimension],
487+
name=grid.name,
488+
)
489+
return xu.UgridDataArray(face_da, topology)
442490

443491

444492
@typedispatch # type: ignore[no-redef]

pixi.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ install_with_deps = "python -m pip install --editable ."
1919
format = "ruff check --fix .; ruff format ."
2020
lint = "ruff check . ; ruff format --check ."
2121
tests = { depends_on = ["unittests", "examples"] }
22-
unittests = { cmd = [
22+
unittests = { depends_on = ["unittests_njit", "unittests_jit"] }
23+
unittests_njit = { cmd = [
2324
"NUMBA_DISABLE_JIT=1",
2425
"pytest",
2526
"-n", "auto",
26-
"-m", "not example and not user_acceptance",
27+
"-m", "not example and not user_acceptance and not unittest_jit",
2728
"--cache-clear",
2829
"--verbose",
2930
"--junitxml=unittest_report.xml",
@@ -32,6 +33,14 @@ unittests = { cmd = [
3233
"--cov-report=html:coverage",
3334
"--cov-config=.coveragerc"
3435
], depends_on = ["install"], cwd = "imod/tests" }
36+
unittests_jit = { cmd = [
37+
"pytest",
38+
"-n", "auto",
39+
"-m", "unittest_jit",
40+
"--cache-clear",
41+
"--verbose",
42+
"--junitxml=unittest_jit_report.xml",
43+
], depends_on = ["install"], cwd = "imod/tests" }
3544
examples = { cmd = [
3645
"NUMBA_DISABLE_JIT=1",
3746
"pytest",

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ ignore_missing_imports = true
158158
markers = [
159159
"example: marks test as example (deselect with '-m \"not example\"')",
160160
"user_acceptance: marks user acceptance tests (deselect with '-m \"not user_acceptance\"')",
161+
"unittest_jit: marks unit tests that should be jitted (deselect with '-m \"not unittest_jit\"')"
161162
]
162163

163164
[tool.hatch.version]

0 commit comments

Comments
 (0)