Skip to content

Commit 4badae5

Browse files
authored
636 make python serialization usable again (#761)
1 parent abc02d7 commit 4badae5

28 files changed

+311
-107
lines changed

cpp/memilio/epidemiology/age_group.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ struct AgeGroup : public Index<AgeGroup> {
3434
: Index<AgeGroup>(val)
3535
{
3636
}
37+
38+
/**
39+
* Override deserialize of base class
40+
* @see mio::Index::deserialize
41+
*/
42+
template <class IOContext>
43+
static IOResult<AgeGroup> deserialize(IOContext& io)
44+
{
45+
BOOST_OUTCOME_TRY(auto&& i, mio::deserialize(io, Tag<size_t>{}));
46+
return success(AgeGroup(i));
47+
}
3748
};
3849

3950
} // namespace mio

cpp/memilio/utils/metaprogramming.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ template <class B1>
120120
struct disjunction<B1> : B1 {
121121
//disjunction of one element is identity
122122
};
123-
template <class B1, class... Bn>
124-
struct disjunction<B1, Bn...> : std::conditional<bool(B1::value), B1, disjunction<Bn...>> {
123+
template<class B1, class... Bn>
124+
struct disjunction<B1, Bn...> : std::conditional_t<bool(B1::value), B1, disjunction<Bn...>> {
125125
//disjunction of mutliple elements is equal to the first element if the first element is true.
126126
//otherwise its equal to the disjunction of the remaining elements.
127127
};

cpp/models/ode_seir/model.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ class Model : public FlowModel<InfectionState, Populations<AgeGroup, InfectionSt
5454
using Base = FlowModel<InfectionState, mio::Populations<AgeGroup, InfectionState>, Parameters, Flows>;
5555

5656
public:
57+
Model(const Populations& pop, const ParameterSet& params)
58+
: Base(pop, params)
59+
{
60+
}
61+
5762
Model(int num_agegroups)
5863
: Base(Populations({AgeGroup(num_agegroups), InfectionState::Count}), ParameterSet(AgeGroup(num_agegroups)))
5964
{
@@ -201,6 +206,36 @@ class Model : public FlowModel<InfectionState, Populations<AgeGroup, InfectionSt
201206
auto result = linear_interpolation(t_value, y.get_time(time_late - 1), y.get_time(time_late), y1, y2);
202207
return mio::success(static_cast<ScalarType>(result));
203208
}
209+
210+
/**
211+
* serialize this.
212+
* @see mio::serialize
213+
*/
214+
template <class IOContext>
215+
void serialize(IOContext& io) const
216+
{
217+
auto obj = io.create_object("Model");
218+
obj.add_element("Parameters", parameters);
219+
obj.add_element("Populations", populations);
220+
}
221+
222+
/**
223+
* deserialize an object of this class.
224+
* @see mio::deserialize
225+
*/
226+
template <class IOContext>
227+
static IOResult<Model> deserialize(IOContext& io)
228+
{
229+
auto obj = io.expect_object("Model");
230+
auto par = obj.expect_element("Parameters", Tag<ParameterSet>{});
231+
auto pop = obj.expect_element("Populations", Tag<Populations>{});
232+
return apply(
233+
io,
234+
[](auto&& par_, auto&& pop_) {
235+
return Model{pop_, par_};
236+
},
237+
par, pop);
238+
}
204239
};
205240

206241
} // namespace oseir

cpp/models/ode_sir/model.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ class Model : public CompartmentalModel<InfectionState, Populations<AgeGroup, In
4242
using Base = CompartmentalModel<InfectionState, mio::Populations<AgeGroup, InfectionState>, Parameters>;
4343

4444
public:
45+
Model(const Populations& pop, const ParameterSet& params)
46+
: Base(pop, params)
47+
{
48+
}
49+
4550
Model(int num_agegroups)
4651
: Base(Populations({AgeGroup(num_agegroups), InfectionState::Count}), ParameterSet(AgeGroup(num_agegroups)))
4752
{
@@ -80,6 +85,36 @@ class Model : public CompartmentalModel<InfectionState, Populations<AgeGroup, In
8085

8186
}
8287
}
88+
89+
/**
90+
* serialize this.
91+
* @see mio::serialize
92+
*/
93+
template <class IOContext>
94+
void serialize(IOContext& io) const
95+
{
96+
auto obj = io.create_object("Model");
97+
obj.add_element("Parameters", parameters);
98+
obj.add_element("Populations", populations);
99+
}
100+
101+
/**
102+
* deserialize an object of this class.
103+
* @see mio::deserialize
104+
*/
105+
template <class IOContext>
106+
static IOResult<Model> deserialize(IOContext& io)
107+
{
108+
auto obj = io.expect_object("Model");
109+
auto par = obj.expect_element("Parameters", Tag<ParameterSet>{});
110+
auto pop = obj.expect_element("Populations", Tag<Populations>{});
111+
return apply(
112+
io,
113+
[](auto&& par_, auto&& pop_) {
114+
return Model{pop_, par_};
115+
},
116+
par, pop);
117+
}
83118
};
84119

85120
} // namespace osir

pycode/memilio-simulation/memilio/simulation/abm.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace py = pybind11;
3333

3434
PYBIND11_MODULE(_simulation_abm, m)
3535
{
36-
pymio::iterable_enum<mio::abm::InfectionState>(m, "InfectionState", py::module_local{})
36+
pymio::iterable_enum<mio::abm::InfectionState>(m, "InfectionState")
3737
.value("Susceptible", mio::abm::InfectionState::Susceptible)
3838
.value("Exposed", mio::abm::InfectionState::Exposed)
3939
.value("InfectedNoSymptoms", mio::abm::InfectionState::InfectedNoSymptoms)
@@ -62,21 +62,21 @@ PYBIND11_MODULE(_simulation_abm, m)
6262
.value("PublicTransport", mio::abm::LocationType::PublicTransport)
6363
.value("TransportWithoutContact", mio::abm::LocationType::TransportWithoutContact);
6464

65-
py::class_<mio::abm::TestParameters>(m, "TestParameters")
65+
pymio::bind_class<mio::abm::TestParameters, pymio::EnablePickling::Never>(m, "TestParameters")
6666
.def(py::init<double, double>())
6767
.def_readwrite("sensitivity", &mio::abm::TestParameters::sensitivity)
6868
.def_readwrite("specificity", &mio::abm::TestParameters::specificity);
6969

7070
pymio::bind_CustomIndexArray<mio::UncertainValue, mio::abm::VirusVariant, mio::AgeGroup>(m, "_AgeParameterArray");
7171
pymio::bind_Index<mio::abm::ExposureType>(m, "ExposureTypeIndex");
72-
pymio::bind_ParameterSet<mio::abm::ParametersBase>(m, "ParametersBase");
73-
py::class_<mio::abm::Parameters, mio::abm::ParametersBase>(m, "Parameters")
72+
pymio::bind_ParameterSet<mio::abm::ParametersBase, pymio::EnablePickling::Never>(m, "ParametersBase");
73+
pymio::bind_class<mio::abm::Parameters, pymio::EnablePickling::Never, mio::abm::ParametersBase>(m, "Parameters")
7474
.def(py::init<int>())
7575
.def("check_constraints", &mio::abm::Parameters::check_constraints);
7676

77-
pymio::bind_ParameterSet<mio::abm::LocalInfectionParameters>(m, "LocalInfectionParameters").def(py::init<size_t>());
77+
pymio::bind_ParameterSet<mio::abm::LocalInfectionParameters, pymio::EnablePickling::Never>(m, "LocalInfectionParameters").def(py::init<size_t>());
7878

79-
py::class_<mio::abm::TimeSpan>(m, "TimeSpan")
79+
pymio::bind_class<mio::abm::TimeSpan, pymio::EnablePickling::Never>(m, "TimeSpan")
8080
.def(py::init<int>(), py::arg("seconds") = 0)
8181
.def_property_readonly("seconds", &mio::abm::TimeSpan::seconds)
8282
.def_property_readonly("hours", &mio::abm::TimeSpan::hours)
@@ -101,7 +101,7 @@ PYBIND11_MODULE(_simulation_abm, m)
101101
m.def("hours", &mio::abm::hours);
102102
m.def("days", py::overload_cast<int>(&mio::abm::days));
103103

104-
py::class_<mio::abm::TimePoint>(m, "TimePoint")
104+
pymio::bind_class<mio::abm::TimePoint, pymio::EnablePickling::Never>(m, "TimePoint")
105105
.def(py::init<int>(), py::arg("seconds") = 0)
106106
.def_property_readonly("seconds", &mio::abm::TimePoint::seconds)
107107
.def_property_readonly("days", &mio::abm::TimePoint::days)
@@ -121,7 +121,7 @@ PYBIND11_MODULE(_simulation_abm, m)
121121
.def(py::self - mio::abm::TimeSpan{})
122122
.def(py::self -= mio::abm::TimeSpan{});
123123

124-
py::class_<mio::abm::LocationId>(m, "LocationId")
124+
pymio::bind_class<mio::abm::LocationId, pymio::EnablePickling::Never>(m, "LocationId")
125125
.def(py::init([](uint32_t idx, mio::abm::LocationType type) {
126126
return mio::abm::LocationId{idx, type};
127127
}))
@@ -130,36 +130,36 @@ PYBIND11_MODULE(_simulation_abm, m)
130130
.def(py::self == py::self)
131131
.def(py::self != py::self);
132132

133-
py::class_<mio::abm::Person>(m, "Person")
133+
pymio::bind_class<mio::abm::Person, pymio::EnablePickling::Never>(m, "Person")
134134
.def("set_assigned_location", py::overload_cast<mio::abm::LocationId>(&mio::abm::Person::set_assigned_location))
135135
.def_property_readonly("location", py::overload_cast<>(&mio::abm::Person::get_location, py::const_))
136136
.def_property_readonly("age", &mio::abm::Person::get_age)
137137
.def_property_readonly("is_in_quarantine", &mio::abm::Person::is_in_quarantine);
138138

139-
py::class_<mio::abm::TestingCriteria>(m, "TestingCriteria")
139+
pymio::bind_class<mio::abm::TestingCriteria, pymio::EnablePickling::Never>(m, "TestingCriteria")
140140
.def(py::init<const std::vector<mio::AgeGroup>&, const std::vector<mio::abm::InfectionState>&>(),
141141
py::arg("age_groups"), py::arg("infection_states"));
142+
143+
pymio::bind_class<mio::abm::GenericTest, pymio::EnablePickling::Never>(m, "GenericTest").def(py::init<>());
144+
pymio::bind_class<mio::abm::AntigenTest, pymio::EnablePickling::Never, mio::abm::GenericTest>(m, "AntigenTest").def(py::init<>());
145+
pymio::bind_class<mio::abm::PCRTest, pymio::EnablePickling::Never, mio::abm::GenericTest>(m, "PCRTest").def(py::init<>());
142146

143-
py::class_<mio::abm::GenericTest>(m, "GenericTest").def(py::init<>());
144-
py::class_<mio::abm::AntigenTest, mio::abm::GenericTest>(m, "AntigenTest").def(py::init<>());
145-
py::class_<mio::abm::PCRTest, mio::abm::GenericTest>(m, "PCRTest").def(py::init<>());
146-
147-
py::class_<mio::abm::TestingScheme>(m, "TestingScheme")
147+
pymio::bind_class<mio::abm::TestingScheme, pymio::EnablePickling::Never>(m, "TestingScheme")
148148
.def(py::init<const mio::abm::TestingCriteria&, mio::abm::TimeSpan, mio::abm::TimePoint, mio::abm::TimePoint,
149149
const mio::abm::GenericTest&, double>(),
150150
py::arg("testing_criteria"), py::arg("testing_min_time_since_last_test"), py::arg("start_date"),
151151
py::arg("end_date"), py::arg("test_type"), py::arg("probability"))
152152
.def_property_readonly("active", &mio::abm::TestingScheme::is_active);
153153

154-
py::class_<mio::abm::Vaccination>(m, "Vaccination")
154+
pymio::bind_class<mio::abm::Vaccination, pymio::EnablePickling::Never>(m, "Vaccination")
155155
.def(py::init<mio::abm::ExposureType, mio::abm::TimePoint>(), py::arg("exposure_type"), py::arg("time"))
156156
.def_readwrite("exposure_type", &mio::abm::Vaccination::exposure_type)
157157
.def_readwrite("time", &mio::abm::Vaccination::time);
158158

159-
py::class_<mio::abm::TestingStrategy>(m, "TestingStrategy")
159+
pymio::bind_class<mio::abm::TestingStrategy, pymio::EnablePickling::Never>(m, "TestingStrategy")
160160
.def(py::init<const std::unordered_map<mio::abm::LocationId, std::vector<mio::abm::TestingScheme>>&>());
161161

162-
py::class_<mio::abm::Location>(m, "Location")
162+
pymio::bind_class<mio::abm::Location, pymio::EnablePickling::Never>(m, "Location")
163163
.def_property_readonly("type", &mio::abm::Location::get_type)
164164
.def_property_readonly("index", &mio::abm::Location::get_index)
165165
.def_property("infection_parameters",
@@ -172,7 +172,7 @@ PYBIND11_MODULE(_simulation_abm, m)
172172
pymio::bind_Range<decltype(std::declval<mio::abm::World>().get_locations())>(m, "_WorldLocationsRange");
173173
pymio::bind_Range<decltype(std::declval<mio::abm::World>().get_persons())>(m, "_WorldPersonsRange");
174174

175-
py::class_<mio::abm::Trip>(m, "Trip")
175+
pymio::bind_class<mio::abm::Trip, pymio::EnablePickling::Never>(m, "Trip")
176176
.def(py::init<uint32_t, mio::abm::TimePoint, mio::abm::LocationId, mio::abm::LocationId,
177177
std::vector<uint32_t>>(),
178178
py::arg("person_id"), py::arg("time"), py::arg("destination"), py::arg("origin"),
@@ -183,13 +183,13 @@ PYBIND11_MODULE(_simulation_abm, m)
183183
.def_readwrite("origin", &mio::abm::Trip::migration_origin)
184184
.def_readwrite("cells", &mio::abm::Trip::cells);
185185

186-
py::class_<mio::abm::TripList>(m, "TripList")
186+
pymio::bind_class<mio::abm::TripList, pymio::EnablePickling::Never>(m, "TripList")
187187
.def(py::init<>())
188188
.def("add_trip", &mio::abm::TripList::add_trip, py::arg("trip"), py::arg("weekend") = false)
189189
.def("next_trip", &mio::abm::TripList::get_next_trip, py::arg("weekend") = false)
190190
.def("num_trips", &mio::abm::TripList::num_trips, py::arg("weekend") = false);
191191

192-
py::class_<mio::abm::World>(m, "World")
192+
pymio::bind_class<mio::abm::World, pymio::EnablePickling::Never>(m, "World")
193193
.def(py::init<int32_t>())
194194
.def("add_location", &mio::abm::World::add_location, py::arg("location_type"), py::arg("num_cells") = 1)
195195
.def("add_person", &mio::abm::World::add_person, py::arg("location_id"), py::arg("age_group"),
@@ -213,7 +213,7 @@ PYBIND11_MODULE(_simulation_abm, m)
213213
},
214214
py::return_value_policy::reference_internal);
215215

216-
py::class_<mio::abm::Simulation>(m, "Simulation")
216+
pymio::bind_class<mio::abm::Simulation, pymio::EnablePickling::Never>(m, "Simulation")
217217
.def(py::init<mio::abm::TimePoint, size_t>())
218218
.def("advance",
219219
static_cast<void (mio::abm::Simulation::*)(mio::abm::TimePoint)>(&mio::abm::Simulation::advance),

pycode/memilio-simulation/memilio/simulation/compartments/compartmentalmodel.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#define PYMIO_COMPARTMENTALMODEL_H
2222

2323
#include "memilio/compartments/compartmentalmodel.h"
24+
#include "pybind_util.h"
2425

2526
#include "pybind11/pybind11.h"
2627

@@ -30,22 +31,28 @@ namespace pymio
3031
/*
3132
* @brief bind a compartmental model for any Populations and Parameters class
3233
*/
33-
template <class InfectionState, class Populations, class Parameters>
34+
template <class InfectionState, class Populations, class Parameters, EnablePickling F>
3435
void bind_CompartmentalModel(pybind11::module_& m, std::string const& name)
3536
{
3637
using Model = mio::CompartmentalModel<InfectionState, Populations, Parameters>;
37-
pybind11::class_<Model>(m, name.c_str())
38+
bind_class<Model, F>(m, name.c_str())
3839
.def(pybind11::init<Populations const&, Parameters const&>())
3940
.def("apply_constraints", &Model::template apply_constraints)
4041
.def("check_constraints", &Model::template check_constraints)
4142
.def("get_initial_values", &Model::get_initial_values)
4243
.def_property(
43-
"populations", [](const Model& self) -> auto& { return self.populations; },
44+
"populations",
45+
[](const Model& self) -> auto& {
46+
return self.populations;
47+
},
4448
[](Model& self, Populations& p) {
4549
self.populations = p;
4650
})
4751
.def_property(
48-
"parameters", [](const Model& self) -> auto& { return self.parameters; },
52+
"parameters",
53+
[](const Model& self) -> auto& {
54+
return self.parameters;
55+
},
4956
[](Model& self, Parameters& p) {
5057
self.parameters = p;
5158
});

pycode/memilio-simulation/memilio/simulation/compartments/flow_simulation.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,17 @@
2121
#define PYMIO_FLOW_SIMULATION_H
2222

2323
#include "memilio/compartments/flow_simulation.h"
24+
#include "pybind_util.h"
2425

2526
#include "pybind11/pybind11.h"
2627

2728
namespace pymio
2829
{
2930

30-
template <class Model>
31+
template <class Model, EnablePickling F>
3132
void bind_Flow_Simulation(pybind11::module_& m)
3233
{
33-
pybind11::class_<mio::FlowSimulation<Model>>(m, "FlowSimulation")
34+
bind_class<mio::FlowSimulation<Model>, F>(m, "FlowSimulation")
3435
.def(pybind11::init<const Model&, double, double>(), pybind11::arg("model"), pybind11::arg("t0") = 0,
3536
pybind11::arg("dt") = 0.1)
3637
.def_property_readonly("result",

pycode/memilio-simulation/memilio/simulation/compartments/simulation.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#define PYMIO_SIMULATION_H
2222

2323
#include "memilio/compartments/simulation.h"
24+
#include "pybind_util.h"
2425

2526
#include "pybind11/pybind11.h"
2627

@@ -33,7 +34,7 @@ namespace pymio
3334
template <class Simulation>
3435
void bind_Simulation(pybind11::module_& m, std::string const& name)
3536
{
36-
pybind11::class_<Simulation>(m, name.c_str())
37+
bind_class<Simulation, EnablePickling::IfAvailable>(m, name.c_str())
3738
.def(pybind11::init<const typename Simulation::Model&, double, double>(), pybind11::arg("model"),
3839
pybind11::arg("t0") = 0, pybind11::arg("dt") = 0.1)
3940
.def_property_readonly("result", pybind11::overload_cast<>(&Simulation::get_result, pybind11::const_),

pycode/memilio-simulation/memilio/simulation/epidemiology/damping_sampling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace pymio
3232

3333
void bind_damping_sampling(py::module_& m, std::string const& name)
3434
{
35-
pymio::pybind_pickle_class<mio::DampingSampling>(m, name.c_str())
35+
bind_class<mio::DampingSampling, EnablePickling::Required>(m, name.c_str())
3636
.def(py::init([](const mio::UncertainValue& value, int level, int type, double time,
3737
const std::vector<size_t>& matrices, const Eigen::Ref<const Eigen::VectorXd>& groups) {
3838
return mio::DampingSampling(value, mio::DampingLevel(level), mio::DampingType(type),

pycode/memilio-simulation/memilio/simulation/epidemiology/dynamic_npis.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void bind_dynamicNPI_members(pybind11::module_& m, std::string const& name)
3636
{
3737
bind_Range<decltype(std::declval<mio::DynamicNPIs>().get_thresholds())>(m, "_ThresholdRange");
3838
using C = mio::DynamicNPIs;
39-
pybind11::class_<C>(m, name.c_str())
39+
bind_class<C, EnablePickling::Required>(m, name.c_str())
4040
.def(pybind11::init<>())
4141
.def_property(
4242
"interval",

0 commit comments

Comments
 (0)