Skip to content

Commit 79ae152

Browse files
jubickerxsaschakoDavidKerkmann
authored
1081 rework ABM state transitions (#1107)
Parameters can now use any type of distribution the user wants to. Signed-off-by: DavidKerkmann <44698825+DavidKerkmann@users.noreply.github.com> Co-authored-by: Sascha Korf <51127093+xsaschako@users.noreply.github.com> Co-authored-by: DavidKerkmann <44698825+DavidKerkmann@users.noreply.github.com>
1 parent 3390762 commit 79ae152

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2396
-1042
lines changed

cpp/examples/abm_history_object.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
#include "abm/simulation.h"
2323
#include "abm/model.h"
2424
#include "abm/location_type.h"
25+
#include "memilio/utils/abstract_parameter_distribution.h"
2526
#include "memilio/io/history.h"
27+
#include "memilio/utils/parameter_distributions.h"
2628

2729
#include <fstream>
2830
#include <string>
@@ -69,9 +71,9 @@ int main()
6971

7072
// Create the model with 4 age groups.
7173
auto model = mio::abm::Model(num_age_groups);
72-
73-
// Set same infection parameter for all age groups. For example, the incubation period is 4 days.
74-
model.parameters.get<mio::abm::IncubationPeriod>() = 4.;
74+
mio::ParameterDistributionLogNormal log_norm(4., 1.);
75+
// Set same infection parameter for all age groups. For example, the incubation period is log normally distributed with parameters 4 and 1.
76+
model.parameters.get<mio::abm::TimeExposedToNoSymptoms>() = mio::ParameterDistributionLogNormal(4., 1.);
7577

7678
// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
7779
model.parameters.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
@@ -134,7 +136,7 @@ int main()
134136
auto test_parameters = model.parameters.get<mio::abm::TestData>()[test_type];
135137
auto testing_criteria_work = mio::abm::TestingCriteria();
136138
auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, validity_period, start_date, end_date,
137-
test_parameters, probability);
139+
test_parameters, probability);
138140
model.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme_work);
139141

140142
// Assign infection state to each person.

cpp/examples/abm_minimal.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "abm/lockdown_rules.h"
2222
#include "abm/model.h"
2323
#include "abm/common_abm_loggers.h"
24+
#include "memilio/utils/abstract_parameter_distribution.h"
2425

2526
#include <fstream>
2627

@@ -37,9 +38,8 @@ int main()
3738

3839
// Create the model with 4 age groups.
3940
auto model = mio::abm::Model(num_age_groups);
40-
41-
// Set same infection parameter for all age groups. For example, the incubation period is 4 days.
42-
model.parameters.get<mio::abm::IncubationPeriod>() = 4.;
41+
// Set same infection parameter for all age groups. For example, the incubation period is log normally distributed with parameters 4 and 1.
42+
model.parameters.get<mio::abm::TimeExposedToNoSymptoms>() = mio::ParameterDistributionLogNormal(4., 1.);
4343

4444
// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
4545
model.parameters.get<mio::abm::AgeGroupGotoSchool>() = false;
@@ -113,7 +113,7 @@ int main()
113113
auto test_parameters = model.parameters.get<mio::abm::TestData>()[test_type];
114114
auto testing_criteria_work = mio::abm::TestingCriteria();
115115
auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, validity_period, start_date, end_date,
116-
test_parameters, probability);
116+
test_parameters, probability);
117117
model.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme_work);
118118

119119
// Assign infection state to each person.

cpp/examples/graph_abm.cpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include "graph_abm/graph_abmodel.h"
2929
#include "memilio/io/history.h"
3030
#include "memilio/mobility/graph.h"
31+
#include "memilio/utils/abstract_parameter_distribution.h"
32+
#include "memilio/utils/parameter_distributions.h"
3133
#include <cstddef>
3234
#include <cstdint>
3335
#include <map>
@@ -77,15 +79,15 @@ int main()
7779
auto model1 = mio::GraphABModel(num_age_groups, 0);
7880

7981
//Set infection parameters
80-
model1.parameters.get<mio::abm::IncubationPeriod>() = 4.;
81-
model1.parameters.get<mio::abm::InfectedNoSymptomsToSymptoms>() = 2.;
82-
model1.parameters.get<mio::abm::InfectedNoSymptomsToRecovered>() = 4.;
83-
model1.parameters.get<mio::abm::InfectedSymptomsToRecovered>() = 5.;
84-
model1.parameters.get<mio::abm::InfectedSymptomsToSevere>() = 6.;
85-
model1.parameters.get<mio::abm::SevereToRecovered>() = 8.;
86-
model1.parameters.get<mio::abm::SevereToCritical>() = 7.;
87-
model1.parameters.get<mio::abm::CriticalToRecovered>() = 10.;
88-
model1.parameters.get<mio::abm::CriticalToDead>() = 11.;
82+
model1.parameters.get<mio::abm::TimeExposedToNoSymptoms>() = mio::ParameterDistributionConstant(4.);
83+
model1.parameters.get<mio::abm::TimeInfectedNoSymptomsToSymptoms>() = mio::ParameterDistributionConstant(2.);
84+
model1.parameters.get<mio::abm::TimeInfectedNoSymptomsToRecovered>() = mio::ParameterDistributionConstant(4.);
85+
model1.parameters.get<mio::abm::TimeInfectedSymptomsToRecovered>() = mio::ParameterDistributionConstant(5.);
86+
model1.parameters.get<mio::abm::TimeInfectedSymptomsToSevere>() = mio::ParameterDistributionConstant(6.);
87+
model1.parameters.get<mio::abm::TimeInfectedSevereToRecovered>() = mio::ParameterDistributionConstant(8.);
88+
model1.parameters.get<mio::abm::TimeInfectedSevereToCritical>() = mio::ParameterDistributionConstant(7.);
89+
model1.parameters.get<mio::abm::TimeInfectedCriticalToRecovered>() = mio::ParameterDistributionConstant(10.);
90+
model1.parameters.get<mio::abm::TimeInfectedCriticalToDead>() = mio::ParameterDistributionConstant(11.);
8991

9092
//Age group 0 goes to school and age group 1 goes to work
9193
model1.parameters.get<mio::abm::AgeGroupGotoSchool>()[age_group_children] = true;
@@ -135,15 +137,15 @@ int main()
135137
auto model2 = mio::GraphABModel(num_age_groups, 1);
136138

137139
//Set infection parameters
138-
model2.parameters.get<mio::abm::IncubationPeriod>() = 4.;
139-
model2.parameters.get<mio::abm::InfectedNoSymptomsToSymptoms>() = 2.;
140-
model2.parameters.get<mio::abm::InfectedNoSymptomsToRecovered>() = 4.;
141-
model2.parameters.get<mio::abm::InfectedSymptomsToRecovered>() = 5.;
142-
model2.parameters.get<mio::abm::InfectedSymptomsToSevere>() = 6.;
143-
model2.parameters.get<mio::abm::SevereToRecovered>() = 8.;
144-
model2.parameters.get<mio::abm::SevereToCritical>() = 7.;
145-
model2.parameters.get<mio::abm::CriticalToRecovered>() = 10.;
146-
model2.parameters.get<mio::abm::CriticalToDead>() = 11.;
140+
model2.parameters.get<mio::abm::TimeExposedToNoSymptoms>() = mio::ParameterDistributionConstant(4.);
141+
model2.parameters.get<mio::abm::TimeInfectedNoSymptomsToSymptoms>() = mio::ParameterDistributionConstant(2.);
142+
model2.parameters.get<mio::abm::TimeInfectedNoSymptomsToRecovered>() = mio::ParameterDistributionConstant(4.);
143+
model2.parameters.get<mio::abm::TimeInfectedSymptomsToRecovered>() = mio::ParameterDistributionConstant(5.);
144+
model2.parameters.get<mio::abm::TimeInfectedSymptomsToSevere>() = mio::ParameterDistributionConstant(6.);
145+
model2.parameters.get<mio::abm::TimeInfectedSevereToRecovered>() = mio::ParameterDistributionConstant(8.);
146+
model2.parameters.get<mio::abm::TimeInfectedSevereToCritical>() = mio::ParameterDistributionConstant(7.);
147+
model2.parameters.get<mio::abm::TimeInfectedCriticalToRecovered>() = mio::ParameterDistributionConstant(10.);
148+
model2.parameters.get<mio::abm::TimeInfectedCriticalToDead>() = mio::ParameterDistributionConstant(11.);
147149

148150
//Age group 0 goes to school and age group 1 goes to work
149151
model2.parameters.get<mio::abm::AgeGroupGotoSchool>()[age_group_children] = true;

cpp/examples/ode_secir_parameter_sampling.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
* limitations under the License.
1919
*/
2020
#include "memilio/utils/parameter_distributions.h"
21+
#include "memilio/utils/random_number_generator.h"
2122
#include "ode_secir/parameter_space.h"
2223
#include "ode_secir/model.h"
2324

@@ -41,7 +42,7 @@ int main()
4142
printf("\n N(%.0f,%.0f)-distribution with sampling only in [%.0f,%.0f]", mean, stddev, min, max);
4243
int counter[10] = {0};
4344
for (int i = 0; i < 1000; i++) {
44-
int rounded = (int)(some_parameter.get_sample() - 1);
45+
int rounded = (int)(some_parameter.get_sample(mio::thread_local_rng()) - 1);
4546
if (rounded >= 0 && rounded < 10) {
4647
counter[rounded]++;
4748
}
@@ -59,7 +60,7 @@ int main()
5960

6061
double counter_unif[10] = {0};
6162
for (int i = 0; i < 1000; i++) {
62-
int rounded = (int)(some_other_parameter.get_sample() - 1);
63+
int rounded = (int)(some_other_parameter.get_sample(mio::thread_local_rng()) - 1);
6364
if (rounded >= 0 && rounded < 10) {
6465
counter_unif[rounded]++;
6566
}

cpp/memilio/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ add_library(memilio
7979
utils/custom_index_array.h
8080
utils/memory.h
8181
utils/parameter_distributions.h
82+
utils/abstract_parameter_distribution.h
8283
utils/time_series.h
8384
utils/time_series.cpp
8485
utils/span.h

cpp/memilio/epidemiology/damping_sampling.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#define EPI_SECIR_DAMPING_SAMPLING_H
2222

2323
#include "memilio/epidemiology/damping.h"
24+
#include "memilio/utils/random_number_generator.h"
2425
#include "memilio/utils/uncertain_value.h"
2526
#include <memory>
2627

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
/*
2+
* Copyright (C) 2020-2025 MEmilio
3+
*
4+
* Authors: Julia Bicker
5+
*
6+
* Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
7+
*
8+
* Licensed under the Apache License, Version 2.0 (the "License");
9+
* you may not use this file except in compliance with the License.
10+
* You may obtain a copy of the License at
11+
*
12+
* http://www.apache.org/licenses/LICENSE-2.0
13+
*
14+
* Unless required by applicable law or agreed to in writing, software
15+
* distributed under the License is distributed on an "AS IS" BASIS,
16+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
* See the License for the specific language governing permissions and
18+
* limitations under the License.
19+
*/
20+
#ifndef ABSTRACT_PARAMETER_DISTRIBUTION_H
21+
#define ABSTRACT_PARAMETER_DISTRIBUTION_H
22+
23+
#include "memilio/io/io.h"
24+
#include "memilio/utils/compiler_diagnostics.h"
25+
#include "memilio/utils/logging.h"
26+
#include "memilio/utils/random_number_generator.h"
27+
#include "parameter_distributions.h"
28+
#include <memory>
29+
#include <string>
30+
31+
namespace mio
32+
{
33+
34+
/**
35+
* @brief This class represents an arbitrary ParameterDistribution.
36+
* @see mio::ParameterDistribution
37+
* This class can for instance be used for model parameters that should have an arbitrary distribution.
38+
*/
39+
class AbstractParameterDistribution
40+
{
41+
42+
public:
43+
/**
44+
* The implementation handed to the constructor should have get_sample function
45+
* overloaded with mio::RandomNumberGenerator and mio::abm::PersonalRandomNumberGenerator as input arguments
46+
*/
47+
template <class Impl>
48+
AbstractParameterDistribution(Impl&& dist)
49+
: m_dist(std::make_shared<Impl>(std::move(dist)))
50+
, sample_impl1([](void* d, RandomNumberGenerator& rng) {
51+
return static_cast<Impl*>(d)->get_sample(rng);
52+
})
53+
, sample_impl2([](void* d, abm::PersonalRandomNumberGenerator& rng) {
54+
return static_cast<Impl*>(d)->get_sample(rng);
55+
})
56+
{
57+
}
58+
59+
AbstractParameterDistribution(AbstractParameterDistribution& other)
60+
: m_dist(other.m_dist)
61+
, sample_impl1(other.sample_impl1)
62+
, sample_impl2(other.sample_impl2)
63+
{
64+
}
65+
66+
AbstractParameterDistribution(AbstractParameterDistribution&& other)
67+
: m_dist(other.m_dist)
68+
, sample_impl1(other.sample_impl1)
69+
, sample_impl2(other.sample_impl2)
70+
{
71+
}
72+
73+
AbstractParameterDistribution(const AbstractParameterDistribution& other)
74+
: m_dist(other.m_dist)
75+
, sample_impl1(other.sample_impl1)
76+
, sample_impl2(other.sample_impl2)
77+
{
78+
}
79+
80+
AbstractParameterDistribution()
81+
: m_dist(nullptr)
82+
, sample_impl1([](void* /*dist*/, RandomNumberGenerator& /*rng*/) {
83+
log_critical("AbstractParameterDistribution does not hold a distribution.");
84+
if (true) {
85+
exit(static_cast<int>(StatusCode::UnknownError));
86+
}
87+
else {
88+
return -1.;
89+
}
90+
})
91+
, sample_impl2([](void* /*dist*/, abm::PersonalRandomNumberGenerator& /*rng*/) {
92+
log_critical("AbstractParameterDistribution does not hold a distribution.");
93+
if (true) {
94+
exit(static_cast<int>(StatusCode::UnknownError));
95+
}
96+
else {
97+
return -1.;
98+
}
99+
})
100+
{
101+
}
102+
103+
AbstractParameterDistribution& operator=(AbstractParameterDistribution&& other) = default;
104+
105+
AbstractParameterDistribution& operator=(const AbstractParameterDistribution& other) = default;
106+
107+
bool operator<(const AbstractParameterDistribution& other) const
108+
{
109+
return static_cast<ParameterDistribution*>(m_dist.get())
110+
->smaller_impl(*static_cast<ParameterDistribution*>(other.m_dist.get()));
111+
}
112+
113+
/**
114+
* @brief Returns a value sampled with the given distribution.
115+
* @param[in] rng RandomNumberGenerator used for sampling.
116+
*/
117+
double get(RandomNumberGenerator& rng) const
118+
{
119+
return sample_impl1(m_dist.get(), rng);
120+
}
121+
122+
/**
123+
* @brief Returns a value sampled with the given distribution.
124+
* @param[in] rng abm::PersonalRandomNumberGenerator used for sampling.
125+
*/
126+
double get(abm::PersonalRandomNumberGenerator& rng) const
127+
{
128+
return sample_impl2(m_dist.get(), rng);
129+
}
130+
131+
/**
132+
* @brief Get the parameters of the given distribution.
133+
*/
134+
std::vector<double> params() const
135+
{
136+
return static_cast<ParameterDistribution*>(m_dist.get())->params();
137+
}
138+
139+
/**
140+
* serialize an AbstractParameterDistribution.
141+
* @see mio::serialize
142+
*/
143+
template <class IOContext>
144+
void serialize(IOContext& io) const
145+
{
146+
static_cast<ParameterDistribution*>(m_dist.get())->serialize(io);
147+
}
148+
149+
private:
150+
std::shared_ptr<void> m_dist; ///< Underlying distribtuion.
151+
double (*sample_impl1)(
152+
void*,
153+
RandomNumberGenerator&); ///< Sample function of the distribution which gets a RandomNumberGenerator as rng.
154+
double (*sample_impl2)(
155+
void*,
156+
abm::
157+
PersonalRandomNumberGenerator&); ///< Sample function of the distribution which gets a abm::PersonalRandomNumberGenerator as rng.
158+
};
159+
160+
/**
161+
* deserialize a AbstractParameterDistribution.
162+
* @see mio::deserialize
163+
*/
164+
template <class IOContext>
165+
IOResult<AbstractParameterDistribution> deserialize_internal(IOContext& io, Tag<AbstractParameterDistribution>)
166+
{
167+
auto obj = io.expect_object("ParameterDistribution");
168+
auto type = obj.expect_element("Type", Tag<std::string>{});
169+
if (type) {
170+
if (type.value() == "Uniform") {
171+
BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionUniform::deserialize_elements(io, obj));
172+
return mio::success(AbstractParameterDistribution(std::move(r)));
173+
}
174+
else if (type.value() == "Normal") {
175+
BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionNormal::deserialize_elements(io, obj));
176+
return mio::success(AbstractParameterDistribution(std::move(r)));
177+
}
178+
else if (type.value() == "LogNormal") {
179+
BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionLogNormal::deserialize_elements(io, obj));
180+
return mio::success(AbstractParameterDistribution(std::move(r)));
181+
}
182+
else if (type.value() == "Exponential") {
183+
BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionExponential::deserialize_elements(io, obj));
184+
return mio::success(AbstractParameterDistribution(std::move(r)));
185+
}
186+
else if (type.value() == "Constant") {
187+
BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionConstant::deserialize_elements(io, obj));
188+
return mio::success(AbstractParameterDistribution(std::move(r)));
189+
}
190+
else {
191+
return failure(StatusCode::InvalidValue, "Type of ParameterDistribution in AbstractParameterDistribution" +
192+
type.value() + " not valid.");
193+
}
194+
}
195+
return failure(type.error());
196+
}
197+
198+
} // namespace mio
199+
200+
#endif //ABSTRACT_PARAMETER_DISTRIBUTION_H

cpp/memilio/utils/logging.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ inline void log_error(spdlog::string_view_t fmt, const Args&... args)
103103
spdlog::default_logger_raw()->error(fmt, args...);
104104
}
105105

106+
template <typename... Args>
107+
inline void log_critical(spdlog::string_view_t fmt, const Args&... args)
108+
{
109+
spdlog::default_logger_raw()->error(fmt, args...);
110+
}
111+
106112
template <typename... Args>
107113
inline void log_warning(spdlog::string_view_t fmt, const Args&... args)
108114
{

0 commit comments

Comments
 (0)