Skip to content

Commit 04f746e

Browse files
committed
Add comments to the ModelSplitter class
1 parent a32e989 commit 04f746e

File tree

1 file changed

+129
-73
lines changed

1 file changed

+129
-73
lines changed

imod/mf6/multimodel/modelsplitter.py

Lines changed: 129 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -90,98 +90,36 @@ class ModelSplitter:
9090
def __init__(self, partition_info: List[PartitionInfo]) -> None:
9191
self.partition_info = partition_info
9292

93+
# Initialize mapping from original model names to partitioned models
9394
self._model_to_partitioned_model: dict[str, dict[str, IModel]] = {}
9495

9596
# Initialize mapping from partition IDs to models
9697
self._partition_id_to_models: dict[int, dict[str, IModel]] = {}
9798
for submodel_partition_info in self.partition_info:
9899
self._partition_id_to_models[submodel_partition_info.id] = {}
99100

100-
def update_packages(self) -> None:
101-
"""
102-
Update packages that need to be updated after partitioning.
103-
This includes for example the transport model names in the buoyancy package.
101+
def split(self, model_name: str, model: IModel) -> dict[str, IModel]:
104102
"""
105-
# Update buoyancy packages
106-
for _, models in self._partition_id_to_models.items():
107-
flow_model = next(
108-
(
109-
model
110-
for model_name, model in models.items()
111-
if isinstance(model, GroundwaterFlowModel)
112-
),
113-
None,
114-
)
115-
transport_model_names = [
116-
model_name
117-
for model_name, model in models.items()
118-
if isinstance(model, GroundwaterTransportModel)
119-
]
120-
if flow_model is None:
121-
raise ValueError(
122-
"Could not find a groundwater flow model for updating the buoyancy package."
123-
)
103+
Split a model into multiple partitioned models based on partition information.
124104
125-
flow_model._update_buoyancy_package(transport_model_names)
105+
Each partition creates a separate submodel containing:
106+
- All non-boundary packages from the original model, clipped to the partition's domain
107+
- Boundary packages that have active cells within the partition's domain, clipped accordingly
108+
- IAgnosticPackages are excluded if they contain no data after clipping
126109
127-
# Update ssm packages
128-
for _, models in self._partition_id_to_models.items():
129-
flow_model = next(
130-
(
131-
model
132-
for model_name, model in models.items()
133-
if isinstance(model, GroundwaterFlowModel)
134-
),
135-
None,
136-
)
137-
transport_models = [
138-
model
139-
for model_name, model in models.items()
140-
if isinstance(model, GroundwaterTransportModel)
141-
]
142-
143-
for transport_model in transport_models:
144-
ssm_key = transport_model._get_pkgkey("ssm")
145-
if ssm_key is None:
146-
continue
147-
old_ssm_package = transport_model.pop(ssm_key)
148-
state_variable_name = old_ssm_package.dataset[
149-
"auxiliary_variable_name"
150-
].values[0]
151-
ssm_package = SourceSinkMixing.from_flow_model(
152-
flow_model, state_variable_name, is_split=True
153-
)
154-
if ssm_package is not None:
155-
transport_model[ssm_key] = ssm_package
156-
157-
def update_solutions(
158-
self, original_model_name_to_solution: dict[str, Solution]
159-
) -> None:
160-
"""
161-
Update the solutions to refer to the new partitioned models.
162-
"""
163-
for model_name, new_models in self._model_to_partitioned_model.items():
164-
solution = original_model_name_to_solution[model_name]
165-
solution._remove_model_from_solution(model_name)
166-
for new_model_name, new_model in new_models.items():
167-
solution._add_model_to_solution(new_model_name)
168-
169-
def split(self, model_name: str, model: IModel) -> dict[str, IModel]:
170-
"""
171-
Split a model into multiple partitioned models based on configured partition information.
172110
Parameters
173111
----------
174112
model_name : str
175-
Base name of the input model; partition identifiers are appended to create
176-
names for each resulting submodel.
113+
Base name of the input model. Partition IDs are appended to create
114+
unique names for each submodel (e.g., "model_0", "model_1").
177115
model : IModel
178116
The input model instance to partition.
117+
179118
Returns
180119
-------
181120
dict[str, IModel]
182121
A mapping from generated submodel names to the corresponding partitioned
183-
model instances, each containing only the packages and data relevant to its
184-
active domain.
122+
model instances, each clipped to its respective active domain.
185123
"""
186124
modelclass = type(model)
187125
partitioned_models = {}
@@ -233,7 +171,79 @@ def split(self, model_name: str, model: IModel) -> dict[str, IModel]:
233171

234172
return partitioned_models
235173

174+
def update_packages(self) -> None:
175+
"""
176+
Update packages that reference other models after partitioning.
177+
178+
This method performs two updates:
179+
1. Updates buoyancy packages in flow models to reference the correct
180+
partitioned transport model names.
181+
2. Recreates Source Sink Mixing (SSM) packages in transport models
182+
based on the partitioned flow model data.
183+
"""
184+
# Update buoyancy packages
185+
for _, models in self._partition_id_to_models.items():
186+
flow_model = self._get_flow_model(models)
187+
transport_model_names = self._get_transport_model_names(models)
188+
189+
flow_model._update_buoyancy_package(transport_model_names)
190+
191+
# Update ssm packages
192+
for _, models in self._partition_id_to_models.items():
193+
flow_model = self._get_flow_model(models)
194+
transport_models = self._get_transport_models(models)
195+
196+
for transport_model in transport_models:
197+
ssm_key = transport_model._get_pkgkey("ssm")
198+
if ssm_key is None:
199+
continue
200+
old_ssm_package = transport_model.pop(ssm_key)
201+
state_variable_name = old_ssm_package.dataset[
202+
"auxiliary_variable_name"
203+
].values[0]
204+
ssm_package = SourceSinkMixing.from_flow_model(
205+
flow_model, state_variable_name, is_split=True
206+
)
207+
if ssm_package is not None:
208+
transport_model[ssm_key] = ssm_package
209+
210+
def update_solutions(
211+
self, original_model_name_to_solution: dict[str, Solution]
212+
) -> None:
213+
"""
214+
Update solution objects to reference partitioned models instead of original models.
215+
216+
For each original model that was split:
217+
1. Removes the original model reference from its solution (This was a deepcopy of the original solution and thus references the original model).
218+
2. Adds all partitioned submodel references to the same solution
219+
220+
This ensures that the Solution objects correctly reference the new partitioned
221+
model names after splitting.
222+
"""
223+
for model_name, new_models in self._model_to_partitioned_model.items():
224+
solution = original_model_name_to_solution[model_name]
225+
solution._remove_model_from_solution(model_name)
226+
for new_model_name, new_model in new_models.items():
227+
solution._add_model_to_solution(new_model_name)
228+
236229
def _get_package_domain(self, package: IPackage) -> GridDataArray | None:
230+
"""
231+
Extract the active domain of a boundary condition package.
232+
233+
For boundary condition packages, this method identifies which cells contain
234+
active boundary data by checking the package's defining variable (e.g.,
235+
"head" for CHD, "rate" for WEL). Non-boundary packages return None.
236+
237+
The active domain is determined by:
238+
1. Retrieving the variable that defines active cells (from _pkg_id_to_var_mapping)
239+
2. Removing non-spatial dimensions and the layer dimension
240+
3. Creating a boolean mask where non-null values indicate active cells
241+
242+
Special cases:
243+
- IAgnosticPackages: No domain check (return None)
244+
- Packages in _pkg_id_skip_active_domain_check: No domain check (return None)
245+
- Non-boundary packages: Return None
246+
"""
237247
pkg_id = package.pkg_id
238248
active_package_domain = None
239249

@@ -263,6 +273,19 @@ def _has_package_data_in_domain(
263273
active_package_domain: GridDataArray,
264274
partition_info: PartitionInfo,
265275
) -> bool:
276+
"""
277+
Check if a package has any active data within a partition's domain.
278+
279+
For boundary condition packages, this method determines whether the package
280+
should be included in a partitioned model by checking if its active cells
281+
overlap with the partition's active domain.
282+
283+
The method returns True in the following cases:
284+
- Package is not a BoundaryCondition (non-boundary packages are always included)
285+
- Package is an IAgnosticPackage (overlap check deferred until after slicing)
286+
- Package is in _pkg_id_skip_active_domain_check (e.g., SSM, LAK packages)
287+
- Package has at least one active cell overlapping with the partition domain
288+
"""
266289
pkg_id = package.pkg_id
267290
has_overlap = True
268291
if isinstance(package, BoundaryCondition):
@@ -278,3 +301,36 @@ def _has_package_data_in_domain(
278301
).any() # type: ignore
279302

280303
return has_overlap
304+
305+
def _get_flow_model(self, models: dict[str, IModel]) -> GroundwaterFlowModel:
306+
flow_model = next(
307+
(
308+
model
309+
for model_name, model in models.items()
310+
if isinstance(model, GroundwaterFlowModel)
311+
),
312+
None,
313+
)
314+
315+
if flow_model is None:
316+
raise ValueError(
317+
"Could not find a groundwater flow model for updating the buoyancy package."
318+
)
319+
320+
return flow_model
321+
322+
def _get_transport_model_names(self, models: dict[str, IModel]) -> List[str]:
323+
return [
324+
model_name
325+
for model_name, model in models.items()
326+
if isinstance(model, GroundwaterTransportModel)
327+
]
328+
329+
def _get_transport_models(
330+
self, models: dict[str, IModel]
331+
) -> List[GroundwaterTransportModel]:
332+
return [
333+
model
334+
for model_name, model in models.items()
335+
if isinstance(model, GroundwaterTransportModel)
336+
]

0 commit comments

Comments
 (0)