Skip to content

Commit fe6109c

Browse files
HenrZuAgathaSchmidtjubickermknaranja
authored
402 Surrogate model implementation with multiple age groups and dampings (#562)
Co-authored-by: Agatha Schmidt <agatha.schmidt@hotmail.de> Co-authored-by: jubicker <113909589+jubicker@users.noreply.github.com> Co-authored-by: Martin J. Kühn <62713180+mknaranja@users.noreply.github.com>
1 parent c18828e commit fe6109c

File tree

9 files changed

+1239
-6
lines changed

9 files changed

+1239
-6
lines changed

pycode/memilio-surrogatemodel/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ The package currently provides the following modules:
2424
- `models`: models for different specific tasks
2525
Currently we have the following models:
2626
- `ode_secir_simple`: A simple model allowing for asymptomatic as well as symptomatic infection not stratified by age groups.
27+
- `ode_secir_groups`: A model allowing for asymptomatic as well as symptomatic infection stratified by age groups and including one damping.
2728

2829
Each model folder contains the following files:
2930
- `data_generation`: data generated from expert model simulation.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# SECIR model with multiple age groups and one damping
2+
3+
This model is an application of the SECIR model implemented in https://github.com/DLR-SC/memilio/tree/main/cpp/models/ode_secir/ stratified by age groups using one damping to represent a change in the contact matrice.
4+
The example is based on https://github.com/DLR-SC/memilio/tree/main/pycode/examples/simulation/secir_groups.py which uses python bindings to run the underlying C++ code.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#############################################################################
2+
# Copyright (C) 2020-2023 German Aerospace Center (DLR-SC)
3+
#
4+
# Authors: Agatha Schmidt, Henrik Zunker
5+
#
6+
# Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
#############################################################################
20+
21+
"""
22+
A surrogate model for a SECIR model allowing for asymptomatic as well as symptomatic infection stratified by age groups.
23+
"""
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
#############################################################################
2+
# Copyright (C) 2020-2023 German Aerospace Center (DLR-SC)
3+
#
4+
# Authors: Agatha Schmidt, Henrik Zunker
5+
#
6+
# Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
#############################################################################
20+
import copy
21+
import os
22+
import pickle
23+
import random
24+
import json
25+
from datetime import date
26+
27+
import numpy as np
28+
import tensorflow as tf
29+
from progress.bar import Bar
30+
from sklearn.preprocessing import FunctionTransformer
31+
32+
from memilio.simulation import (AgeGroup, Damping, LogLevel, set_log_level)
33+
from memilio.simulation.secir import (Index_InfectionState,
34+
InfectionState, Model,
35+
interpolate_simulation_result, simulate)
36+
37+
38+
def interpolate_age_groups(data_entry):
39+
"""! Interpolates the age groups from the population data into the age groups used in the simulation.
40+
We assume that the people in the age groups are uniformly distributed.
41+
@param data_entry Data entry containing the population data.
42+
@return List containing the population in each age group used in the simulation.
43+
"""
44+
age_groups = {
45+
"A00-A04": data_entry['<3 years'] + data_entry['3-5 years'] * 2 / 3,
46+
"A05-A14": data_entry['3-5 years'] * 1 / 3 + data_entry['6-14 years'],
47+
"A15-A34": data_entry['15-17 years'] + data_entry['18-24 years'] + data_entry['25-29 years'] + data_entry['30-39 years'] * 1 / 2,
48+
"A35-A59": data_entry['30-39 years'] * 1 / 2 + data_entry['40-49 years'] + data_entry['50-64 years'] * 2 / 3,
49+
"A60-A79": data_entry['50-64 years'] * 1 / 3 + data_entry['65-74 years'] + data_entry['>74 years'] * 1 / 5,
50+
"A80+": data_entry['>74 years'] * 4 / 5
51+
}
52+
return [age_groups[key] for key in age_groups]
53+
54+
55+
def remove_confirmed_compartments(result_array):
56+
"""! Removes the confirmed compartments which are not used in the data generation.
57+
@param result_array Array containing the simulation results.
58+
@return Array containing the simulation results without the confirmed compartments.
59+
"""
60+
num_groups = int(result_array.shape[1] / 10)
61+
delete_indices = [index for i in range(
62+
num_groups) for index in (3+10*i, 5+10*i)]
63+
return np.delete(result_array, delete_indices, axis=1)
64+
65+
66+
def transform_data(data, transformer, num_runs):
67+
"""! Transforms the data by a logarithmic normalization.
68+
Reshaping is necessary, because the transformer needs an array with dimension <= 2.
69+
@param data Data to be transformed.
70+
@param transformer Transformer used for the transformation.
71+
@return Transformed data.
72+
"""
73+
data = np.asarray(data).transpose(2, 0, 1).reshape(48, -1)
74+
scaled_data = transformer.transform(data)
75+
return tf.convert_to_tensor(scaled_data.transpose().reshape(num_runs, -1, 48))
76+
77+
78+
def run_secir_groups_simulation(days, damping_day, populations):
79+
"""! Uses an ODE SECIR model allowing for asymptomatic infection with 6 different age groups. The model is not stratified by region.
80+
Virus-specific parameters are fixed and initial number of persons in the particular infection states are chosen randomly from defined ranges.
81+
@param Days Describes how many days we simulate within a single run.
82+
@param damping_day The day when damping is applied.
83+
@param populations List containing the population in each age group.
84+
@return List containing the populations in each compartment used to initialize the run.
85+
"""
86+
set_log_level(LogLevel.Off)
87+
88+
start_day = 1
89+
start_month = 1
90+
start_year = 2019
91+
dt = 0.1
92+
93+
# Define age Groups
94+
groups = ['0-4', '5-14', '15-34', '35-59', '60-79', '80+']
95+
num_groups = len(groups)
96+
97+
# Initialize Parameters
98+
model = Model(num_groups)
99+
100+
# Set parameters
101+
for i in range(num_groups):
102+
# Compartment transition duration
103+
model.parameters.IncubationTime[AgeGroup(i)] = 5.2
104+
model.parameters.TimeInfectedSymptoms[AgeGroup(i)] = 6.
105+
model.parameters.SerialInterval[AgeGroup(i)] = 4.2
106+
model.parameters.TimeInfectedSevere[AgeGroup(i)] = 12.
107+
model.parameters.TimeInfectedCritical[AgeGroup(i)] = 8.
108+
109+
# Initial number of people in each compartment with random numbers
110+
model.populations[AgeGroup(i), Index_InfectionState(
111+
InfectionState.Exposed)] = random.uniform(
112+
0.00025, 0.0005) * populations[i]
113+
model.populations[AgeGroup(i), Index_InfectionState(
114+
InfectionState.InfectedNoSymptoms)] = random.uniform(
115+
0.0001, 0.00035) * populations[i]
116+
model.populations[AgeGroup(i), Index_InfectionState(
117+
InfectionState.InfectedNoSymptomsConfirmed)] = 0
118+
model.populations[AgeGroup(i), Index_InfectionState(
119+
InfectionState.InfectedSymptoms)] = random.uniform(
120+
0.00007, 0.0001) * populations[i]
121+
model.populations[AgeGroup(i), Index_InfectionState(
122+
InfectionState.InfectedSymptomsConfirmed)] = 0
123+
model.populations[AgeGroup(i), Index_InfectionState(
124+
InfectionState.InfectedSevere)] = random.uniform(
125+
0.00003, 0.00006) * populations[i]
126+
model.populations[AgeGroup(i), Index_InfectionState(
127+
InfectionState.InfectedCritical)] = random.uniform(
128+
0.00001, 0.00002) * populations[i]
129+
model.populations[AgeGroup(i), Index_InfectionState(
130+
InfectionState.Recovered)] = random.uniform(
131+
0.002, 0.008) * populations[i]
132+
model.populations[AgeGroup(i),
133+
Index_InfectionState(InfectionState.Dead)] = 0
134+
model.populations.set_difference_from_group_total_AgeGroup(
135+
(AgeGroup(i), Index_InfectionState(InfectionState.Susceptible)), populations[i])
136+
137+
# Compartment transition propabilities
138+
model.parameters.RelativeTransmissionNoSymptoms[AgeGroup(i)] = 0.5
139+
model.parameters.TransmissionProbabilityOnContact[AgeGroup(i)] = 0.1
140+
model.parameters.RecoveredPerInfectedNoSymptoms[AgeGroup(i)] = 0.09
141+
model.parameters.RiskOfInfectionFromSymptomatic[AgeGroup(i)] = 0.25
142+
model.parameters.SeverePerInfectedSymptoms[AgeGroup(i)] = 0.2
143+
model.parameters.CriticalPerSevere[AgeGroup(i)] = 0.25
144+
model.parameters.DeathsPerCritical[AgeGroup(i)] = 0.3
145+
# twice the value of RiskOfInfectionFromSymptomatic
146+
model.parameters.MaxRiskOfInfectionFromSymptomatic[AgeGroup(i)] = 0.5
147+
148+
# StartDay is the n-th day of the year
149+
model.parameters.StartDay = (
150+
date(start_year, start_month, start_day) - date(start_year, 1, 1)).days
151+
152+
# Load baseline and minimum contact matrix and assign them to the model
153+
baseline = getBaselineMatrix()
154+
minimum = getMinimumMatrix()
155+
156+
model.parameters.ContactPatterns.cont_freq_mat[0].baseline = baseline
157+
model.parameters.ContactPatterns.cont_freq_mat[0].minimum = minimum
158+
159+
# Generate a damping matrix and assign it to the model
160+
damping = np.ones((num_groups, num_groups)
161+
) * np.float16(random.uniform(0, 0.5))
162+
163+
model.parameters.ContactPatterns.cont_freq_mat.add_damping(Damping(
164+
coeffs=(damping), t=damping_day, level=0, type=0))
165+
166+
damped_contact_matrix = model.parameters.ContactPatterns.cont_freq_mat.get_matrix_at(
167+
damping_day+1)
168+
169+
# Apply mathematical constraints to parameters
170+
model.apply_constraints()
171+
172+
# Run Simulation
173+
result = simulate(0, days, dt, model)
174+
175+
# Interpolate simulation result on days time scale
176+
result = interpolate_simulation_result(result)
177+
178+
result_array = remove_confirmed_compartments(
179+
np.transpose(result.as_ndarray()[1:, :]))
180+
181+
# Omit first column, as the time points are not of interest here.
182+
dataset_entries = copy.deepcopy(result_array)
183+
184+
return dataset_entries.tolist(), damped_contact_matrix
185+
186+
187+
def generate_data(
188+
num_runs, path_out, path_population, input_width, label_width,
189+
normalize=True, save_data=True):
190+
"""! Generate data sets of num_runs many equation-based model simulations and transforms the computed results by a log(1+x) transformation.
191+
Divides the results in input and label data sets and returns them as a dictionary of two TensorFlow Stacks.
192+
In general, we have 8 different compartments and 6 age groups. If we choose,
193+
input_width = 5 and label_width = 20, the dataset has
194+
- input with dimension 5 x 8 x 6
195+
- labels with dimension 20 x 8 x 6
196+
@param num_runs Number of times, the function run_secir_groups_simulation is called.
197+
@param path_out Path, where the dataset is saved to.
198+
@param path_population Path, where we try to read the population data.
199+
@param input_width Int value that defines the number of time series used for the input.
200+
@param label_width Int value that defines the size of the labels.
201+
@param normalize [Default: true] Option to transform dataset by logarithmic normalization.
202+
@param save_data [Default: true] Option to save the dataset.
203+
@return Data dictionary of input and label data sets.
204+
"""
205+
data = {
206+
"inputs": [],
207+
"labels": [],
208+
"contact_matrix": [],
209+
"damping_day": []
210+
}
211+
212+
# The number of days is the same as the sum of input and label width.
213+
# Since the first day of the input is day 0, we still need to subtract 1.
214+
days = input_width + label_width - 1
215+
216+
# Load population data
217+
population = get_population(path_population)
218+
219+
# show progess in terminal for longer runs
220+
# Due to the random structure, there's currently no need to shuffle the data
221+
bar = Bar('Number of Runs done', max=num_runs)
222+
for _ in range(0, num_runs):
223+
224+
# Generate a random damping day
225+
damping_day = random.randrange(
226+
input_width, input_width+label_width)
227+
228+
data_run, damped_contact_matrix = run_secir_groups_simulation(
229+
days, damping_day, population[random.randint(0, len(population) - 1)])
230+
data['inputs'].append(data_run[:input_width])
231+
data['labels'].append(data_run[input_width:])
232+
data['contact_matrix'].append(np.array(damped_contact_matrix))
233+
data['damping_day'].append(damping_day)
234+
bar.next()
235+
bar.finish()
236+
237+
if normalize:
238+
# logarithmic normalization
239+
transformer = FunctionTransformer(np.log1p, validate=True)
240+
241+
# transform inputs and labels
242+
data['inputs'] = transform_data(data['inputs'], transformer, num_runs)
243+
data['labels'] = transform_data(data['labels'], transformer, num_runs)
244+
else:
245+
data['inputs'] = tf.convert_to_tensor(data['inputs'])
246+
data['labels'] = tf.convert_to_tensor(data['labels'])
247+
248+
if save_data:
249+
# check if data directory exists. If necessary, create it.
250+
if not os.path.isdir(path_out):
251+
os.mkdir(path_out)
252+
253+
# save dict to json file
254+
with open(os.path.join(path_out, 'data_secir_groups.pickle'), 'wb') as f:
255+
pickle.dump(data, f)
256+
return data
257+
258+
259+
def getBaselineMatrix():
260+
"""! loads the baselinematrix
261+
"""
262+
263+
baseline_contact_matrix0 = os.path.join(
264+
"./data/contacts/baseline_home.txt")
265+
baseline_contact_matrix1 = os.path.join(
266+
"./data/contacts/baseline_school_pf_eig.txt")
267+
baseline_contact_matrix2 = os.path.join(
268+
"./data/contacts/baseline_work.txt")
269+
baseline_contact_matrix3 = os.path.join(
270+
"./data/contacts/baseline_other.txt")
271+
272+
baseline = np.loadtxt(baseline_contact_matrix0) \
273+
+ np.loadtxt(baseline_contact_matrix1) + \
274+
np.loadtxt(baseline_contact_matrix2) + \
275+
np.loadtxt(baseline_contact_matrix3)
276+
277+
return baseline
278+
279+
280+
def getMinimumMatrix():
281+
"""! loads the minimum matrix
282+
"""
283+
284+
minimum_contact_matrix0 = os.path.join(
285+
"./data/contacts/minimum_home.txt")
286+
minimum_contact_matrix1 = os.path.join(
287+
"./data/contacts/minimum_school_pf_eig.txt")
288+
minimum_contact_matrix2 = os.path.join(
289+
"./data/contacts/minimum_work.txt")
290+
minimum_contact_matrix3 = os.path.join(
291+
"./data/contacts/minimum_other.txt")
292+
293+
minimum = np.loadtxt(minimum_contact_matrix0) \
294+
+ np.loadtxt(minimum_contact_matrix1) + \
295+
np.loadtxt(minimum_contact_matrix2) + \
296+
np.loadtxt(minimum_contact_matrix3)
297+
298+
return minimum
299+
300+
301+
def get_population(path):
302+
"""! read population data in list from dataset
303+
@param path Path to the dataset containing the population data
304+
"""
305+
306+
with open(path) as f:
307+
data = json.load(f)
308+
population = []
309+
for data_entry in data:
310+
population.append(interpolate_age_groups(data_entry))
311+
return population
312+
313+
314+
if __name__ == "__main__":
315+
# Store data relative to current file two levels higher.
316+
path = os.path.dirname(os.path.realpath(__file__))
317+
path_output = os.path.join(os.path.dirname(os.path.realpath(
318+
os.path.dirname(os.path.realpath(path)))), 'data')
319+
320+
path_population = os.path.abspath(
321+
r"data//pydata//Germany//county_population.json")
322+
323+
input_width = 5
324+
label_width = 30
325+
num_runs = 10000
326+
data = generate_data(num_runs, path_output, path_population, input_width,
327+
label_width)

0 commit comments

Comments
 (0)