@@ -90,98 +90,36 @@ class ModelSplitter:
9090 def __init__ (self , partition_info : List [PartitionInfo ]) -> None :
9191 self .partition_info = partition_info
9292
93+ # Initialize mapping from original model names to partitioned models
9394 self ._model_to_partitioned_model : dict [str , dict [str , IModel ]] = {}
9495
9596 # Initialize mapping from partition IDs to models
9697 self ._partition_id_to_models : dict [int , dict [str , IModel ]] = {}
9798 for submodel_partition_info in self .partition_info :
9899 self ._partition_id_to_models [submodel_partition_info .id ] = {}
99100
100- def update_packages (self ) -> None :
101- """
102- Update packages that need to be updated after partitioning.
103- This includes for example the transport model names in the buoyancy package.
101+ def split (self , model_name : str , model : IModel ) -> dict [str , IModel ]:
104102 """
105- # Update buoyancy packages
106- for _ , models in self ._partition_id_to_models .items ():
107- flow_model = next (
108- (
109- model
110- for model_name , model in models .items ()
111- if isinstance (model , GroundwaterFlowModel )
112- ),
113- None ,
114- )
115- transport_model_names = [
116- model_name
117- for model_name , model in models .items ()
118- if isinstance (model , GroundwaterTransportModel )
119- ]
120- if flow_model is None :
121- raise ValueError (
122- "Could not find a groundwater flow model for updating the buoyancy package."
123- )
103+ Split a model into multiple partitioned models based on partition information.
124104
125- flow_model ._update_buoyancy_package (transport_model_names )
105+ Each partition creates a separate submodel containing:
106+ - All non-boundary packages from the original model, clipped to the partition's domain
107+ - Boundary packages that have active cells within the partition's domain, clipped accordingly
108+ - IAgnosticPackages are excluded if they contain no data after clipping
126109
127- # Update ssm packages
128- for _ , models in self ._partition_id_to_models .items ():
129- flow_model = next (
130- (
131- model
132- for model_name , model in models .items ()
133- if isinstance (model , GroundwaterFlowModel )
134- ),
135- None ,
136- )
137- transport_models = [
138- model
139- for model_name , model in models .items ()
140- if isinstance (model , GroundwaterTransportModel )
141- ]
142-
143- for transport_model in transport_models :
144- ssm_key = transport_model ._get_pkgkey ("ssm" )
145- if ssm_key is None :
146- continue
147- old_ssm_package = transport_model .pop (ssm_key )
148- state_variable_name = old_ssm_package .dataset [
149- "auxiliary_variable_name"
150- ].values [0 ]
151- ssm_package = SourceSinkMixing .from_flow_model (
152- flow_model , state_variable_name , is_split = True
153- )
154- if ssm_package is not None :
155- transport_model [ssm_key ] = ssm_package
156-
157- def update_solutions (
158- self , original_model_name_to_solution : dict [str , Solution ]
159- ) -> None :
160- """
161- Update the solutions to refer to the new partitioned models.
162- """
163- for model_name , new_models in self ._model_to_partitioned_model .items ():
164- solution = original_model_name_to_solution [model_name ]
165- solution ._remove_model_from_solution (model_name )
166- for new_model_name , new_model in new_models .items ():
167- solution ._add_model_to_solution (new_model_name )
168-
169- def split (self , model_name : str , model : IModel ) -> dict [str , IModel ]:
170- """
171- Split a model into multiple partitioned models based on configured partition information.
172110 Parameters
173111 ----------
174112 model_name : str
175- Base name of the input model; partition identifiers are appended to create
176- names for each resulting submodel.
113+ Base name of the input model. Partition IDs are appended to create
114+ unique names for each submodel (e.g., "model_0", "model_1") .
177115 model : IModel
178116 The input model instance to partition.
117+
179118 Returns
180119 -------
181120 dict[str, IModel]
182121 A mapping from generated submodel names to the corresponding partitioned
183- model instances, each containing only the packages and data relevant to its
184- active domain.
122+ model instances, each clipped to its respective active domain.
185123 """
186124 modelclass = type (model )
187125 partitioned_models = {}
@@ -233,7 +171,79 @@ def split(self, model_name: str, model: IModel) -> dict[str, IModel]:
233171
234172 return partitioned_models
235173
174+ def update_packages (self ) -> None :
175+ """
176+ Update packages that reference other models after partitioning.
177+
178+ This method performs two updates:
179+ 1. Updates buoyancy packages in flow models to reference the correct
180+ partitioned transport model names.
181+ 2. Recreates Source Sink Mixing (SSM) packages in transport models
182+ based on the partitioned flow model data.
183+ """
184+ # Update buoyancy packages
185+ for _ , models in self ._partition_id_to_models .items ():
186+ flow_model = self ._get_flow_model (models )
187+ transport_model_names = self ._get_transport_model_names (models )
188+
189+ flow_model ._update_buoyancy_package (transport_model_names )
190+
191+ # Update ssm packages
192+ for _ , models in self ._partition_id_to_models .items ():
193+ flow_model = self ._get_flow_model (models )
194+ transport_models = self ._get_transport_models (models )
195+
196+ for transport_model in transport_models :
197+ ssm_key = transport_model ._get_pkgkey ("ssm" )
198+ if ssm_key is None :
199+ continue
200+ old_ssm_package = transport_model .pop (ssm_key )
201+ state_variable_name = old_ssm_package .dataset [
202+ "auxiliary_variable_name"
203+ ].values [0 ]
204+ ssm_package = SourceSinkMixing .from_flow_model (
205+ flow_model , state_variable_name , is_split = True
206+ )
207+ if ssm_package is not None :
208+ transport_model [ssm_key ] = ssm_package
209+
210+ def update_solutions (
211+ self , original_model_name_to_solution : dict [str , Solution ]
212+ ) -> None :
213+ """
214+ Update solution objects to reference partitioned models instead of original models.
215+
216+ For each original model that was split:
217+ 1. Removes the original model reference from its solution (This was a deepcopy of the original solution and thus references the original model).
218+ 2. Adds all partitioned submodel references to the same solution
219+
220+ This ensures that the Solution objects correctly reference the new partitioned
221+ model names after splitting.
222+ """
223+ for model_name , new_models in self ._model_to_partitioned_model .items ():
224+ solution = original_model_name_to_solution [model_name ]
225+ solution ._remove_model_from_solution (model_name )
226+ for new_model_name , new_model in new_models .items ():
227+ solution ._add_model_to_solution (new_model_name )
228+
236229 def _get_package_domain (self , package : IPackage ) -> GridDataArray | None :
230+ """
231+ Extract the active domain of a boundary condition package.
232+
233+ For boundary condition packages, this method identifies which cells contain
234+ active boundary data by checking the package's defining variable (e.g.,
235+ "head" for CHD, "rate" for WEL). Non-boundary packages return None.
236+
237+ The active domain is determined by:
238+ 1. Retrieving the variable that defines active cells (from _pkg_id_to_var_mapping)
239+ 2. Removing non-spatial dimensions and the layer dimension
240+ 3. Creating a boolean mask where non-null values indicate active cells
241+
242+ Special cases:
243+ - IAgnosticPackages: No domain check (return None)
244+ - Packages in _pkg_id_skip_active_domain_check: No domain check (return None)
245+ - Non-boundary packages: Return None
246+ """
237247 pkg_id = package .pkg_id
238248 active_package_domain = None
239249
@@ -263,6 +273,19 @@ def _has_package_data_in_domain(
263273 active_package_domain : GridDataArray ,
264274 partition_info : PartitionInfo ,
265275 ) -> bool :
276+ """
277+ Check if a package has any active data within a partition's domain.
278+
279+ For boundary condition packages, this method determines whether the package
280+ should be included in a partitioned model by checking if its active cells
281+ overlap with the partition's active domain.
282+
283+ The method returns True in the following cases:
284+ - Package is not a BoundaryCondition (non-boundary packages are always included)
285+ - Package is an IAgnosticPackage (overlap check deferred until after slicing)
286+ - Package is in _pkg_id_skip_active_domain_check (e.g., SSM, LAK packages)
287+ - Package has at least one active cell overlapping with the partition domain
288+ """
266289 pkg_id = package .pkg_id
267290 has_overlap = True
268291 if isinstance (package , BoundaryCondition ):
@@ -278,3 +301,36 @@ def _has_package_data_in_domain(
278301 ).any () # type: ignore
279302
280303 return has_overlap
304+
305+ def _get_flow_model (self , models : dict [str , IModel ]) -> GroundwaterFlowModel :
306+ flow_model = next (
307+ (
308+ model
309+ for model_name , model in models .items ()
310+ if isinstance (model , GroundwaterFlowModel )
311+ ),
312+ None ,
313+ )
314+
315+ if flow_model is None :
316+ raise ValueError (
317+ "Could not find a groundwater flow model for updating the buoyancy package."
318+ )
319+
320+ return flow_model
321+
322+ def _get_transport_model_names (self , models : dict [str , IModel ]) -> List [str ]:
323+ return [
324+ model_name
325+ for model_name , model in models .items ()
326+ if isinstance (model , GroundwaterTransportModel )
327+ ]
328+
329+ def _get_transport_models (
330+ self , models : dict [str , IModel ]
331+ ) -> List [GroundwaterTransportModel ]:
332+ return [
333+ model
334+ for model_name , model in models .items ()
335+ if isinstance (model , GroundwaterTransportModel )
336+ ]
0 commit comments