Skip to content

Commit cbfb05f

Browse files
author
Frankie Patten-Elliott
committed
Code for generating 24_10_24 protocol
1 parent a3b1d88 commit cbfb05f

10 files changed

+550
-46
lines changed

methods/classes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(self, model, protocol, times, win, conc, param_dict):
5959
self._conc = conc
6060
# TODO currently hardcoded to get number of pulses
6161
if model.split("-")[0] == 'sis':
62-
self.n_pulses = 10
62+
self.n_pulses = 20
6363
elif times[-1] != 14999.5:
6464
self.n_pulses = int(np.floor(250000/times[-1]))
6565
else:

methods/funcs.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ def generate_data(herg_model, drug_vals, prot, sd, max_time, bounds, m_sel, conc
1818
# TODO currently hardcoded to get number of sweeps
1919
if max_time != 15e3 and herg_model != '2024_Joey_sis_25C':
2020
swps = int(np.floor(250000/max_time))
21-
else:
21+
elif max_time != 15350:
2222
swps = sweeps
23+
else:
24+
swps = 20
2325

2426
# define protocol
2527
protocol = myokit.load_protocol(prot)
@@ -161,11 +163,12 @@ def model_outputs(model_pars, herg_model, prot, times, concs, alt_protocol = Non
161163
# get herg parameters
162164
if herg_model == '2019_37C':
163165
herg_vals = [2.07e-3, 7.17e-2, 3.44e-5, 6.18e-2, 4.18e-1, 2.58e-2, 4.75e-2, 2.51e-2, 33.3]
164-
elif herg_model == 'kemp':
166+
elif herg_model == 'kemp' or herg_model == '2024_Joey_sis_25C':
165167
herg_vals = []
166168
model_out = {}
167169
# get no. of sweeps such that the total length is approximately 250s
168-
swps = int(np.floor(250000/(times[-1]+alt_times[-1])))
170+
#swps = int(np.floor(250000/(times[-1]+alt_times[-1])))
171+
swps = 5
169172
win = (times >= wins[0]) & (times < wins[1])
170173
if alt_protocol is not None:
171174
win_alt = (alt_times >= wins_alt[0]) & (alt_times < wins_alt[1])
@@ -181,39 +184,44 @@ def model_outputs(model_pars, herg_model, prot, times, concs, alt_protocol = Non
181184
analytical=True)
182185
# set hERG model parameters
183186
model.set_fix_parameters({p:v for p, v in zip(herg_pars, herg_vals)})
184-
else:
187+
elif herg_model == 'kemp':
185188
model = Model(f'kemp-m{m}',
186189
prot,
187190
parameters=['binding'],
188191
analytical=True)
192+
elif herg_model == '2024_Joey_sis_25C':
193+
model = Model(f'sis-m{m}',
194+
prot,
195+
parameters=['binding'],
196+
analytical=True)
189197
# fix kt if necessary
190198
if m in ['12', '13']:
191199
model.fix_kt()
192200
try:
193201
# loop to simulate and append proportion open model output for no. of sweeps
194202
model_milnes = []
195-
control = []
196-
model.set_dose(0)
197-
before = model.simulate(binding_params, times)[win]
198-
before_alt = model.simulate(binding_params, alt_times)[win_alt]
203+
#control = []
204+
#model.set_dose(0)
205+
#before = model.simulate(binding_params, times)[win]
206+
#before_alt = model.simulate(binding_params, alt_times)[win_alt]
199207
model.set_dose(conc)
200208
after = model.simulate(binding_params, times)[win]
201-
model_milnes = np.append(model_milnes, after/before)
202-
control = np.append(control, before)
209+
model_milnes = np.append(model_milnes, after)
210+
#control = np.append(control, before)
203211
for i in range(swps*2-1):
204212
if (alt_protocol is not None) & ((i % 2) == 0):
205213
model.change_protocol(alt_protocol)
206214
after = model.simulate(binding_params, alt_times, reset=False)[win_alt]
207-
model_milnes = np.append(model_milnes, after/before_alt)
208-
control = np.append(control, before_alt)
215+
model_milnes = np.append(model_milnes, after)
216+
#control = np.append(control, before_alt)
209217
model.change_protocol(prot)
210218
else:
211219
after = model.simulate(binding_params, times, reset=False)[win]
212-
model_milnes = np.append(model_milnes, after/before)
213-
control = np.append(control, before)
220+
model_milnes = np.append(model_milnes, after)
221+
#control = np.append(control, before)
214222
model_out[m][conc] = model_milnes
215-
save_control = control
223+
#save_control = control
216224
except:
217225
model_out[m][conc] = np.ones(times.shape) * float('inf')
218-
save_control = np.ones(times.shape) * float('inf')
219-
return model_out, save_control, swps
226+
#save_control = np.ones(times.shape) * float('inf')
227+
return model_out, swps
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[[protocol]]
2+
# Milnes protocol
3+
# Level Start Length Period Multiplier
4+
-80 0.0 1350.0 25350 10
5+
0 1350.0 10000.0 25350 10
6+
-80 11350.0 14000.0 25350 10

src/fit_models.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def parse_list_of_lists(s):
3333
parser.add_argument('-t', type=float, default=15e3, help='Max time')
3434
parser.add_argument('-b', type=parse_list_of_lists, default="[[1e3, 11e3]]", help='Protocol window(s) of interest')
3535
parser.add_argument('-c', type = str, help='Drug compound string')
36+
parser.add_argument('-d', action='store_true', help='Enable dual fitting of Milnes and optimal protocol data')
3637
args = parser.parse_args()
3738

3839
def get_pars(model_num):
@@ -43,12 +44,21 @@ def get_pars(model_num):
4344
if herg_model != 'kemp' and herg_model != '2024_Joey_sis_25C':
4445
model = classes.ConcatMilnesModel(f'm{model_num}', protocol, times,
4546
win, conc, param_dict)
47+
if args.d:
48+
model_m = classes.ConcatMilnesModel(f'm{model_num}', 'protocols/Milnes_Phil_Trans.mmt', times_m,
49+
win_m, conc, param_dict)
4650
elif herg_model == 'kemp':
4751
model = classes.ConcatMilnesModel(f'kemp-m{model_num}', protocol, times,
4852
win, conc, param_dict)
53+
if args.d:
54+
model_m = classes.ConcatMilnesModel(f'kemp-m{model_num}', 'protocols/Milnes_Phil_Trans.mmt', times_m,
55+
win_m, conc, param_dict)
4956
elif herg_model == '2024_Joey_sis_25C':
5057
model = classes.ConcatMilnesModel(f'sis-m{model_num}', protocol, times,
5158
win, conc, param_dict)
59+
if args.d:
60+
model_m = classes.ConcatMilnesModel(f'sis-m{model_num}', 'protocols/Milnes_Phil_Trans.mmt', times_m,
61+
win_m, conc, param_dict)
5262
# Load data
5363
u = np.loadtxt(
5464
f'{outdir}/fb_synthetic_conc_{conc}.csv',
@@ -57,10 +67,26 @@ def get_pars(model_num):
5767
)
5868
concat_time = u[:, 0]
5969
concat_milnes = u[:, 1]
70+
if args.d:
71+
# Load data
72+
u_m = np.loadtxt(
73+
f'{outdir.rsplit("/",1)[0]}/fb_synthetic_conc_{conc}.csv',
74+
delimiter=',',
75+
skiprows=1
76+
)
77+
concat_time_m = u_m[:, 0]
78+
concat_milnes_m = u_m[:, 1]
6079
# Create single output problem
6180
problem = pints.SingleOutputProblem(model, concat_time, concat_milnes)
6281
likelihoods.append(classes.NormalRatioLogLikelihood(problem, mu_y))
63-
f = pints.SumOfIndependentLogPDFs(likelihoods)
82+
if args.d:
83+
problem_m = pints.SingleOutputProblem(model_m, concat_time_m, concat_milnes_m)
84+
likelihoods.append(classes.NormalRatioLogLikelihood(problem_m, mu_y_m))
85+
86+
if len(likelihoods) > 1:
87+
f = pints.SumOfIndependentLogPDFs(likelihoods)
88+
else:
89+
f = likelihoods[0]
6490
bounds = boundaries.Boundaries(model_num, fix_hill=False, likelihood=True)
6591

6692
# Fix random seed for reproducibility
@@ -114,8 +140,12 @@ def get_pars(model_num):
114140
concs = [30, 100, 300]
115141
elif args.c == 'quinidine':
116142
concs = [150, 500, 1500]
143+
elif args.c == 'terfenadine':
144+
concs = [30, 100, 300]
117145
elif args.c == 'verapamil':
118146
concs = [100, 300, 1000]
147+
elif args.c == 'DMSO':
148+
concs = [1]
119149
if herg_model != 'kemp' and herg_model != '2024_Joey_sis_25C':
120150
with open(f'methods/models/params/{p_path}', newline='') as csvfile:
121151
p_reader = csv.reader(csvfile)
@@ -133,9 +163,21 @@ def get_pars(model_num):
133163
win = np.zeros_like(conditions[0], dtype=bool)
134164
for condition in conditions:
135165
win |= condition
166+
if args.d:
167+
times_m = np.arange(0, 15e3, steps)
168+
conditions_m = []
169+
for b in [[1e3, 11e3]]:
170+
conditions_m.append(((times_m >= b[0]) & (times_m < b[-1])))
171+
win_m = np.zeros_like(conditions_m[0], dtype=bool)
172+
for condition in conditions_m:
173+
win_m |= condition
136174
# read fitted splines
137175
dfy = pd.read_csv(f"{outdir}/synth_Y_fit.csv")
138176
mu_y = np.array(dfy['0'])
177+
if args.d:
178+
# read fitted splines
179+
dfy_m = pd.read_csv(f"{outdir.rsplit('/',1)[0]}/synth_Y_fit.csv")
180+
mu_y_m = np.array(dfy_m['0'])
139181
# fit model
140182
pars, sc = get_pars(model_num)
141183
print(f'{model_num}: {pars}, {sc}')

src/fit_spline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def main(input, output, lambda_, t_steps, data, max_t):
3030
t_swp = t_steps[j]
3131
t_total = 0
3232
swps = int(np.floor(250000/max_t))*6
33+
elif data == 'milnes real':
34+
t_swp = 10000
35+
swps = 20
3336
else:
3437
lens = [3340, 3330, 3340, 3330, 3330]
3538
t_total = 0
@@ -39,7 +42,7 @@ def main(input, output, lambda_, t_steps, data, max_t):
3942

4043
# Fit splines
4144
for i in range(swps):
42-
if data == 'milnes':
45+
if data == 'milnes' or data == 'milnes real':
4346
df_rep = df_all[(df_all['t'] >= t_swp * i) & (df_all['t'] < t_swp * (i + 1))][['t', 'x']]
4447
knots = np.arange(t_swp * i, t_swp * (i + 1) + t_swp/2, t_swp/2)
4548
elif data == 'opt':

src/optimise_protocol.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
parser.add_argument('-e', type=str, default='joey_sis', help='hERG model parameters')
2323
parser.add_argument('-o', type=str, required=True, help='output folder for synthetic data')
2424
parser.add_argument('-c', type=str, help='drug compound')
25+
parser.add_argument('-r', action='store_true', help='set true for real data')
2526
args = parser.parse_args()
2627

2728
def get_opt_prot(model_pars, herg, v_steps, t_steps, p0, CMAES_pop = 10, max_iter = 5, alt_protocol = None):
@@ -49,32 +50,32 @@ def __call__(self, p):
4950
prot = funcs.create_protocol(self.v_st, self.t_st)
5051
# generate model output under the new protocol
5152
if alt_protocol is not None:
52-
model_out, cont, swps = funcs.model_outputs(model_pars, herg, prot, times = np.arange(0, np.sum(self.t_st), 10), concs = concs,
53+
model_out, swps = funcs.model_outputs(model_pars, herg, prot, times = np.arange(0, np.sum(self.t_st), 10), concs = concs,
5354
alt_protocol = prot_alt, alt_times = np.arange(0, np.sum(self.t_st_alt), 10), wins = [1e3, np.sum(self.t_st[1:4])],
5455
wins_alt = [1e3, np.sum(self.t_st_alt[1:4])])
5556
else:
56-
model_out, cont, swps = funcs.model_outputs(model_pars, herg, prot, times = np.arange(0, np.sum(self.t_st), 10), concs = concs, wins = [1e3, np.sum(self.t_st[1:4])])
57+
model_out, swps = funcs.model_outputs(model_pars, herg, prot, times = np.arange(0, np.sum(self.t_st), 10), concs = concs, wins = [1e3, np.sum(self.t_st[1:4])])
5758
### loop through models and concentrations to get traces
5859
all_traces = []
5960
for m in model_out:
6061
model_trace = []
6162
for c in concs:
6263
model_trace = np.append(model_trace, model_out[m][c])
6364
all_traces.append(model_trace)
64-
conts = []
65-
for c in concs:
66-
conts = np.append(conts, cont)
67-
#ssq = [sum((m-n)**2) for i,m in enumerate(all_traces) for j,n in enumerate(all_traces) if i < j]
65+
#conts = []
66+
#for c in concs:
67+
# conts = np.append(conts, cont)
68+
ssq = [sum((m-n)**2) for i,m in enumerate(all_traces) for j,n in enumerate(all_traces) if i < j]
6869
# calculate expected likelihood ratio (assuming normal ratio data)
69-
lhoods = []
70-
for i,m in enumerate(all_traces):
71-
for j,n in enumerate(all_traces):
72-
if i < j:
73-
top = np.exp(-conts**2*((m**2+1)/(2*sd**2))) + np.sqrt(np.pi/2)*(conts/sd)*np.sqrt(1+m**2)*special.erf((conts/(np.sqrt(2)*sd))*np.sqrt(1+m**2))
74-
bottom = np.exp(-conts**2*((n**2+1)/(2*sd**2))) + np.sqrt(np.pi/2)*(conts/sd)*((1+m*n)/np.sqrt(1+m**2))*special.erf((conts*(1+m*n))/(np.sqrt(2)*sd*np.sqrt(1+m**2)))*np.exp(-(conts**2*(m-n)**2)/(2*sd**2*(1+m**2)))
75-
lhoods.append(-2*np.sum(np.log(top/bottom)))
76-
#out = -np.median(ssq)
77-
out = np.median(lhoods)
70+
#lhoods = []
71+
#for i,m in enumerate(all_traces):
72+
# for j,n in enumerate(all_traces):
73+
# if i < j:
74+
# top = np.exp(-conts**2*((m**2+1)/(2*sd**2))) + np.sqrt(np.pi/2)*(conts/sd)*np.sqrt(1+m**2)*special.erf((conts/(np.sqrt(2)*sd))*np.sqrt(1+m**2))
75+
# bottom = np.exp(-conts**2*((n**2+1)/(2*sd**2))) + np.sqrt(np.pi/2)*(conts/sd)*((1+m*n)/np.sqrt(1+m**2))*special.erf((conts*(1+m*n))/(np.sqrt(2)*sd*np.sqrt(1+m**2)))*np.exp(-(conts**2*(m-n)**2)/(2*sd**2*(1+m**2)))
76+
# lhoods.append(-2*np.sum(np.log(top/bottom)))
77+
out = -np.median(ssq)
78+
#out = np.median(lhoods)
7879
return out
7980

8081
def n_parameters(self):
@@ -99,14 +100,14 @@ def n_parameters(self):
99100
transformation = pints.RectangularBoundariesTransformation(boundaries)
100101

101102
# define initial standard deviation around voltages and times during optimisation
102-
step_v = [5]*v_steps.count(np.nan)
103-
step_t = [50]*t_steps.count(np.nan)
103+
step_v = [10]*v_steps.count(np.nan)
104+
step_t = [500]*t_steps.count(np.nan)
104105

105106
if alt_protocol is not None:
106107
design = DrugBind(p0+alt_protocol)
107108
else:
108109
design = DrugBind(p0)
109-
110+
110111
# Fix random seed for reproducibility
111112
np.random.seed(101)
112113

@@ -117,12 +118,14 @@ def n_parameters(self):
117118
q0 = boundaries.sample()[0]
118119
try:
119120
temp = design(q0)
121+
print(f'temp score: {temp}')
120122
if temp < score:
121123
score = temp
122124
q0save = q0
123125
except Exception as e:
124126
print(f"An error occurred: {e}")
125127

128+
print(f'init_score: {score}')
126129
# Define optimiser
127130
optimiser = pints.CMAES
128131
if alt_protocol is not None:
@@ -211,6 +214,14 @@ def main(model_nums, max_time, bounds, herg, output_folder):
211214
writer.writerow(t_steps)
212215

213216
if __name__ == "__main__":
214-
concs = parameters.drug_concs[args.c]
217+
if args.r:
218+
if args.c == 'verapamil':
219+
concs=[100,300,1000]
220+
elif args.c == 'bepridil':
221+
concs=[30,100,300]
222+
elif args.c == 'terfenadine':
223+
concs=[100,300,1000]
224+
else:
225+
concs = parameters.drug_concs[args.c]
215226
m_list = ast.literal_eval(args.m)
216227
main(m_list, args.t, args.b, args.e, args.o)

0 commit comments

Comments
 (0)