Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 169 additions & 78 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,171 @@ void serialize(Archive &ar, amici::Model &m, unsigned int version);
} // namespace boost

namespace amici {
/**
* @brief Container for model dimensions.
*
* Holds number of states, observables, etc.
*/
struct ModelDimensions {
/** Default ctor */
ModelDimensions() = default;

/**
* @brief Constructor with model dimensions
* @param nx_rdata Number of state variables
* @param nxtrue_rdata Number of state variables of the non-augmented model
* @param nx_solver Number of state variables with conservation laws applied
* @param nxtrue_solver Number of state variables of the non-augmented model
* with conservation laws applied
* @param nx_solver_reinit Number of state variables with conservation laws
* subject to reinitialization
* @param np Number of parameters
* @param nk Number of constants
* @param ny Number of observables
* @param nytrue Number of observables of the non-augmented model
* @param nz Number of event observables
* @param nztrue Number of event observables of the non-augmented model
* @param ne Number of events
* @param nJ Number of objective functions
* @param nw Number of repeating elements
* @param ndwdx Number of nonzero elements in the `x` derivative of the
* repeating elements
* @param ndwdp Number of nonzero elements in the `p` derivative of the
* repeating elements
* @param ndwdw Number of nonzero elements in the `w` derivative of the
* repeating elements
* @param ndxdotdw Number of nonzero elements in the \f$w\f$ derivative of
* \f$xdot\f$
* @param ndJydy Number of nonzero elements in the \f$y\f$ derivative of
* \f$dJy\f$ (dimension `nytrue`)
* @param nnz Number of nonzero elements in Jacobian
* @param ubw Upper matrix bandwidth in the Jacobian
* @param lbw Lower matrix bandwidth in the Jacobian
*/
ModelDimensions(
const int nx_rdata, const int nxtrue_rdata, const int nx_solver,
const int nxtrue_solver, const int nx_solver_reinit, const int np,
const int nk, const int ny,
const int nytrue, const int nz, const int nztrue, const int ne,
const int nJ, const int nw, const int ndwdx, const int ndwdp,
const int ndwdw, const int ndxdotdw, std::vector<int> ndJydy,
const int nnz, const int ubw, const int lbw)
: nx_rdata(nx_rdata), nxtrue_rdata(nxtrue_rdata), nx_solver(nx_solver),
nxtrue_solver(nxtrue_solver), nx_solver_reinit(nx_solver_reinit),
np(np), nk(nk),
ny(ny), nytrue(nytrue), nz(nz), nztrue(nztrue),
ne(ne), nw(nw), ndwdx(ndwdx), ndwdp(ndwdp), ndwdw(ndwdw),
ndxdotdw(ndxdotdw), ndJydy(std::move(ndJydy)),
nnz(nnz), nJ(nJ), ubw(ubw), lbw(lbw) {
Expects(nxtrue_rdata >= 0);
Expects(nxtrue_rdata <= nx_rdata);
Expects(nxtrue_solver >= 0);
Expects(nx_solver <= nx_rdata);
Expects(nxtrue_solver <= nx_solver);
Expects(nx_solver_reinit >= 0);
Expects(nx_solver_reinit <= nx_solver);
Expects(np >= 0);
Expects(nk >= 0);
Expects(nytrue <= ny);
Expects(nytrue >= 0);
Expects(nztrue >= 0);
Expects(nztrue <= nz);
Expects(ne >= 0);
Expects(nw >= 0);
Expects(ndwdx >= 0);
Expects(ndwdx <= nw * nx_solver);
Expects(ndwdp >= 0);
Expects(ndwdp <= nw * np);
Expects(ndwdw >= 0);
Expects(ndwdw <= nw * nw);
Expects(ndxdotdw >= 0);
Expects(nnz >= 0);
Expects(nJ >= 0);
Expects(ubw >= 0);
Expects(lbw >= 0);
}

/** Number of states */
int nx_rdata{0};

/** Number of states in the unaugmented system */
int nxtrue_rdata{0};

/** Number of states with conservation laws applied */
int nx_solver{0};

/**
* Number of states in the unaugmented system with conservation laws
* applied
*/
int nxtrue_solver{0};

/** Number of solver states subject to reinitialization */
int nx_solver_reinit{0};

/** Number of parameters */
int np{0};

/** Number of constants */
int nk{0};

/** Number of observables */
int ny{0};

/** Number of observables in the unaugmented system */
int nytrue{0};

/** Number of event outputs */
int nz{0};

/** Number of event outputs in the unaugmented system */
int nztrue{0};

/** Number of events */
int ne{0};

/** Number of common expressions */
int nw{0};

/**
* Number of nonzero elements in the `x` derivative of the
* repeating elements
*/
int ndwdx {0};

/**
* Number of nonzero elements in the `p` derivative of the
* repeating elements
*/
int ndwdp {0};

/**
* Number of nonzero elements in the `w` derivative of the
* repeating elements
*/
int ndwdw {0};

/** Number of nonzero elements in the \f$w\f$ derivative of \f$xdot\f$ */
int ndxdotdw {0};

/**
* Number of nonzero elements in the \f$y\f$ derivative of
* \f$dJy\f$ (dimension `nytrue`)
*/
std::vector<int> ndJydy;

/** Number of nonzero entries in Jacobian */
int nnz{0};

/** Dimension of the augmented objective function for 2nd order ASA */
int nJ{0};

/** Upper bandwidth of the Jacobian */
int ubw{0};

/** Lower bandwidth of the Jacobian */
int lbw{0};
};

/**
* @brief Exchange format to store and transfer the state of the
Expand Down Expand Up @@ -74,40 +239,14 @@ struct ModelState {
* The model can compute various model related quantities based on symbolically
* generated code.
*/
class Model : public AbstractModel {
class Model : public AbstractModel, public ModelDimensions {
public:
/** Default constructor */
Model() = default;

/**
* @brief Constructor with model dimensions.
* @param nx_rdata Number of state variables
* @param nxtrue_rdata Number of state variables of the non-augmented model
* @param nx_solver Number of state variables with conservation laws applied
* @param nxtrue_solver Number of state variables of the non-augmented model
* with conservation laws applied
* @param nx_solver_reinit Number of state variables with conservation laws
* subject to reinitialization
* @param ny Number of observables
* @param nytrue Number of observables of the non-augmented model
* @param nz Number of event observables
* @param nztrue Number of event observables of the non-augmented model
* @param ne Number of events
* @param nJ Number of objective functions
* @param nw Number of repeating elements
* @param ndwdx Number of nonzero elements in the `x` derivative of the
* repeating elements
* @param ndwdp Number of nonzero elements in the `p` derivative of the
* repeating elements
* @param ndwdw Number of nonzero elements in the `w` derivative of the
* repeating elements
* @param ndxdotdw Number of nonzero elements in the \f$w\f$ derivative of
* \f$xdot\f$
* @param ndJydy Number of nonzero elements in the \f$y\f$ derivative of
* \f$dJy\f$ (dimension `nytrue`)
* @param nnz Number of nonzero elements in Jacobian
* @param ubw Upper matrix bandwidth in the Jacobian
* @param lbw Lower matrix bandwidth in the Jacobian
* @param model_dimensions Model dimensions
* @param o2mode Second order sensitivity mode
* @param p Parameters
* @param k Constants
Expand All @@ -119,10 +258,7 @@ class Model : public AbstractModel {
* @param ndxdotdx_explicit Number of nonzero elements in `dxdotdx_explicit`
* @param w_recursion_depth Recursion depth of fw
*/
Model(int nx_rdata, int nxtrue_rdata, int nx_solver, int nxtrue_solver,
int nx_solver_reinit, int ny, int nytrue, int nz, int nztrue, int ne,
int nJ, int nw, int ndwdx, int ndwdp, int ndwdw, int ndxdotdw,
std::vector<int> ndJydy, int nnz, int ubw, int lbw,
Model(ModelDimensions const& model_dimensions,
amici::SecondOrderMode o2mode,
const std::vector<amici::realtype> &p, std::vector<amici::realtype> k,
const std::vector<int> &plist, std::vector<amici::realtype> idlist,
Expand Down Expand Up @@ -1275,53 +1411,7 @@ class Model : public AbstractModel {
*/
void fsx_rdata(AmiVectorArray &sx_rdata, const AmiVectorArray &sx_solver);

/** Number of states */
int nx_rdata{0};

/** Number of states in the unaugmented system */
int nxtrue_rdata{0};

/** Number of states with conservation laws applied */
int nx_solver{0};

/**
* Number of states in the unaugmented system with conservation laws
* applied
*/
int nxtrue_solver{0};

/** Number of solver states subject to reinitialization */
int nx_solver_reinit{0};

/** Number of observables */
int ny{0};

/** Number of observables in the unaugmented system */
int nytrue{0};

/** Number of event outputs */
int nz{0};

/** Number of event outputs in the unaugmented system */
int nztrue{0};

/** Number of events */
int ne{0};

/** Number of common expressions */
int nw{0};

/** Number of nonzero entries in Jacobian */
int nnz{0};

/** Dimension of the augmented objective function for 2nd order ASA */
int nJ{0};

/** Upper bandwidth of the Jacobian */
int ubw{0};

/** Lower bandwidth of the Jacobian */
int lbw{0};

/** Flag indicating Matlab- or Python-based model generation */
bool pythonGenerated;
Expand Down Expand Up @@ -2041,6 +2131,7 @@ class Model : public AbstractModel {
};

bool operator==(const Model &a, const Model &b);
bool operator==(const ModelDimensions &a, const ModelDimensions &b);

} // namespace amici

Expand Down
37 changes: 3 additions & 34 deletions include/amici/model_dae.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,7 @@ class Model_DAE : public Model {

/**
* @brief Constructor with model dimensions
* @param nx_rdata number of state variables
* @param nxtrue_rdata number of state variables of the non-augmented model
* @param nx_solver number of state variables with conservation laws applied
* @param nxtrue_solver number of state variables of the non-augmented model
with conservation laws applied
* @param nx_solver_reinit number of state variables with conservation laws
* subject to reinitialization
* @param ny number of observables
* @param nytrue number of observables of the non-augmented model
* @param nz number of event observables
* @param nztrue number of event observables of the non-augmented model
* @param ne number of events
* @param nJ number of objective functions
* @param nw number of repeating elements
* @param ndwdx number of nonzero elements in the x derivative of the
* repeating elements
* @param ndwdp number of nonzero elements in the p derivative of the
* repeating elements
* @param ndwdw number of nonzero elements in the w derivative of the
* repeating elements
* @param ndxdotdw number of nonzero elements dxdotdw
* @param ndJydy number of nonzero elements dJydy
* @param nnz number of nonzero elements in Jacobian
* @param ubw upper matrix bandwidth in the Jacobian
* @param lbw lower matrix bandwidth in the Jacobian
* @param model_dimensions Model dimensions
* @param o2mode second order sensitivity mode
* @param p parameters
* @param k constants
Expand All @@ -65,20 +41,13 @@ class Model_DAE : public Model {
* @param pythonGenerated flag indicating matlab or python wrapping
* @param ndxdotdp_explicit number of nonzero elements dxdotdp_explicit
*/
Model_DAE(const int nx_rdata, const int nxtrue_rdata, const int nx_solver,
const int nxtrue_solver, const int nx_solver_reinit, const int ny, const int nytrue,
const int nz, const int nztrue, const int ne, const int nJ,
const int nw, const int ndwdx, const int ndwdp, const int ndwdw,
const int ndxdotdw, std::vector<int> ndJydy, const int nnz,
const int ubw, const int lbw, const SecondOrderMode o2mode,
Model_DAE(const ModelDimensions &model_dimensions, const SecondOrderMode o2mode,
std::vector<realtype> const &p, std::vector<realtype> const &k,
std::vector<int> const &plist,
std::vector<realtype> const &idlist,
std::vector<int> const &z2event, const bool pythonGenerated=false,
const int ndxdotdp_explicit=0)
: Model(nx_rdata, nxtrue_rdata, nx_solver, nxtrue_solver,
nx_solver_reinit, ny, nytrue, nz, nztrue, ne, nJ, nw, ndwdx,
ndwdp, ndwdw, ndxdotdw, std::move(ndJydy), nnz, ubw, lbw,
: Model(model_dimensions,
o2mode, p, k, plist, idlist, z2event, pythonGenerated,
ndxdotdp_explicit) {
M_ = SUNMatrixWrapper(nx_solver, nx_solver);
Expand Down
Loading