Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_petab_test_suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
# retrieve test models
- name: Download and install PEtab test suite
run: |
git clone --depth 1 --branch pysb https://github.com/FFroehlich/petab_test_suite \
git clone --depth 1 --branch develop https://github.com/PEtab-dev/petab_test_suite \
&& source ./build/venv/bin/activate \
&& cd petab_test_suite && pip3 install -e . && cd ..

Expand Down
14 changes: 9 additions & 5 deletions include/amici/abstract_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <sunmatrix/sunmatrix_sparse.h>

#include <memory>
#include <set>

namespace amici {

Expand Down Expand Up @@ -230,9 +231,11 @@ class AbstractModel {
* @param t initial time
* @param p parameter vector
* @param k constant vector
* @return set of indices of states that have been reset
*/
virtual void fx0_fixedParameters(realtype *x0, const realtype t,
const realtype *p, const realtype *k);
virtual std::set<int> fx0_fixedParameters(
realtype *x0, const realtype t,
const realtype *p, const realtype *k);

/**
* @brief Model specific implementation of fsx0_fixedParameters
Expand All @@ -242,10 +245,11 @@ class AbstractModel {
* @param p parameter vector
* @param k constant vector
* @param ip sensitivity index
* @param resettedStateIdxs set of indices of states have been reset
*/
virtual void fsx0_fixedParameters(realtype *sx0, const realtype t,
const realtype *x0, const realtype *p,
const realtype *k, int ip);
virtual void fsx0_fixedParameters(
realtype *sx0, const realtype t, const realtype *x0, const realtype *p,
const realtype *k, int ip, const std::set<int>& resettedStateIdxs);

/**
* @brief Model specific implementation of fsx0
Expand Down
8 changes: 6 additions & 2 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -1231,8 +1231,9 @@ class Model : public AbstractModel {
* @brief Set only those initial states that are specified via
* fixed parameters.
* @param x Output buffer.
* @return set of indices of states that have been reset
*/
void fx0_fixedParameters(AmiVector &x);
std::set<int> fx0_fixedParameters(AmiVector &x);

/**
* @brief Compute/get initial value for initial state sensitivities.
Expand All @@ -1246,8 +1247,11 @@ class Model : public AbstractModel {
* from `amici::Model::fx0_fixedParameters`.
* @param sx Output buffer for state sensitivities
* @param x State variables
* @param resettedStateIdxs set of indices of states have been reset
*/
void fsx0_fixedParameters(AmiVectorArray &sx, const AmiVector &x);
void fsx0_fixedParameters(AmiVectorArray &sx,
const AmiVector &x,
const std::set<int>& resettedStateIdxs);

/**
* @brief Compute sensitivity of derivative initial states sensitivities
Expand Down
54 changes: 34 additions & 20 deletions python/amici/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@
'signature':
'(realtype *x0_fixedParameters, const realtype t, '
'const realtype *p, const realtype *k)',
'ret_type': 'std::set<int>'
},
'sx0': {
'signature':
Expand All @@ -222,7 +223,7 @@
'signature':
'(realtype *sx0_fixedParameters, const realtype t, '
'const realtype *x0, const realtype *p, const realtype *k, '
'const int ip)',
'const int ip, const std::set<int> &resettedStateIdxs)',
},
'xdot': {
'signature':
Expand Down Expand Up @@ -2605,11 +2606,12 @@ def _write_function_file(self, function: str) -> None:
'#include "sundials/sundials_types.h"',
'',
'#include <array>',
'#include <set>',
]

# function signature
signature = self.functions[function]['signature']

ret_type = self.functions[function].get('ret_type', 'void')
lines.append('')

for sym in self.model.sym_names():
Expand All @@ -2627,7 +2629,7 @@ def _write_function_file(self, function: str) -> None:
'',
])

lines.append(f'void {function}_{self.model_name}{signature}{{')
lines.append(f'{ret_type} {function}_{self.model_name}{signature}{{')

# function body
body = self._get_function_body(function, equations)
Expand Down Expand Up @@ -2721,9 +2723,10 @@ def _write_function_index(self, function: str, indextype: str) -> None:
lines.append(' ' + ', '.join(map(str, values)))
lines.append("};")

ret_type = self.functions[function].get('ret_type', 'void')
lines.extend([
'',
f'void {function}_{indextype}_{self.model_name}{signature}{{',
f'{ret_type} {function}_{indextype}_{self.model_name}{signature}{{',
])

if len(values):
Expand Down Expand Up @@ -2779,20 +2782,11 @@ def _get_function_body(self,
# was applied

lines.extend([
# Keep list of indices of fixed parameters occurring in x0
" static const std::array<int, "
+ str(len(self.model._x0_fixedParameters_idx))
+ "> _x0_fixedParameters_idxs = {",
" "
+ ', '.join(str(x)
for x in self.model._x0_fixedParameters_idx),
" };",
"",
# Set all parameters that are to be reset to 0, so that the
# switch statement below only needs to handle non-zero entries
# (which usually reduces file size and speeds up
# compilation significantly).
" for(auto idx: _x0_fixedParameters_idxs) {",
" for(auto idx: resettedStateIdxs) {",
" sx0_fixedParameters[idx] = 0.0;",
" }"])

Expand All @@ -2805,18 +2799,35 @@ def _get_function_body(self,
):
if not formula.is_zero:
expressions.append(
f'if(resettedStateIdxs.find({index}) != '
'resettedStateIdxs.end()) '
f'{function}[{index}] = '
f'{_print_with_exception(formula)};')
cases[ipar] = expressions
lines.extend(get_switch_statement('ip', cases, 1))

elif function == 'x0_fixedParameters':
lines.append(" realtype tmp;\n"
" std::set<int> resettedStateIdxs;")
for index, formula in zip(
self.model._x0_fixedParameters_idx,
equations
):
lines.append(f'{function}[{index}] = '
f'{_print_with_exception(formula)};')
lines.append(f' tmp = {_print_with_exception(formula)};\n'
' if(!std::isnan(tmp)) {\n'
f' {function}[{index}] = tmp;\n'
f' resettedStateIdxs.emplace({index});\n'
' }')
lines.append(" return resettedStateIdxs;")
elif function == 'x0':
lines.append("realtype tmp;")

lines.extend([
f' tmp = {_print_with_exception(math)};'
f' if(!std::isnan(tmp)) {function}[{index}] = tmp;'
for index, math in enumerate(equations)
if not (math == 0 or math == 0.0)
])

elif function in event_functions:
outer_cases = {}
Expand Down Expand Up @@ -3200,8 +3211,9 @@ def get_function_extern_declaration(fun: str, name: str) -> str:
c++ function definition string

"""
return \
f'extern void {fun}_{name}{functions[fun]["signature"]};'
signature = functions[fun]["signature"]
ret_type = functions[fun].get('ret_type', 'void')
return f'extern {ret_type} {fun}_{name}{signature};'


def get_sunindex_extern_declaration(fun: str, name: str,
Expand Down Expand Up @@ -3244,11 +3256,12 @@ def get_model_override_implementation(fun: str, name: str) -> str:

"""
return \
'virtual void f{fun}{signature} override {{\n' \
'virtual {ret_type} f{fun}{signature} override {{\n' \
'{ind8}{fun}_{name}{eval_signature};\n' \
'{ind4}}}\n'.format(
ind4=' '*4,
ind8=' '*8,
ret_type=functions[fun].get('ret_type', 'void'),
fun=fun,
name=name,
signature=functions[fun]["signature"],
Expand Down Expand Up @@ -3279,11 +3292,12 @@ def get_sunindex_override_implementation(fun: str, name: str,
index_arg_eval = ', index' if fun in multiobs_functions else ''

return \
'virtual void f{fun}_{indextype}{signature} override {{\n' \
'virtual {ret_type} f{fun}_{indextype}{signature} override {{\n' \
'{ind8}{fun}_{indextype}_{name}{eval_signature};\n' \
'{ind4}}}\n'.format(
ind4=' '*4,
ind8=' '*8,
ret_type=functions[fun].get('ret_type', 'void'),
fun=fun,
indextype=indextype,
name=name,
Expand Down
7 changes: 5 additions & 2 deletions src/abstract_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ AbstractModel::isFixedParameterStateReinitializationAllowed() const
return false;
}

void
std::set<int>
AbstractModel::fx0_fixedParameters(realtype* /*x0*/,
const realtype /*t*/,
const realtype* /*p*/,
const realtype* /*k*/)
{
// no-op default implementation
return std::set<int>();
}

void
Expand All @@ -46,7 +47,9 @@ AbstractModel::fsx0_fixedParameters(realtype* /*sx0*/,
const realtype* /*x0*/,
const realtype* /*p*/,
const realtype* /*k*/,
const int /*ip*/)
const int /*ip*/,
const std::set<int>& /*resettedStateIdxs*/
)
{
// no-op default implementation
}
Expand Down
20 changes: 12 additions & 8 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1166,19 +1166,21 @@ void Model::fx0(AmiVector &x) {
}
}

void Model::fx0_fixedParameters(AmiVector &x) {
std::set<int>Model::fx0_fixedParameters(AmiVector &x) {
if (!getReinitializeFixedParameterInitialStates())
return;
return std::set<int>();
/* we transform to the unreduced states x_rdata and then apply
x0_fixedparameters to (i) enable updates to states that were removed from
conservation laws and (ii) be able to correctly compute total abundances
after updating the state variables */
fx_rdata(x_rdata_.data(), x.data(), state_.total_cl.data());
fx0_fixedParameters(x_rdata_.data(), tstart_, state_.unscaledParameters.data(),
state_.fixedParameters.data());
auto resettedStateIdxs = fx0_fixedParameters(
x_rdata_.data(), tstart_, state_.unscaledParameters.data(),
state_.fixedParameters.data());
fx_solver(x.data(), x_rdata_.data());
/* update total abundances */
ftotal_cl(state_.total_cl.data(), x_rdata_.data());
return resettedStateIdxs;
}

void Model::fsx0(AmiVectorArray &sx, const AmiVector &x) {
Expand All @@ -1195,18 +1197,20 @@ void Model::fsx0(AmiVectorArray &sx, const AmiVector &x) {
}
}

void Model::fsx0_fixedParameters(AmiVectorArray &sx, const AmiVector &x) {
void Model::fsx0_fixedParameters(AmiVectorArray &sx,
const AmiVector &x,
const std::set<int>& resettedStateIdxs) {
if (!getReinitializeFixedParameterInitialStates())
return;
realtype *stcl = nullptr;
for (int ip = 0; ip < nplist(); ip++) {
if (ncl() > 0)
stcl = &state_.stotal_cl.at(plist(ip) * ncl());
fsx_rdata(sx_rdata_.data(), sx.data(ip), stcl, plist(ip));
fsx0_fixedParameters(sx_rdata_.data(), tstart_, x.data(),
state_.unscaledParameters.data(),
fsx0_fixedParameters(sx_rdata_.data(), tstart_,
x.data(), state_.unscaledParameters.data(),
state_.fixedParameters.data(),
plist(ip));
plist(ip), resettedStateIdxs);
fsx_solver(sx.data(ip), sx_rdata_.data());
fstotal_cl(stcl, sx_rdata_.data(), plist(ip));
}
Expand Down
15 changes: 9 additions & 6 deletions src/model_header.ODE_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define _amici_TPL_MODELNAME_h
#include <cmath>
#include <memory>
#include <set>

#include "amici/model_ode.h"
#include "amici/solver_cvodes.h"
Expand Down Expand Up @@ -76,7 +77,7 @@ extern void sigmay_TPL_MODELNAME(realtype *sigmay, const realtype t,
TPL_W_DEF
extern void x0_TPL_MODELNAME(realtype *x0, const realtype t, const realtype *p,
const realtype *k);
extern void x0_fixedParameters_TPL_MODELNAME(realtype *x0, const realtype t,
extern std::set<int> x0_fixedParameters_TPL_MODELNAME(realtype *x0, const realtype t,
const realtype *p,
const realtype *k);
extern void sx0_TPL_MODELNAME(realtype *sx0, const realtype t,
Expand All @@ -85,7 +86,8 @@ extern void sx0_TPL_MODELNAME(realtype *sx0, const realtype t,
extern void sx0_fixedParameters_TPL_MODELNAME(realtype *sx0, const realtype t,
const realtype *x0,
const realtype *p,
const realtype *k, const int ip);
const realtype *k, const int ip,
const std::set<int> &resettedParameterIdxs);
extern void xdot_TPL_MODELNAME(realtype *xdot, const realtype t,
const realtype *x, const realtype *p,
const realtype *k, const realtype *h,
Expand Down Expand Up @@ -597,8 +599,9 @@ class Model_TPL_MODELNAME : public amici::Model_ODE {
virtual void fsx0_fixedParameters(realtype *sx0, const realtype t,
const realtype *x0, const realtype *p,
const realtype *k,
const int ip) override {
sx0_fixedParameters_TPL_MODELNAME(sx0, t, x0, p, k, ip);
const int ip,
const std::set<int> &resettedStateIdxs) override {
sx0_fixedParameters_TPL_MODELNAME(sx0, t, x0, p, k, ip, resettedStateIdxs);
}

/** model specific implementation of fsz
Expand Down Expand Up @@ -636,10 +639,10 @@ class Model_TPL_MODELNAME : public amici::Model_ODE {
* @param p parameter vector
* @param k constant vector
**/
virtual void fx0_fixedParameters(realtype *x0, const realtype t,
virtual std::set<int> fx0_fixedParameters(realtype *x0, const realtype t,
const realtype *p,
const realtype *k) override {
x0_fixedParameters_TPL_MODELNAME(x0, t, p, k);
return x0_fixedParameters_TPL_MODELNAME(x0, t, p, k);
}

/** model specific implementation for fxdot
Expand Down
4 changes: 2 additions & 2 deletions src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,11 @@ void Solver::setupSteadystate(const realtype t0, Model *model, const AmiVector &
}

void Solver::updateAndReinitStatesAndSensitivities(Model *model) {
model->fx0_fixedParameters(x_);
auto resettedStateIdxs = model->fx0_fixedParameters(x_);
reInit(t_, x_, dx_);

if (getSensitivityOrder() >= SensitivityOrder::first) {
model->fsx0_fixedParameters(sx_, x_);
model->fsx0_fixedParameters(sx_, x_, resettedStateIdxs);
if (getSensitivityMethod() == SensitivityMethod::forward)
sensReInit(sx_, sdx_);
}
Expand Down
7 changes: 3 additions & 4 deletions tests/petab_test_suite/test_petab_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from amici.petab_import import import_petab_problem, PysbPetabProblem
from amici.petab_objective import (
simulate_petab, rdatas_to_measurement_df, create_parameterized_edatas)
from amici import SteadyStateSensitivityMode_simulationFSA

logger = get_logger(__name__, logging.DEBUG)
set_log_level(get_logger("amici.petab_import"), logging.DEBUG)
Expand Down Expand Up @@ -127,12 +126,12 @@ def check_derivatives(problem: petab.Problem, model: amici.Model) -> None:
problem_parameters = {t.Index: getattr(t, petab.NOMINAL_VALUE) for t in
problem.parameter_df.itertuples()}
solver = model.getSolver()
solver.setSensitivityMethod(amici.SensitivityMethod_forward)
solver.setSensitivityOrder(amici.SensitivityOrder_first)
solver.setSensitivityMethod(amici.SensitivityMethod.forward)
solver.setSensitivityOrder(amici.SensitivityOrder.first)
# Required for case 9 to not fail in
# amici::NewtonSolver::computeNewtonSensis
model.setSteadyStateSensitivityMode(
SteadyStateSensitivityMode_simulationFSA)
amici.SteadyStateSensitivityMode.simulationFSA)

def assert_true(x):
assert x
Expand Down