33import logging
44import os
55import re
6- from typing import Iterable , Optional , Callable , Union , Any , Sequence , List
6+ from typing import (
7+ Iterable , Optional , Callable , Union , Any , Sequence , List , Dict ,
8+ )
79from warnings import warn
810
911import numpy as np
1719 'write_visualization_df' , 'get_notnull_columns' ,
1820 'flatten_timepoint_specific_output_overrides' ,
1921 'concat_tables' , 'to_float_if_float' , 'is_empty' ,
20- 'create_combine_archive' , 'unique_preserve_order' ]
22+ 'create_combine_archive' , 'unique_preserve_order' ,
23+ 'unflatten_simulation_df' ]
24+
25+ POSSIBLE_GROUPVARS_FLATTENED_PROBLEM = [
26+ OBSERVABLE_ID ,
27+ OBSERVABLE_PARAMETERS ,
28+ NOISE_PARAMETERS ,
29+ SIMULATION_CONDITION_ID ,
30+ PREEQUILIBRATION_CONDITION_ID ,
31+ ]
2132
2233
2334def get_simulation_df (simulation_file : Union [str , Path ]) -> pd .DataFrame :
@@ -90,6 +101,99 @@ def get_notnull_columns(df: pd.DataFrame, candidates: Iterable):
90101 if col in df and not np .all (df [col ].isnull ())]
91102
92103
104+ def get_observable_replacement_id (groupvars , groupvar ) -> str :
105+ """Get the replacement ID for an observable.
106+
107+ Arguments:
108+ groupvars:
109+ The columns of a PEtab measurement table that should be unique
110+ between observables in a flattened PEtab problem.
111+ groupvar:
112+ A specific grouping of `groupvars`.
113+
114+ Returns:
115+ The observable replacement ID.
116+ """
117+ replacement_id = ''
118+ for field in POSSIBLE_GROUPVARS_FLATTENED_PROBLEM :
119+ if field in groupvars :
120+ val = str (groupvar [groupvars .index (field )])\
121+ .replace (PARAMETER_SEPARATOR , '_' ).replace ('.' , '_' )
122+ if replacement_id == '' :
123+ replacement_id = val
124+ elif val != '' :
125+ replacement_id += f'__{ val } '
126+ return replacement_id
127+
128+
129+ def get_hyperparameter_replacement_id (
130+ hyperparameter_type ,
131+ observable_replacement_id ,
132+ ):
133+ """Get the full ID for a replaced hyperparameter.
134+
135+ Arguments:
136+ hyperparameter_type:
137+ The type of hyperparameter, e.g. `noiseParameter`.
138+ observable_replacement_id:
139+ The observable replacement ID, e.g. the output of
140+ `get_observable_replacement_id`.
141+
142+ Returns:
143+ The hyperparameter replacement ID, with a field that will be replaced
144+ by the first matched substring in a regex substitution.
145+ """
146+ return f'{ hyperparameter_type } \\ 1_{ observable_replacement_id } '
147+
148+
149+ def get_flattened_id_mappings (
150+ petab_problem : 'petab.problem.Problem' ,
151+ ) -> Dict [str , Dict [str , str ]]:
152+ """Get mapping from unflattened to flattened observable IDs.
153+
154+ Arguments:
155+ petab_problem:
156+ The unflattened PEtab problem.
157+
158+ Returns:
159+ A dictionary of dictionaries. Each inner dictionary is a mapping
160+ from original ID to flattened ID. Each outer dictionary is the mapping
161+ for either: observable IDs; noise parameter IDs; or, observable
162+ parameter IDs.
163+ """
164+ groupvars = get_notnull_columns (petab_problem .measurement_df ,
165+ POSSIBLE_GROUPVARS_FLATTENED_PROBLEM )
166+ mappings = {
167+ OBSERVABLE_ID : {},
168+ NOISE_PARAMETERS : {},
169+ OBSERVABLE_PARAMETERS : {},
170+ }
171+ for groupvar , measurements in \
172+ petab_problem .measurement_df .groupby (groupvars , dropna = False ):
173+ observable_id = groupvar [groupvars .index (OBSERVABLE_ID )]
174+ observable_replacement_id = \
175+ get_observable_replacement_id (groupvars , groupvar )
176+
177+ logger .debug (f'Creating synthetic observable { observable_id } ' )
178+ if observable_replacement_id in petab_problem .observable_df .index :
179+ raise RuntimeError ('could not create synthetic observables '
180+ f'since { observable_replacement_id } was '
181+ 'already present in observable table' )
182+
183+ mappings [OBSERVABLE_ID ][observable_replacement_id ] = observable_id
184+
185+ for field , hyperparameter_type , target in [
186+ (NOISE_PARAMETERS , 'noiseParameter' , NOISE_FORMULA ),
187+ (OBSERVABLE_PARAMETERS , 'observableParameter' , OBSERVABLE_FORMULA )
188+ ]:
189+ if field in measurements :
190+ mappings [field ][get_hyperparameter_replacement_id (
191+ hyperparameter_type = hyperparameter_type ,
192+ observable_replacement_id = observable_replacement_id ,
193+ )] = fr'{ hyperparameter_type } ([0-9]+)_{ observable_id } '
194+ return mappings
195+
196+
93197def flatten_timepoint_specific_output_overrides (
94198 petab_problem : 'petab.problem.Problem' ,
95199) -> None :
@@ -109,44 +213,38 @@ def flatten_timepoint_specific_output_overrides(
109213 """
110214 new_measurement_dfs = []
111215 new_observable_dfs = []
112- possible_groupvars = [OBSERVABLE_ID , OBSERVABLE_PARAMETERS ,
113- NOISE_PARAMETERS , SIMULATION_CONDITION_ID ,
114- PREEQUILIBRATION_CONDITION_ID ]
115216 groupvars = get_notnull_columns (petab_problem .measurement_df ,
116- possible_groupvars )
217+ POSSIBLE_GROUPVARS_FLATTENED_PROBLEM )
218+
219+ mappings = get_flattened_id_mappings (petab_problem )
220+
117221 for groupvar , measurements in \
118222 petab_problem .measurement_df .groupby (groupvars , dropna = False ):
119223 obs_id = groupvar [groupvars .index (OBSERVABLE_ID )]
120- # construct replacement id
121- replacement_id = ''
122- for field in possible_groupvars :
123- if field in groupvars :
124- val = str (groupvar [groupvars .index (field )])\
125- .replace (PARAMETER_SEPARATOR , '_' ).replace ('.' , '_' )
126- if replacement_id == '' :
127- replacement_id = val
128- elif val != '' :
129- replacement_id += f'__{ val } '
130-
131- logger .debug (f'Creating synthetic observable { obs_id } ' )
132- if replacement_id in petab_problem .observable_df .index :
133- raise RuntimeError ('could not create synthetic observables '
134- f'since { replacement_id } was already '
135- 'present in observable table' )
224+ observable_replacement_id = \
225+ get_observable_replacement_id (groupvars , groupvar )
226+
136227 observable = petab_problem .observable_df .loc [obs_id ].copy ()
137- observable .name = replacement_id
138- for field , parname , target in [
228+ observable .name = observable_replacement_id
229+ for field , hyperparameter_type , target in [
139230 (NOISE_PARAMETERS , 'noiseParameter' , NOISE_FORMULA ),
140231 (OBSERVABLE_PARAMETERS , 'observableParameter' , OBSERVABLE_FORMULA )
141232 ]:
142233 if field in measurements :
234+ hyperparameter_replacement_id = \
235+ get_hyperparameter_replacement_id (
236+ hyperparameter_type = hyperparameter_type ,
237+ observable_replacement_id = observable_replacement_id ,
238+ )
239+ hyperparameter_id = \
240+ mappings [field ][hyperparameter_replacement_id ]
143241 observable [target ] = re .sub (
144- fr' { parname } ([0-9]+)_ { obs_id } ' ,
145- f' { parname } \\ 1_ { replacement_id } ' ,
146- observable [target ]
242+ hyperparameter_id ,
243+ hyperparameter_replacement_id ,
244+ observable [target ],
147245 )
148246
149- measurements [OBSERVABLE_ID ] = replacement_id
247+ measurements [OBSERVABLE_ID ] = observable_replacement_id
150248 new_measurement_dfs .append (measurements )
151249 new_observable_dfs .append (observable )
152250
@@ -155,6 +253,37 @@ def flatten_timepoint_specific_output_overrides(
155253 petab_problem .measurement_df = pd .concat (new_measurement_dfs )
156254
157255
256+ def unflatten_simulation_df (
257+ simulation_df : pd .DataFrame ,
258+ petab_problem : 'petab.problem.Problem' ,
259+ ) -> None :
260+ """Unflatten simulations from a flattened PEtab problem.
261+
262+ A flattened PEtab problem is the output of applying
263+ :func:`flatten_timepoint_specific_output_overrides` to a PEtab problem.
264+
265+ Arguments:
266+ simulation_df:
267+ The simulation dataframe. A dataframe in the same format as a PEtab
268+ measurements table, but with the ``measurement`` column switched
269+ with a ``simulation`` column.
270+ petab_problem:
271+ The unflattened PEtab problem.
272+
273+ Returns:
274+ The simulation dataframe for the unflattened PEtab problem.
275+ """
276+ mappings = get_flattened_id_mappings (petab_problem )
277+ original_observable_ids = (
278+ simulation_df [OBSERVABLE_ID ]
279+ .replace (mappings [OBSERVABLE_ID ])
280+ )
281+ unflattened_simulation_df = simulation_df .assign (** {
282+ OBSERVABLE_ID : original_observable_ids ,
283+ })
284+ return unflattened_simulation_df
285+
286+
158287def concat_tables (
159288 tables : Union [str , Path , pd .DataFrame ,
160289 Iterable [Union [pd .DataFrame , str , Path ]]],
0 commit comments