Skip to content

Commit a245aee

Browse files
authored
Prevent NaNs through negative flows in SDE models (#1008)
1 parent 79158d1 commit a245aee

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

cpp/models/sde_sir/model.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,19 @@ class Model : public FlowModel<InfectionState, Populations<InfectionState>, Para
6363
// take the minimum of the calculated flow and the source compartment, to ensure that
6464
// no compartment attains negative values.
6565

66-
flows[get_flat_flow_index<InfectionState::Susceptible, InfectionState::Infected>()] = std::min(
66+
flows[get_flat_flow_index<InfectionState::Susceptible, InfectionState::Infected>()] = std::clamp(
6767
coeffStoI * y[(size_t)InfectionState::Susceptible] * pop[(size_t)InfectionState::Infected] +
6868
sqrt(coeffStoI * y[(size_t)InfectionState::Susceptible] * pop[(size_t)InfectionState::Infected]) /
6969
sqrt(step_size) * si,
70-
y[(size_t)InfectionState::Susceptible] / step_size);
70+
0.0, y[(size_t)InfectionState::Susceptible] / step_size);
7171

72-
flows[get_flat_flow_index<InfectionState::Infected, InfectionState::Recovered>()] = std::min(
72+
flows[get_flat_flow_index<InfectionState::Infected, InfectionState::Recovered>()] = std::clamp(
7373
(1.0 / params.get<TimeInfected>()) * y[(size_t)InfectionState::Infected] +
7474
sqrt((1.0 / params.get<TimeInfected>()) * y[(size_t)InfectionState::Infected]) / sqrt(step_size) * ir,
75-
y[(size_t)InfectionState::Infected] / step_size);
75+
0.0, y[(size_t)InfectionState::Infected] / step_size);
7676
}
7777

78-
ScalarType step_size = 0.1; ///< A step size of the model with which the stochastic process is realized.
78+
ScalarType step_size; ///< A step size of the model with which the stochastic process is realized.
7979
mutable RandomNumberGenerator rng;
8080

8181
private:

cpp/models/sde_sirs/model.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,24 @@ class Model : public FlowModel<InfectionState, Populations<InfectionState>, Para
6767
// take the minimum of the calculated flow and the source compartment, to ensure that
6868
// no compartment attains negative values.
6969

70-
flows[get_flat_flow_index<InfectionState::Susceptible, InfectionState::Infected>()] = std::min(
70+
flows[get_flat_flow_index<InfectionState::Susceptible, InfectionState::Infected>()] = std::clamp(
7171
coeffStoI * y[(size_t)InfectionState::Susceptible] * pop[(size_t)InfectionState::Infected] +
7272
sqrt(coeffStoI * y[(size_t)InfectionState::Susceptible] * pop[(size_t)InfectionState::Infected]) *
7373
inv_sqrt_dt * si,
74-
y[(size_t)InfectionState::Susceptible] / step_size);
74+
0.0, y[(size_t)InfectionState::Susceptible] / step_size);
7575

76-
flows[get_flat_flow_index<InfectionState::Infected, InfectionState::Recovered>()] = std::min(
76+
flows[get_flat_flow_index<InfectionState::Infected, InfectionState::Recovered>()] = std::clamp(
7777
(1.0 / params.get<TimeInfected>()) * y[(size_t)InfectionState::Infected] +
7878
sqrt((1.0 / params.get<TimeInfected>()) * y[(size_t)InfectionState::Infected]) * inv_sqrt_dt * ir,
79-
y[(size_t)InfectionState::Infected] / step_size);
79+
0.0, y[(size_t)InfectionState::Infected] / step_size);
8080

81-
flows[get_flat_flow_index<InfectionState::Recovered, InfectionState::Susceptible>()] = std::min(
81+
flows[get_flat_flow_index<InfectionState::Recovered, InfectionState::Susceptible>()] = std::clamp(
8282
(1.0 / params.get<TimeImmune>()) * y[(size_t)InfectionState::Recovered] +
8383
sqrt((1.0 / params.get<TimeImmune>()) * y[(size_t)InfectionState::Recovered]) * inv_sqrt_dt * rs,
84-
y[(size_t)InfectionState::Recovered] / step_size);
84+
0.0, y[(size_t)InfectionState::Recovered] / step_size);
8585
}
8686

87-
ScalarType step_size = 0.1; ///< A step size of the model with which the stochastic process is realized.
87+
ScalarType step_size; ///< A step size of the model with which the stochastic process is realized.
8888
mutable RandomNumberGenerator rng;
8989

9090
private:

0 commit comments

Comments
 (0)