Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/algorithms/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use super::{initialization, output::CycleLog, NonParametricAlgorithm};
/// Maximum a posteriori (MAP) estimation
///
/// Calculate the MAP estimate of the parameters of the model given the data.
#[derive(Debug, Clone)]
pub struct MAP<E: Equation> {
equation: E,
psi: Array2<f64>,
Expand Down
33 changes: 13 additions & 20 deletions src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::path::Path;

use crate::prelude::{self, settings::Settings};

use anyhow::{bail, Result};
use anyhow::Result;
use anyhow::{Context, Error};
use map::MAP;
use ndarray::Array2;
Expand All @@ -23,26 +23,14 @@ pub mod routines;
/// Supported algorithms by `PMcore`
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
pub enum Algorithm {
NonParametric(NonParametric),
Parametric(Parametric),
}

/// Supported non-parametric algorithms
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
pub enum NonParametric {
// Non-parametric algorithms
NPAG,
NPOD,
MAP,
// Parametric algorithms
}

/// Supported parametric algorithms
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
pub enum Parametric {
FOCE,
NPSA,
}

/// This traint defines the methods for non-parametric (NP) algorithms
/// This trait defines the methods for non-parametric (NP) algorithms
pub trait NonParametricAlgorithm<E: Equation> {
fn new(config: Settings, equation: E, data: Data) -> Result<Box<Self>, Error>
where
Expand Down Expand Up @@ -122,15 +110,20 @@ pub trait NonParametricAlgorithm<E: Equation> {
fn into_npresult(&self) -> NPResult<E>;
}

pub trait ParametricAlgorithm<E: Equation> {
fn fit(&mut self) -> Result<()> {
unimplemented!()
}
}

pub fn dispatch_algorithm<E: Equation>(
settings: Settings,
equation: E,
data: Data,
) -> Result<Box<dyn NonParametricAlgorithm<E>>> {
match settings.config().algorithm {
Algorithm::NonParametric(NonParametric::NPAG) => Ok(NPAG::new(settings, equation, data)?),
Algorithm::NonParametric(NonParametric::NPOD) => Ok(NPOD::new(settings, equation, data)?),
Algorithm::NonParametric(NonParametric::MAP) => Ok(MAP::new(settings, equation, data)?),
_ => bail!("Unsupported algorithm"),
Algorithm::NPAG => Ok(NPAG::new(settings, equation, data)?),
Algorithm::NPOD => Ok(NPOD::new(settings, equation, data)?),
Algorithm::MAP => Ok(MAP::new(settings, equation, data)?),
}
}
2 changes: 1 addition & 1 deletion src/algorithms/npag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const THETA_G: f64 = 1e-4; // Objective function convergence criteria
const THETA_F: f64 = 1e-2;
const THETA_D: f64 = 1e-4;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct NPAG<E: Equation> {
equation: E,
psi: Array2<f64>,
Expand Down
1 change: 1 addition & 0 deletions src/algorithms/npod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use super::{
const THETA_F: f64 = 1e-2;
const THETA_D: f64 = 1e-4;

#[derive(Debug, Clone)]
pub struct NPOD<E: Equation> {
equation: E,
psi: Array2<f64>,
Expand Down
11 changes: 4 additions & 7 deletions src/algorithms/routines/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ impl Default for Config {
fn default() -> Self {
Config {
cycles: 100,
algorithm: Algorithm::NonParametric(crate::algorithms::NonParametric::NPAG),
algorithm: Algorithm::NPAG,
cache: true,
}
}
Expand Down Expand Up @@ -793,7 +793,7 @@ impl SettingsBuilder<ErrorSet> {

mod tests {
use super::*;
use crate::algorithms::{Algorithm, NonParametric};
use crate::algorithms::Algorithm;
use pharmsol::prelude::data::ErrorType;

#[test]
Expand All @@ -805,7 +805,7 @@ mod tests {
.unwrap();

let settings = SettingsBuilder::new()
.set_algorithm(Algorithm::NonParametric(NonParametric::NPAG)) // Step 1: Define algorithm
.set_algorithm(Algorithm::NPAG) // Step 1: Define algorithm
.set_parameters(parameters) // Step 2: Define parameters
.set_error_model(Error {
value: 0.1,
Expand All @@ -814,9 +814,6 @@ mod tests {
}) // Step 3: Define error model
.build(); // Final step

assert_eq!(
settings.config.algorithm,
Algorithm::NonParametric(NonParametric::NPAG,)
);
assert_eq!(settings.config.algorithm, Algorithm::NPAG);
}
}
Loading