Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ pub trait Algorithms<E: Equation + Send + 'static>: Sync + Send + 'static {
// Count problematic values in psi
let mut nan_count = 0;
let mut inf_count = 0;
let is_log_space = self.psi().is_log_space();
let is_log_space = match self.psi().space() {
crate::structs::psi::Space::Linear => false,
crate::structs::psi::Space::Log => true,
};

let psi = self.psi().matrix().as_ref().into_ndarray();
// First coerce all NaN and infinite in psi to 0.0 (or NEG_INFINITY for log-space)
Expand Down
20 changes: 8 additions & 12 deletions src/algorithms/npag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::routines::math::logsumexp;
use crate::routines::settings::Settings;

use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
use crate::structs::psi::{calculate_psi_dispatch, Psi};
use crate::structs::psi::{calculate_psi, Psi};
use crate::structs::theta::Theta;
use crate::structs::weights::Weights;

Expand Down Expand Up @@ -162,7 +162,7 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
self.eps /= 2.;
if self.eps <= THETA_E {
// Compute f1 = sum(log(pyl)) where pyl = psi * w
self.f1 = if self.psi.is_log_space() {
self.f1 = if self.psi.space() == crate::structs::psi::Space::Log {
// For log-space: f1 = sum_i(logsumexp(log_psi[i,:] + log(w)))
let log_w: Vec<f64> = w.weights().iter().map(|&x| x.ln()).collect();
(0..psi.nrows())
Expand Down Expand Up @@ -214,16 +214,14 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
}

fn estimation(&mut self) -> Result<()> {
let use_log_space = self.settings.advanced().log_space;

self.psi = calculate_psi_dispatch(
self.psi = calculate_psi(
&self.equation,
&self.data,
&self.theta,
&self.error_models,
self.cycle == 1 && self.settings.config().progress,
self.cycle != 1,
use_log_space,
self.settings.advanced().space,
)?;

if let Err(err) = self.validate_psi() {
Expand Down Expand Up @@ -296,8 +294,6 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
}

fn optimizations(&mut self) -> Result<()> {
let use_log_space = self.settings.advanced().log_space;

self.error_models
.clone()
.iter_mut()
Expand All @@ -318,24 +314,24 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
let mut error_model_down = self.error_models.clone();
error_model_down.set_factor(outeq, gamma_down)?;

let psi_up = calculate_psi_dispatch(
let psi_up = calculate_psi(
&self.equation,
&self.data,
&self.theta,
&error_model_up,
false,
true,
use_log_space,
self.settings.advanced().space,
)?;

let psi_down = calculate_psi_dispatch(
let psi_down = calculate_psi(
&self.equation,
&self.data,
&self.theta,
&error_model_down,
false,
true,
use_log_space,
self.settings.advanced().space,
)?;

let (lambda_up, objf_up) = burke_ipm(&psi_up)
Expand Down
47 changes: 20 additions & 27 deletions src/algorithms/npod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::algorithms::StopReason;
use crate::routines::initialization::sample_space;
use crate::routines::math::logsumexp;
use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
use crate::structs::psi::calculate_psi;
use crate::structs::weights::Weights;
use crate::{
algorithms::Status,
Expand All @@ -12,10 +13,7 @@ use crate::{
settings::Settings,
},
},
structs::{
psi::{calculate_psi_dispatch, Psi},
theta::Theta,
},
structs::{psi::Psi, theta::Theta},
};
use pharmsol::SppOptimizer;

Expand Down Expand Up @@ -205,16 +203,14 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPOD<E> {
}

fn estimation(&mut self) -> Result<()> {
let use_log_space = self.settings.advanced().log_space;

self.psi = calculate_psi_dispatch(
self.psi = calculate_psi(
&self.equation,
&self.data,
&self.theta,
&self.error_models,
self.cycle == 1 && self.settings.config().progress,
self.cycle != 1,
use_log_space,
self.settings.advanced().space,
)?;

if let Err(err) = self.validate_psi() {
Expand Down Expand Up @@ -282,8 +278,6 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPOD<E> {
}

fn optimizations(&mut self) -> Result<()> {
let use_log_space = self.settings.advanced().log_space;

self.error_models
.clone()
.iter_mut()
Expand All @@ -306,23 +300,23 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPOD<E> {
let mut error_model_down = self.error_models.clone();
error_model_down.set_factor(outeq, gamma_down)?;

let psi_up = calculate_psi_dispatch(
let psi_up = calculate_psi(
&self.equation,
&self.data,
&self.theta,
&error_model_up,
false,
true,
use_log_space,
self.settings.advanced().space,
)?;
let psi_down = calculate_psi_dispatch(
let psi_down = calculate_psi(
&self.equation,
&self.data,
&self.theta,
&error_model_down,
false,
true,
use_log_space,
self.settings.advanced().space,
)?;

let (lambda_up, objf_up) = burke_ipm(&psi_up)
Expand Down Expand Up @@ -365,20 +359,19 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPOD<E> {

// Compute pyl = P(Y|L) for each subject
// In log-space, we need to use logsumexp and then exp to get regular pyl
let pyl = if self.psi.is_log_space() {
// pyl[i] = sum_j(exp(log_psi[i,j]) * w[j]) = sum_j(exp(log_psi[i,j] + log(w[j])))
// Using logsumexp for stability, then exp to get regular values
let log_w: Array1<f64> = w.iter().map(|&x| x.ln()).collect();
let mut pyl = Array1::zeros(psi_mat.nrows());
for i in 0..psi_mat.nrows() {
let combined: Vec<f64> = (0..psi_mat.ncols())
.map(|j| psi_mat[[i, j]] + log_w[j])
.collect();
pyl[i] = logsumexp(&combined).exp();
let pyl = match self.settings.advanced().space {
crate::structs::psi::Space::Log => {
let log_w: Array1<f64> = w.iter().map(|&x| x.ln()).collect();
let mut pyl = Array1::zeros(psi_mat.nrows());
for i in 0..psi_mat.nrows() {
let combined: Vec<f64> = (0..psi_mat.ncols())
.map(|j| psi_mat[[i, j]] + log_w[j])
.collect();
pyl[i] = logsumexp(&combined).exp();
}
pyl
}
pyl
} else {
psi_mat.dot(&w)
crate::structs::psi::Space::Linear => psi_mat.dot(&w),
};

// Add new point to theta based on the optimization of the D function
Expand Down
8 changes: 3 additions & 5 deletions src/algorithms/postprob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
algorithms::{Status, StopReason},
prelude::algorithms::Algorithms,
structs::{
psi::{calculate_psi_dispatch, Psi},
psi::{calculate_psi, Psi},
theta::Theta,
weights::Weights,
},
Expand Down Expand Up @@ -119,16 +119,14 @@ impl<E: Equation + Send + 'static> Algorithms<E> for POSTPROB<E> {
}

fn estimation(&mut self) -> Result<()> {
let use_log_space = self.settings.advanced().log_space;

self.psi = calculate_psi_dispatch(
self.psi = calculate_psi(
&self.equation,
&self.data,
&self.theta,
&self.error_models,
false,
false,
use_log_space,
self.settings.advanced().space,
)?;

(self.w, self.objf) = burke_ipm(&self.psi).context("Error in IPM")?;
Expand Down
14 changes: 6 additions & 8 deletions src/bestdose/posterior.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ use crate::algorithms::Algorithms;
use crate::algorithms::Status;
use crate::prelude::*;
use crate::routines::estimation::ipm::burke_ipm;
use crate::structs::psi::calculate_psi_dispatch;
use crate::structs::psi::calculate_psi;
use crate::structs::psi::Space;
use crate::structs::theta::Theta;
use crate::structs::weights::Weights;
use pharmsol::prelude::*;
Expand Down Expand Up @@ -95,20 +96,20 @@ pub fn npagfull11_filter(
past_data: &Data,
eq: &ODE,
error_models: &ErrorModels,
use_log_space: bool,
space: Space,
) -> Result<(Theta, Weights, Weights)> {
tracing::info!("Stage 1.1: NPAGFULL11 Bayesian filtering");

// Calculate psi matrix P(data|theta_i) for all support points
// Use log-space or regular space based on setting
let psi = calculate_psi_dispatch(
let psi = calculate_psi(
eq,
past_data,
population_theta,
error_models,
false,
true,
use_log_space,
space,
)?;

// First burke call to get initial posterior probabilities
Expand Down Expand Up @@ -327,9 +328,6 @@ pub fn calculate_two_step_posterior(
) -> Result<(Theta, Weights, Weights)> {
tracing::info!("=== STAGE 1: Posterior Density Calculation ===");

// Use log-space based on settings
let use_log_space = settings.advanced().log_space;

// Step 1.1: NPAGFULL11 filtering (returns filtered posterior AND filtered prior)
let (filtered_theta, filtered_posterior_weights, filtered_population_weights) =
npagfull11_filter(
Expand All @@ -338,7 +336,7 @@ pub fn calculate_two_step_posterior(
past_data,
eq,
error_models,
use_log_space,
settings.advanced().space,
)?;

// Step 1.2: NPAGFULL refinement
Expand Down
31 changes: 16 additions & 15 deletions src/routines/estimation/ipm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,9 @@ pub fn burke_log(log_psi: &Psi) -> anyhow::Result<(Weights, f64)> {
///
/// Returns an error if the underlying IPM optimization fails.
pub fn burke_ipm(psi: &Psi) -> anyhow::Result<(Weights, f64)> {
if psi.is_log_space() {
burke_log(psi)
} else {
burke(psi)
match psi.space() {
crate::structs::psi::Space::Linear => burke(psi),
crate::structs::psi::Space::Log => burke_log(psi),
}
}

Expand Down Expand Up @@ -813,7 +812,6 @@ mod tests {
fn test_burke_log_identity() {
// Test with identity matrix converted to log space
// log(1) = 0, log(0) = -inf, but we use a small positive value instead
use crate::structs::psi::PsiBuilder;
use ndarray::Array2;

let n = 10;
Expand All @@ -826,7 +824,8 @@ mod tests {
}
});

let psi = PsiBuilder::new(log_mat).log_space(true).build();
let mat = Mat::from_fn(n, n, |i, j| log_mat[(i, j)]);
let psi = Psi::new_log(mat);

let (lam, obj) = burke_log(&psi).unwrap();

Expand All @@ -847,14 +846,15 @@ mod tests {
fn test_burke_log_uniform() {
// Test with uniform matrix in log space
// log(1) = 0 everywhere
use crate::structs::psi::PsiBuilder;

use ndarray::Array2;

let n_sub = 10;
let n_point = 10;
let log_mat = Array2::from_shape_fn((n_sub, n_point), |_| 0.0); // log(1) = 0

let psi = PsiBuilder::new(log_mat).log_space(true).build();
let mat = Mat::from_fn(n_sub, n_point, |i, j| log_mat[(i, j)]);
let psi = Psi::new_log(mat);

let (lam, obj) = burke_log(&psi).unwrap();

Expand All @@ -875,7 +875,6 @@ mod tests {
fn test_burke_log_consistency_with_regular() {
// Test that burke_log produces the same results as burke
// when given equivalent inputs
use crate::structs::psi::PsiBuilder;
use ndarray::Array2;

let n_sub = 5;
Expand All @@ -889,11 +888,13 @@ mod tests {

// Create the equivalent log-space matrix
let log_mat = regular_mat.mapv(|x| x.ln());
let log_psi = PsiBuilder::new(log_mat).log_space(true).build();

let mat = Mat::from_fn(n_sub, n_point, |i, j| log_mat[(i, j)]);
let psi = Psi::new_log(mat);

// Run both algorithms
let (lam_regular, obj_regular) = burke(&regular_psi).unwrap();
let (lam_log, obj_log) = burke_log(&log_psi).unwrap();
let (lam_log, obj_log) = burke_log(&psi).unwrap();

// The weights should be very similar
for i in 0..n_point {
Expand All @@ -908,7 +909,6 @@ mod tests {
fn test_burke_log_handles_very_small_likelihoods() {
// Test that log-space IPM handles very small likelihoods that would
// underflow in regular space
use crate::structs::psi::PsiBuilder;
use ndarray::Array2;

let n_sub = 5;
Expand All @@ -920,7 +920,8 @@ mod tests {
-500.0 + (i as f64) * 0.1 + (j as f64) * 0.05
});

let psi = PsiBuilder::new(log_mat).log_space(true).build();
let mat = Mat::from_fn(n_point, n_sub, |i, j| log_mat[(i, j)]);
let psi = Psi::new_log(mat);

// This should succeed without underflow issues
let result = burke_log(&psi);
Expand All @@ -944,7 +945,6 @@ mod tests {
#[test]
fn test_burke_log_with_varying_magnitudes() {
// Test with log-likelihoods of varying magnitudes
use crate::structs::psi::PsiBuilder;
use ndarray::Array2;

let n_sub = 8;
Expand All @@ -960,7 +960,8 @@ mod tests {
}
});

let psi = PsiBuilder::new(log_mat).log_space(true).build();
let mat = Mat::from_fn(n_point, n_sub, |i, j| log_mat[(j, i)]);
let psi = Psi::new_log(mat);

let (lam, obj) = burke_log(&psi).unwrap();

Expand Down
Loading
Loading