Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_petab_test_suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
# retrieve test models
- name: Download and install PEtab test suite
run: |
git clone --depth 1 --branch pysb https://github.com/FFroehlich/petab_test_suite \
git clone --depth 1 --branch develop https://github.com/PEtab-dev/petab_test_suite \
&& source ./build/venv/bin/activate \
&& cd petab_test_suite && pip3 install -e . && cd ..

Expand Down
10 changes: 8 additions & 2 deletions include/amici/abstract_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,12 @@ class AbstractModel {
* @param t initial time
* @param p parameter vector
* @param k constant vector
* @param reinitialization_state_idxs Indices of states to be reinitialized
* based on provided constants / fixed parameters.
*/
virtual void fx0_fixedParameters(realtype *x0, const realtype t,
const realtype *p, const realtype *k);
const realtype *p, const realtype *k,
gsl::span<const int> reinitialization_state_idxs);

/**
* @brief Model specific implementation of fsx0_fixedParameters
Expand All @@ -242,10 +245,13 @@ class AbstractModel {
* @param p parameter vector
* @param k constant vector
* @param ip sensitivity index
* @param reinitialization_state_idxs Indices of states to be reinitialized
* based on provided constants / fixed parameters.
*/
virtual void fsx0_fixedParameters(realtype *sx0, const realtype t,
const realtype *x0, const realtype *p,
const realtype *k, int ip);
const realtype *k, int ip,
gsl::span<const int> reinitialization_state_idxs);

/**
* @brief Model specific implementation of fsx0
Expand Down
1 change: 1 addition & 0 deletions include/amici/edata.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ class ConditionContext : public ContextManager {
std::vector<int> original_parameter_list_;
std::vector<amici::ParameterScaling> original_scaling_;
bool original_reinitialize_fixed_parameter_initial_states_;
std::vector<int> original_reinitialization_state_idxs;
};

} // namespace amici
Expand Down
12 changes: 12 additions & 0 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,19 @@ class Model : public AbstractModel, public ModelDimensions {
*/
void fsx_rdata(AmiVectorArray &sx_rdata, const AmiVectorArray &sx_solver);

/**
* @brief Set indices of states to be reinitialized based on provided
* constants / fixed parameters
* @param idxs Array of state indices
*/
void setReinitializationStateIdxs(const std::vector<int> &idxs);

/**
* @brief Return indices of states to be reinitialized based on provided
* constants / fixed parameters
* @return Those indices.
*/
std::vector<int> const& getReinitializationStateIdxs() const;

/** Flag indicating Matlab- or Python-based model generation */
bool pythonGenerated;
Expand Down
49 changes: 49 additions & 0 deletions include/amici/simulation_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,43 @@ class SimulationParameters {
{
}

/**
* @brief Set reinitialization of all states based on model constants for
* presimulation (only meaningful if preequilibration is performed).
*
* Convenience function to populate
* `reinitialization_state_idxs_presim` and
* `reinitialization_state_idxs_sim`
*
* @param nx_rdata Number of states (Model::nx_rdata)
*/
void reinitializeAllFixedParameterDependentInitialStatesForPresimulation(int nx_rdata);

/**
* @brief Set reinitialization of all states based on model constants for
* the 'main' simulation (only meaningful if presimulation or
* preequilibration is performed).
*
* Convenience function to populate
* `reinitialization_state_idxs_presim` and
* `reinitialization_state_idxs_sim`
*
* @param nx_rdata Number of states (Model::nx_rdata)
*/
void reinitializeAllFixedParameterDependentInitialStatesForSimulation(int nx_rdata);

/**
* @brief Set reinitialization of all states based on model constants for
* all simulation phases.
*
* Convenience function to populate
* `reinitialization_state_idxs_presim` and
* `reinitialization_state_idxs_sim`
*
* @param nx_rdata Number of states (Model::nx_rdata)
*/
void reinitializeAllFixedParameterDependentInitialStates(int nx_rdata);

/**
* @brief Model constants
*
Expand Down Expand Up @@ -155,6 +192,18 @@ class SimulationParameters {
* fixed parameters is activated
*/
bool reinitializeFixedParameterInitialStates {false};

/**
* @brief Indices of states to be reinitialized based on provided
* presimulation constants / fixed parameters.
*/
std::vector<int> reinitialization_state_idxs_presim;

/**
* @brief Indices of states to be reinitialized based on provided
* constants / fixed parameters.
*/
std::vector<int> reinitialization_state_idxs_sim;
};

bool operator==(const SimulationParameters &a, const SimulationParameters &b);
Expand Down
36 changes: 25 additions & 11 deletions python/amici/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@
'x0_fixedParameters': {
'signature':
'(realtype *x0_fixedParameters, const realtype t, '
'const realtype *p, const realtype *k)',
'const realtype *p, const realtype *k, '
'gsl::span<const int> reinitialization_state_idxs)',
},
'sx0': {
'signature':
Expand All @@ -222,7 +223,7 @@
'signature':
'(realtype *sx0_fixedParameters, const realtype t, '
'const realtype *x0, const realtype *p, const realtype *k, '
'const int ip)',
'const int ip, gsl::span<const int> reinitialization_state_idxs)',
},
'xdot': {
'signature':
Expand Down Expand Up @@ -2627,6 +2628,7 @@ def _write_function_file(self, function: str) -> None:
'#include "amici/defines.h"',
'#include "sundials/sundials_types.h"',
'',
'#include <gsl/gsl-lite.hpp>',
'#include <array>',
]

Expand Down Expand Up @@ -2667,8 +2669,8 @@ def _write_function_file(self, function: str) -> None:
lines.extend([
'}',
'',
'} // namespace amici',
f'}} // namespace model_{self.model_name}',
'} // namespace amici\n',
])

# check custom functions
Expand Down Expand Up @@ -2758,8 +2760,8 @@ def _write_function_index(self, function: str, indextype: str) -> None:
lines.extend([
'}'
'',
'} // namespace amici',
f'}} // namespace model_{self.model_name}',
'} // namespace amici\n',
])

filename = f'{self.model_name}_{function}_{indextype}.cpp'
Expand Down Expand Up @@ -2815,8 +2817,11 @@ def _get_function_body(self,
# switch statement below only needs to handle non-zero entries
# (which usually reduces file size and speeds up
# compilation significantly).
" for(auto idx: _x0_fixedParameters_idxs) {",
" sx0_fixedParameters[idx] = 0.0;",
" for(auto idx: reinitialization_state_idxs) {",
" if(std::find(_x0_fixedParameters_idxs.cbegin(), "
"_x0_fixedParameters_idxs.cend(), idx) != "
"_x0_fixedParameters_idxs.cend())\n"
" sx0_fixedParameters[idx] = 0.0;",
" }"])

cases = dict()
Expand All @@ -2827,9 +2832,14 @@ def _get_function_body(self,
equations[:, ipar]
):
if not formula.is_zero:
expressions.append(
f'{function}[{index}] = '
f'{_print_with_exception(formula)};')
expressions.extend([
f'if(std::find('
'reinitialization_state_idxs.cbegin(), '
f'reinitialization_state_idxs.cend(), {index}) != '
'reinitialization_state_idxs.cend())',
f' {function}[{index}] = '
f'{_print_with_exception(formula)};'
])
cases[ipar] = expressions
lines.extend(get_switch_statement('ip', cases, 1))

Expand All @@ -2838,8 +2848,12 @@ def _get_function_body(self,
self.model._x0_fixedParameters_idx,
equations
):
lines.append(f'{function}[{index}] = '
f'{_print_with_exception(formula)};')
lines.append(
f' if(std::find(reinitialization_state_idxs.cbegin(), '
f'reinitialization_state_idxs.cend(), {index}) != '
'reinitialization_state_idxs.cend())\n '
f'{function}[{index}] = '
f'{_print_with_exception(formula)};')

elif function in event_functions:
outer_cases = {}
Expand Down
7 changes: 5 additions & 2 deletions python/amici/petab_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,11 @@ def import_model_sbml(
indicator.setName(PREEQ_INDICATOR_ID)
# Can only reset parameters after preequilibration if they are fixed.
fixed_parameters.append(PREEQ_INDICATOR_ID)

for assignee_id in initial_sizes + initial_states:
logger.debug("Adding preequilibration indicator "
f"constant {PREEQ_INDICATOR_ID}")
logger.debug("Adding initial assignments for "
f"{initial_sizes + initial_states}")
for assignee_id in chain(initial_sizes, initial_states):
init_par_id_preeq = f"initial_{assignee_id}_preeq"
init_par_id_sim = f"initial_{assignee_id}_sim"
for init_par_id in [init_par_id_preeq, init_par_id_sim]:
Expand Down
16 changes: 13 additions & 3 deletions python/amici/petab_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,12 +509,22 @@ def create_edata_for_condition(

##########################################################################
# enable initial parameters reinitialization
species_in_condition_table = any(
species_in_condition_table = [
col for col in petab_problem.condition_df
if petab_problem.sbml_model.getSpecies(col) is not None)
if not pd.isna(petab_problem.condition_df.loc[
condition[SIMULATION_CONDITION_ID], col])
and petab_problem.sbml_model.getSpecies(col) is not None
]
if condition.get(PREEQUILIBRATION_CONDITION_ID) \
and species_in_condition_table:
edata.reinitializeFixedParameterInitialStates = True
state_ids = amici_model.getStateIds()
state_idx_reinitalization = [state_ids.index(s)
for s in species_in_condition_table]
edata.reinitialization_state_idxs_sim = state_idx_reinitalization
logger.debug("Enabling state reinitialization for condition "
f"{condition.get(PREEQUILIBRATION_CONDITION_ID, '')} - "
f"{condition.get(SIMULATION_CONDITION_ID)} "
f"{species_in_condition_table}")

##########################################################################
# timepoints
Expand Down
6 changes: 4 additions & 2 deletions src/abstract_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ void
AbstractModel::fx0_fixedParameters(realtype* /*x0*/,
const realtype /*t*/,
const realtype* /*p*/,
const realtype* /*k*/)
const realtype* /*k*/,
gsl::span<const int> /*reinitialization_state_idxs*/)
{
// no-op default implementation
}
Expand All @@ -46,7 +47,8 @@ AbstractModel::fsx0_fixedParameters(realtype* /*sx0*/,
const realtype* /*x0*/,
const realtype* /*p*/,
const realtype* /*k*/,
const int /*ip*/)
const int /*ip*/,
gsl::span<const int> /*reinitialization_state_idxs*/)
{
// no-op default implementation
}
Expand Down
23 changes: 17 additions & 6 deletions src/edata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ ExpData::ExpData(Model const &model)
: ExpData(model.nytrue, model.nztrue, model.nMaxEvent(),
model.getTimepoints(), model.getFixedParameters()) {
reinitializeFixedParameterInitialStates =
model.getReinitializeFixedParameterInitialStates();
model.getReinitializeFixedParameterInitialStates()
&& model.getReinitializationStateIdxs().empty();
reinitialization_state_idxs_sim = model.getReinitializationStateIdxs();
}

ExpData::ExpData(ReturnData const& rdata, realtype sigma_y, realtype sigma_z)
Expand Down Expand Up @@ -336,7 +338,10 @@ ConditionContext::ConditionContext(Model *model, const ExpData *edata,
original_parameter_list_(model->getParameterList()),
original_scaling_(model->getParameterScale()),
original_reinitialize_fixed_parameter_initial_states_(
model->getReinitializeFixedParameterInitialStates())
model->getReinitializeFixedParameterInitialStates()
&& model->getReinitializationStateIdxs().empty()),
original_reinitialization_state_idxs(
model->getReinitializationStateIdxs())
{
if(model->hasCustomInitialStates())
original_x0_ = model->getInitialStates();
Expand Down Expand Up @@ -370,7 +375,6 @@ void ConditionContext::applyCondition(const ExpData *edata,
" match ExpData (%zd).",
model_->np(), edata->pscale.size());
model_->setParameterScale(edata->pscale);

}

if(!edata->x0.empty()) {
Expand Down Expand Up @@ -398,6 +402,9 @@ void ConditionContext::applyCondition(const ExpData *edata,
model_->setParameters(edata->parameters);
}

model_->setReinitializeFixedParameterInitialStates(
edata->reinitializeFixedParameterInitialStates);

switch (fpc) {
case FixedParameterContext::simulation:
if (!edata->fixedParameters.empty()) {
Expand All @@ -409,6 +416,9 @@ void ConditionContext::applyCondition(const ExpData *edata,
"not match ExpData (%zd).",
model_->nk(), edata->fixedParameters.size());
model_->setFixedParameters(edata->fixedParameters);
if(!edata->reinitializeFixedParameterInitialStates)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused, don't we want to check for edata->reinitializeFixedParameterInitialStates here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is, that if edata->reinitializeFixedParameterInitialStates == true, then this has been set on Model a couple of lines before. This will populate the index array in Model. If we would execute the next line, this would overwrite them with an empty array.
Not great, but keeps backward compatibility for the moment and still allow selective reinitialization. To be cleaned up soon.

model_->setReinitializationStateIdxs(
edata->reinitialization_state_idxs_sim);
}
break;
case FixedParameterContext::preequilibration:
Expand All @@ -435,6 +445,9 @@ void ConditionContext::applyCondition(const ExpData *edata,
model_->nk(),
edata->fixedParametersPresimulation.size());
model_->setFixedParameters(edata->fixedParametersPresimulation);
if(!edata->reinitializeFixedParameterInitialStates)
model_->setReinitializationStateIdxs(
edata->reinitialization_state_idxs_presim);
}
break;
}
Expand All @@ -443,9 +456,6 @@ void ConditionContext::applyCondition(const ExpData *edata,
// fixed parameter in model are superseded by those provided in edata
model_->setTimepoints(edata->getTimepoints());
}

model_->setReinitializeFixedParameterInitialStates(
edata->reinitializeFixedParameterInitialStates);
}

void ConditionContext::restore()
Expand All @@ -466,6 +476,7 @@ void ConditionContext::restore()
model_->setTimepoints(original_timepoints_);
model_->setReinitializeFixedParameterInitialStates(
original_reinitialize_fixed_parameter_initial_states_);
model_->setReinitializationStateIdxs(original_reinitialization_state_idxs);

}

Expand Down
Loading