Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7adbe86
Implement first draft of costum class bind
May 11, 2023
17513b5
Add pickling for json serializable classes
Jun 15, 2023
6583696
Change pickle class call
Jul 13, 2023
189b2c3
Merge branch 'main' into 636-make-python-serialization-usable-again
Jul 13, 2023
114d386
change function calls
Jul 27, 2023
4cca1c8
Merge branch 'main' into 636-make-python-serialization-usable-again
Aug 25, 2023
c46e5b8
Working draft for first option
Aug 25, 2023
a1653d9
Different strategys for serialization
Aug 31, 2023
7b03c3e
Fix deserialize of agegroup
Aug 31, 2023
1fefc9f
Test for all types of serialization
Sep 4, 2023
fc3ffd6
Add flag to decide on pickling strategy
Sep 28, 2023
ff7c64b
rewrite usage of tag
Oct 10, 2023
85ae09f
Rename
Oct 12, 2023
41f453a
change detection of pickling
Oct 12, 2023
d72b23e
change order of template parameter
Oct 17, 2023
d73c57c
first draft of documentation
Oct 17, 2023
dfbff79
change enum binding
Oct 26, 2023
2725198
optional arguments for pybind::class_
Oct 26, 2023
a2382b7
changes to doc
Oct 31, 2023
c05df9a
Merge branch 'main' into 636-make-python-serialization-usable-again
Oct 31, 2023
76359eb
Merge branch 'main' into 636-make-python-serialization-usable-again
Nov 28, 2023
b9c4ca5
Changes review
Jan 9, 2024
4235a5b
Merge branch 'main' into 636-make-python-serialization-usable-again
Jan 9, 2024
c4655de
Rewrite doxygen for bind_class
Apr 29, 2024
7149822
Merge branch 'main' into 636-make-python-serialization-usable-again
Apr 29, 2024
0217f3c
update new bindings from main
Apr 29, 2024
2254be3
Add pickling flag
Apr 29, 2024
7a5773c
try fix abm
May 6, 2024
b5af225
revert last commit
May 6, 2024
6c5689c
Merge branch 'main' into 636-make-python-serialization-usable-again
May 15, 2024
0059e92
Add constructor for deserialize
May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cpp/memilio/epidemiology/age_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ struct AgeGroup : public Index<AgeGroup> {
: Index<AgeGroup>(val)
{
}

/**
* Override deserialize of base class
* @see mio::Index::deserialize
*/
template <class IOContext>
static IOResult<AgeGroup> deserialize(IOContext& io)
{
BOOST_OUTCOME_TRY(auto&& i, mio::deserialize(io, Tag<size_t>{}));
return success(AgeGroup(i));
}
};

} // namespace mio
Expand Down
4 changes: 2 additions & 2 deletions cpp/memilio/utils/metaprogramming.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ template <class B1>
struct disjunction<B1> : B1 {
//disjunction of one element is identity
};
template <class B1, class... Bn>
struct disjunction<B1, Bn...> : std::conditional<bool(B1::value), B1, disjunction<Bn...>> {
template<class B1, class... Bn>
struct disjunction<B1, Bn...> : std::conditional_t<bool(B1::value), B1, disjunction<Bn...>> {
//disjunction of mutliple elements is equal to the first element if the first element is true.
//otherwise its equal to the disjunction of the remaining elements.
};
Expand Down
35 changes: 35 additions & 0 deletions cpp/models/ode_seir/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class Model : public FlowModel<InfectionState, Populations<AgeGroup, InfectionSt
using Base = FlowModel<InfectionState, mio::Populations<AgeGroup, InfectionState>, Parameters, Flows>;

public:
Model(const Populations& pop, const ParameterSet& params)
: Base(pop, params)
{
}

Model(int num_agegroups)
: Base(Populations({AgeGroup(num_agegroups), InfectionState::Count}), ParameterSet(AgeGroup(num_agegroups)))
{
Expand Down Expand Up @@ -201,6 +206,36 @@ class Model : public FlowModel<InfectionState, Populations<AgeGroup, InfectionSt
auto result = linear_interpolation(t_value, y.get_time(time_late - 1), y.get_time(time_late), y1, y2);
return mio::success(static_cast<ScalarType>(result));
}

/**
* serialize this.
* @see mio::serialize
*/
template <class IOContext>
void serialize(IOContext& io) const
{
auto obj = io.create_object("Model");
obj.add_element("Parameters", parameters);
obj.add_element("Populations", populations);
}

/**
* deserialize an object of this class.
* @see mio::deserialize
*/
template <class IOContext>
static IOResult<Model> deserialize(IOContext& io)
{
auto obj = io.expect_object("Model");
auto par = obj.expect_element("Parameters", Tag<ParameterSet>{});
auto pop = obj.expect_element("Populations", Tag<Populations>{});
return apply(
io,
[](auto&& par_, auto&& pop_) {
return Model{pop_, par_};
},
par, pop);
}
};

} // namespace oseir
Expand Down
35 changes: 35 additions & 0 deletions cpp/models/ode_sir/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class Model : public CompartmentalModel<InfectionState, Populations<AgeGroup, In
using Base = CompartmentalModel<InfectionState, mio::Populations<AgeGroup, InfectionState>, Parameters>;

public:
Model(const Populations& pop, const ParameterSet& params)
: Base(pop, params)
{
}

Model(int num_agegroups)
: Base(Populations({AgeGroup(num_agegroups), InfectionState::Count}), ParameterSet(AgeGroup(num_agegroups)))
{
Expand Down Expand Up @@ -80,6 +85,36 @@ class Model : public CompartmentalModel<InfectionState, Populations<AgeGroup, In

}
}

/**
* serialize this.
* @see mio::serialize
*/
template <class IOContext>
void serialize(IOContext& io) const
{
auto obj = io.create_object("Model");
obj.add_element("Parameters", parameters);
obj.add_element("Populations", populations);
}

/**
* deserialize an object of this class.
* @see mio::deserialize
*/
template <class IOContext>
static IOResult<Model> deserialize(IOContext& io)
{
auto obj = io.expect_object("Model");
auto par = obj.expect_element("Parameters", Tag<ParameterSet>{});
auto pop = obj.expect_element("Populations", Tag<Populations>{});
return apply(
io,
[](auto&& par_, auto&& pop_) {
return Model{pop_, par_};
},
par, pop);
}
};

} // namespace osir
Expand Down
44 changes: 22 additions & 22 deletions pycode/memilio-simulation/memilio/simulation/abm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace py = pybind11;

PYBIND11_MODULE(_simulation_abm, m)
{
pymio::iterable_enum<mio::abm::InfectionState>(m, "InfectionState", py::module_local{})
pymio::iterable_enum<mio::abm::InfectionState>(m, "InfectionState")
.value("Susceptible", mio::abm::InfectionState::Susceptible)
.value("Exposed", mio::abm::InfectionState::Exposed)
.value("InfectedNoSymptoms", mio::abm::InfectionState::InfectedNoSymptoms)
Expand Down Expand Up @@ -62,21 +62,21 @@ PYBIND11_MODULE(_simulation_abm, m)
.value("PublicTransport", mio::abm::LocationType::PublicTransport)
.value("TransportWithoutContact", mio::abm::LocationType::TransportWithoutContact);

py::class_<mio::abm::TestParameters>(m, "TestParameters")
pymio::bind_class<mio::abm::TestParameters, pymio::EnablePickling::Never>(m, "TestParameters")
.def(py::init<double, double>())
.def_readwrite("sensitivity", &mio::abm::TestParameters::sensitivity)
.def_readwrite("specificity", &mio::abm::TestParameters::specificity);

pymio::bind_CustomIndexArray<mio::UncertainValue, mio::abm::VirusVariant, mio::AgeGroup>(m, "_AgeParameterArray");
pymio::bind_Index<mio::abm::ExposureType>(m, "ExposureTypeIndex");
pymio::bind_ParameterSet<mio::abm::ParametersBase>(m, "ParametersBase");
py::class_<mio::abm::Parameters, mio::abm::ParametersBase>(m, "Parameters")
pymio::bind_ParameterSet<mio::abm::ParametersBase, pymio::EnablePickling::Never>(m, "ParametersBase");
pymio::bind_class<mio::abm::Parameters, pymio::EnablePickling::Never, mio::abm::ParametersBase>(m, "Parameters")
.def(py::init<int>())
.def("check_constraints", &mio::abm::Parameters::check_constraints);

pymio::bind_ParameterSet<mio::abm::LocalInfectionParameters>(m, "LocalInfectionParameters").def(py::init<size_t>());
pymio::bind_ParameterSet<mio::abm::LocalInfectionParameters, pymio::EnablePickling::Never>(m, "LocalInfectionParameters").def(py::init<size_t>());

py::class_<mio::abm::TimeSpan>(m, "TimeSpan")
pymio::bind_class<mio::abm::TimeSpan, pymio::EnablePickling::Never>(m, "TimeSpan")
.def(py::init<int>(), py::arg("seconds") = 0)
.def_property_readonly("seconds", &mio::abm::TimeSpan::seconds)
.def_property_readonly("hours", &mio::abm::TimeSpan::hours)
Expand All @@ -101,7 +101,7 @@ PYBIND11_MODULE(_simulation_abm, m)
m.def("hours", &mio::abm::hours);
m.def("days", py::overload_cast<int>(&mio::abm::days));

py::class_<mio::abm::TimePoint>(m, "TimePoint")
pymio::bind_class<mio::abm::TimePoint, pymio::EnablePickling::Never>(m, "TimePoint")
.def(py::init<int>(), py::arg("seconds") = 0)
.def_property_readonly("seconds", &mio::abm::TimePoint::seconds)
.def_property_readonly("days", &mio::abm::TimePoint::days)
Expand All @@ -121,7 +121,7 @@ PYBIND11_MODULE(_simulation_abm, m)
.def(py::self - mio::abm::TimeSpan{})
.def(py::self -= mio::abm::TimeSpan{});

py::class_<mio::abm::LocationId>(m, "LocationId")
pymio::bind_class<mio::abm::LocationId, pymio::EnablePickling::Never>(m, "LocationId")
.def(py::init([](uint32_t idx, mio::abm::LocationType type) {
return mio::abm::LocationId{idx, type};
}))
Expand All @@ -130,36 +130,36 @@ PYBIND11_MODULE(_simulation_abm, m)
.def(py::self == py::self)
.def(py::self != py::self);

py::class_<mio::abm::Person>(m, "Person")
pymio::bind_class<mio::abm::Person, pymio::EnablePickling::Never>(m, "Person")
.def("set_assigned_location", py::overload_cast<mio::abm::LocationId>(&mio::abm::Person::set_assigned_location))
.def_property_readonly("location", py::overload_cast<>(&mio::abm::Person::get_location, py::const_))
.def_property_readonly("age", &mio::abm::Person::get_age)
.def_property_readonly("is_in_quarantine", &mio::abm::Person::is_in_quarantine);

py::class_<mio::abm::TestingCriteria>(m, "TestingCriteria")
pymio::bind_class<mio::abm::TestingCriteria, pymio::EnablePickling::Never>(m, "TestingCriteria")
.def(py::init<const std::vector<mio::AgeGroup>&, const std::vector<mio::abm::InfectionState>&>(),
py::arg("age_groups"), py::arg("infection_states"));

pymio::bind_class<mio::abm::GenericTest, pymio::EnablePickling::Never>(m, "GenericTest").def(py::init<>());
pymio::bind_class<mio::abm::AntigenTest, pymio::EnablePickling::Never, mio::abm::GenericTest>(m, "AntigenTest").def(py::init<>());
pymio::bind_class<mio::abm::PCRTest, pymio::EnablePickling::Never, mio::abm::GenericTest>(m, "PCRTest").def(py::init<>());

py::class_<mio::abm::GenericTest>(m, "GenericTest").def(py::init<>());
py::class_<mio::abm::AntigenTest, mio::abm::GenericTest>(m, "AntigenTest").def(py::init<>());
py::class_<mio::abm::PCRTest, mio::abm::GenericTest>(m, "PCRTest").def(py::init<>());

py::class_<mio::abm::TestingScheme>(m, "TestingScheme")
pymio::bind_class<mio::abm::TestingScheme, pymio::EnablePickling::Never>(m, "TestingScheme")
.def(py::init<const mio::abm::TestingCriteria&, mio::abm::TimeSpan, mio::abm::TimePoint, mio::abm::TimePoint,
const mio::abm::GenericTest&, double>(),
py::arg("testing_criteria"), py::arg("testing_min_time_since_last_test"), py::arg("start_date"),
py::arg("end_date"), py::arg("test_type"), py::arg("probability"))
.def_property_readonly("active", &mio::abm::TestingScheme::is_active);

py::class_<mio::abm::Vaccination>(m, "Vaccination")
pymio::bind_class<mio::abm::Vaccination, pymio::EnablePickling::Never>(m, "Vaccination")
.def(py::init<mio::abm::ExposureType, mio::abm::TimePoint>(), py::arg("exposure_type"), py::arg("time"))
.def_readwrite("exposure_type", &mio::abm::Vaccination::exposure_type)
.def_readwrite("time", &mio::abm::Vaccination::time);

py::class_<mio::abm::TestingStrategy>(m, "TestingStrategy")
pymio::bind_class<mio::abm::TestingStrategy, pymio::EnablePickling::Never>(m, "TestingStrategy")
.def(py::init<const std::unordered_map<mio::abm::LocationId, std::vector<mio::abm::TestingScheme>>&>());

py::class_<mio::abm::Location>(m, "Location")
pymio::bind_class<mio::abm::Location, pymio::EnablePickling::Never>(m, "Location")
.def_property_readonly("type", &mio::abm::Location::get_type)
.def_property_readonly("index", &mio::abm::Location::get_index)
.def_property("infection_parameters",
Expand All @@ -172,7 +172,7 @@ PYBIND11_MODULE(_simulation_abm, m)
pymio::bind_Range<decltype(std::declval<mio::abm::World>().get_locations())>(m, "_WorldLocationsRange");
pymio::bind_Range<decltype(std::declval<mio::abm::World>().get_persons())>(m, "_WorldPersonsRange");

py::class_<mio::abm::Trip>(m, "Trip")
pymio::bind_class<mio::abm::Trip, pymio::EnablePickling::Never>(m, "Trip")
.def(py::init<uint32_t, mio::abm::TimePoint, mio::abm::LocationId, mio::abm::LocationId,
std::vector<uint32_t>>(),
py::arg("person_id"), py::arg("time"), py::arg("destination"), py::arg("origin"),
Expand All @@ -183,13 +183,13 @@ PYBIND11_MODULE(_simulation_abm, m)
.def_readwrite("origin", &mio::abm::Trip::migration_origin)
.def_readwrite("cells", &mio::abm::Trip::cells);

py::class_<mio::abm::TripList>(m, "TripList")
pymio::bind_class<mio::abm::TripList, pymio::EnablePickling::Never>(m, "TripList")
.def(py::init<>())
.def("add_trip", &mio::abm::TripList::add_trip, py::arg("trip"), py::arg("weekend") = false)
.def("next_trip", &mio::abm::TripList::get_next_trip, py::arg("weekend") = false)
.def("num_trips", &mio::abm::TripList::num_trips, py::arg("weekend") = false);

py::class_<mio::abm::World>(m, "World")
pymio::bind_class<mio::abm::World, pymio::EnablePickling::Never>(m, "World")
.def(py::init<int32_t>())
.def("add_location", &mio::abm::World::add_location, py::arg("location_type"), py::arg("num_cells") = 1)
.def("add_person", &mio::abm::World::add_person, py::arg("location_id"), py::arg("age_group"),
Expand All @@ -213,7 +213,7 @@ PYBIND11_MODULE(_simulation_abm, m)
},
py::return_value_policy::reference_internal);

py::class_<mio::abm::Simulation>(m, "Simulation")
pymio::bind_class<mio::abm::Simulation, pymio::EnablePickling::Never>(m, "Simulation")
.def(py::init<mio::abm::TimePoint, size_t>())
.def("advance",
static_cast<void (mio::abm::Simulation::*)(mio::abm::TimePoint)>(&mio::abm::Simulation::advance),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#define PYMIO_COMPARTMENTALMODEL_H

#include "memilio/compartments/compartmentalmodel.h"
#include "pybind_util.h"

#include "pybind11/pybind11.h"

Expand All @@ -30,22 +31,28 @@ namespace pymio
/*
* @brief bind a compartmental model for any Populations and Parameters class
*/
template <class InfectionState, class Populations, class Parameters>
template <class InfectionState, class Populations, class Parameters, EnablePickling F>
void bind_CompartmentalModel(pybind11::module_& m, std::string const& name)
{
using Model = mio::CompartmentalModel<InfectionState, Populations, Parameters>;
pybind11::class_<Model>(m, name.c_str())
bind_class<Model, F>(m, name.c_str())
.def(pybind11::init<Populations const&, Parameters const&>())
.def("apply_constraints", &Model::template apply_constraints)
.def("check_constraints", &Model::template check_constraints)
.def("get_initial_values", &Model::get_initial_values)
.def_property(
"populations", [](const Model& self) -> auto& { return self.populations; },
"populations",
[](const Model& self) -> auto& {
return self.populations;
},
[](Model& self, Populations& p) {
self.populations = p;
})
.def_property(
"parameters", [](const Model& self) -> auto& { return self.parameters; },
"parameters",
[](const Model& self) -> auto& {
return self.parameters;
},
[](Model& self, Parameters& p) {
self.parameters = p;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@
#define PYMIO_FLOW_SIMULATION_H

#include "memilio/compartments/flow_simulation.h"
#include "pybind_util.h"

#include "pybind11/pybind11.h"

namespace pymio
{

template <class Model>
template <class Model, EnablePickling F>
void bind_Flow_Simulation(pybind11::module_& m)
{
pybind11::class_<mio::FlowSimulation<Model>>(m, "FlowSimulation")
bind_class<mio::FlowSimulation<Model>, F>(m, "FlowSimulation")
.def(pybind11::init<const Model&, double, double>(), pybind11::arg("model"), pybind11::arg("t0") = 0,
pybind11::arg("dt") = 0.1)
.def_property_readonly("result",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#define PYMIO_SIMULATION_H

#include "memilio/compartments/simulation.h"
#include "pybind_util.h"

#include "pybind11/pybind11.h"

Expand All @@ -33,7 +34,7 @@ namespace pymio
template <class Simulation>
void bind_Simulation(pybind11::module_& m, std::string const& name)
{
pybind11::class_<Simulation>(m, name.c_str())
bind_class<Simulation, EnablePickling::IfAvailable>(m, name.c_str())
.def(pybind11::init<const typename Simulation::Model&, double, double>(), pybind11::arg("model"),
pybind11::arg("t0") = 0, pybind11::arg("dt") = 0.1)
.def_property_readonly("result", pybind11::overload_cast<>(&Simulation::get_result, pybind11::const_),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace pymio

void bind_damping_sampling(py::module_& m, std::string const& name)
{
pymio::pybind_pickle_class<mio::DampingSampling>(m, name.c_str())
bind_class<mio::DampingSampling, EnablePickling::Required>(m, name.c_str())
.def(py::init([](const mio::UncertainValue& value, int level, int type, double time,
const std::vector<size_t>& matrices, const Eigen::Ref<const Eigen::VectorXd>& groups) {
return mio::DampingSampling(value, mio::DampingLevel(level), mio::DampingType(type),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void bind_dynamicNPI_members(pybind11::module_& m, std::string const& name)
{
bind_Range<decltype(std::declval<mio::DynamicNPIs>().get_thresholds())>(m, "_ThresholdRange");
using C = mio::DynamicNPIs;
pybind11::class_<C>(m, name.c_str())
bind_class<C, EnablePickling::Required>(m, name.c_str())
.def(pybind11::init<>())
.def_property(
"interval",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void bind_Population(pybind11::module_& m, std::string const& name, mio::Tag<mio
catch (std::runtime_error& e) {
}

pybind11::class_<C, Base> c(m, name.c_str());
decltype(auto) c = bind_class<C, EnablePickling::Required, Base>(m, name.c_str());
c.def(pybind11::init([](mio::Index<Cats...> const& sizes, double val) {
return C(sizes, val);
}))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "memilio/epidemiology/contact_matrix.h"
#include "memilio/epidemiology/uncertain_matrix.h"
#include "memilio/epidemiology/damping_sampling.h"
#include "pybind_util.h"

#include <pybind11/stl.h>

Expand All @@ -32,7 +33,7 @@ namespace pymio

void bind_uncertain_contact_matrix(py::module_& m, std::string const& name)
{
py::class_<mio::UncertainContactMatrix>(m, name.c_str())
bind_class<mio::UncertainContactMatrix, EnablePickling::Required>(m, name.c_str())
.def(py::init<>())
.def(py::init<const mio::ContactMatrixGroup&>())
.def_property(
Expand Down
Loading