Skip to content

Commit

Permalink
extend SteadyState to do multiple simultaneous fits
Browse files Browse the repository at this point in the history
- imagefile can be a dict of geometry_image:concentration_image filename pairs
- each geometry is simulated and the objective function calculated from the supplied concentration image
- the total objective function is the sum of these objective functions

also

- any concentration_image pixels outside of the model geometry are set to zero
- concentration images are now normalised such that largest pixel value is 1
  • Loading branch information
lkeegan committed Feb 26, 2021
1 parent 7450700 commit 6147a08
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 63 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setup(
name="sme_contrib",
version="0.0.11",
version="0.0.12",
author="Liam Keegan",
author_email="liam@keegan.ch",
description="Useful modules for use with sme (Spatial Model Editor)",
Expand Down
152 changes: 90 additions & 62 deletions sme_contrib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,25 @@ def _ss_plot_line(x, y, title, ax=None):
return ax


def _img_as_normalized_nparray(imagefile, mask):
# todo: check image dimensions match mask dimensions
img = Image.open(imagefile)
if len(img.getbands()) > 1:
# convert RGB or RGBA image to 8-bit grayscale
img = img.convert("L")
a = np.asarray(img, dtype=np.float64)
a[~mask] = 0.0
a_max = np.amax(a)
return a / a_max


def _get_geometry_mask(model):
mask = np.array(model.compartments[0].geometry_mask)
for compartment in model.compartments:
mask = mask | np.array(compartment.geometry_mask)
return mask


class SteadyState:
"""Steady state parameter fitting
Expand All @@ -232,6 +251,9 @@ class SteadyState:
Args:
modelfile(str): The sbml file containing the model
imagefile(str): The image file containing the target concentration
Optionally this can instead be a dict of geometryimagefilename:targetconcentrationfilename,
in which case the model is simulataneously fitted to the target conentration image
steady state for each geometry image.
species(List of str): The species to compare to the target concentration
function_to_apply_params: A function that sets the parameters in the model.
This should be a function with signature ``f(model, params)``, and
Expand Down Expand Up @@ -269,37 +291,29 @@ def __init__(
):
self.filename = modelfile
self.species = species
self.set_target_image(imagefile)
self.geom_img_files = []
self.target_concs = []
if isinstance(imagefile, str):
# single target image, use existing model geometry
self.geom_img_files.append("")
m = sme.open_sbml_file(self.filename)
mask = _get_geometry_mask(m)
self.target_concs.append(_img_as_normalized_nparray(imagefile, mask))
else:
# multiple geometry images, each with target concentration image
for geom_img_file, imgfile in imagefile.items():
self.geom_img_files.append(geom_img_file)
m = sme.open_sbml_file(self.filename)
m.import_geometry_from_image(geom_img_file)
mask = _get_geometry_mask(m)
self.target_concs.append(_img_as_normalized_nparray(imgfile, mask))
self.simulation_time = simulation_time
self.steady_state_time = steady_state_time
self.apply_params = function_to_apply_params
self.lower_bounds = lower_bounds
self.upper_bounds = upper_bounds
self.timeout_seconds = timeout_seconds

def set_target_image(self, imagefile):
"""Set a new target concentration image
Most image formats are supported. If the image has multiple
channels, e.g. RGB or RGBA, it is first converted to grayscale.
Note:
This doesn't (yet) check that the image has the same dimensions
as the model geometry, nor does it mask the pixels that lie
outside of the model compartments to zero
Args:
imagefile(str): The image filename
"""
# todo: mask image to only pixels in model
# todo: check image dimensions match model geometry image
img = Image.open(imagefile)
if len(img.getbands()) > 1:
# convert RGB or RGBA image to 8-bit grayscale
img = img.convert("L")
self.target_conc = np.asarray(img, dtype=np.float64)
self.target_conc_max = np.amax(self.target_conc)

def _get_conc(self, result):
c = np.array(result.species_concentration[self.species[0]])
for i in range(1, len(self.species)):
Expand All @@ -315,33 +329,40 @@ def _get_dcdt(self, result):
def _rescale(self, result):
c = self._get_conc(result)
dcdt = self._get_dcdt(result)
scale_factor = self.target_conc_max / np.amax(c)
scale_factor = 1.0 / np.amax(c)
return (scale_factor * c, scale_factor * dcdt)

def _obj_func(self, params, verbose=False):
m = sme.open_sbml_file(self.filename)
self.apply_params(m, params)
results = m.simulate(
simulation_time=self.simulation_time,
image_interval=self.simulation_time,
timeout_seconds=self.timeout_seconds,
throw_on_timeout=False,
)
if len(results) == 1:
# simulation fail or timeout
print(
f"simulation timeout with timeout {self.timeout_seconds}s, params: {params}"
obj_sum = 0
model_concs = []
conc_norm = 0
dcdt_norm = 0
for geom_img_file, target_conc in zip(self.geom_img_files, self.target_concs):
if geom_img_file:
m.import_geometry_from_image(geom_img_file)
results = m.simulate(
simulation_time=self.simulation_time,
image_interval=self.simulation_time,
timeout_seconds=self.timeout_seconds,
throw_on_timeout=False,
)
conc_norm = abs_diff(0, self.target_conc)
if verbose:
return (conc_norm, conc_norm, 0)
return conc_norm
c, dcdt = self._rescale(results[-1])
conc_norm = abs_diff(c, self.target_conc)
dcdt_norm = abs_diff(self.steady_state_time * dcdt, 0)
if len(results) == 1:
# simulation fail or timeout: no result
print(
f"simulation timeout with timeout {self.timeout_seconds}s, params: {params}"
)
conc_norm = conc_norm + abs_diff(0, target_conc)
dcdt_norm = dcdt_norm + abs_diff(0, target_conc)
c, dcdt = self._rescale(results[-1])
conc_norm = conc_norm + abs_diff(c, target_conc)
dcdt_norm = dcdt_norm + abs_diff(self.steady_state_time * dcdt, 0)
model_concs.append(c)
obj_sum = obj_sum + conc_norm + dcdt_norm
if verbose:
return (conc_norm, dcdt_norm, c)
return conc_norm + dcdt_norm
return (conc_norm, dcdt_norm, model_concs)
return obj_sum

def find(self, particles=20, iterations=20, processes=None):
"""Find parameters that result in a steady state concentration close to the target image
Expand Down Expand Up @@ -376,7 +397,7 @@ def find(self, particles=20, iterations=20, processes=None):
)
self.cost_history = optimizer.cost_history
self.cost_history_pbest = optimizer.mean_pbest_history
self.conc_norm, self.dcdt_norm, self.model_conc = self._obj_func(
self.conc_norm, self.dcdt_norm, self.model_concs = self._obj_func(
params, verbose=True
)
self.params = params
Expand All @@ -385,7 +406,7 @@ def find(self, particles=20, iterations=20, processes=None):
def hessian(self, rel_eps=0.1, processes=None):
return hessian(self._obj_func, self.params, rel_eps, processes)

def plot_target_concentration(self, ax=None, cmap=None):
def plot_target_concentration(self, index=0, ax=None, cmap=None):
"""Plot the target concentration as a 2d heat map
Args:
Expand All @@ -395,9 +416,11 @@ def plot_target_concentration(self, ax=None, cmap=None):
Returns:
matplotlib.axes._subplots.AxesSubplot: The axes the plot was drawn on
"""
return _ss_plot_image(self.target_conc, "Target Concentration", ax, cmap)
return _ss_plot_image(
self.target_concs[index], "Target Concentration", ax, cmap
)

def plot_model_concentration(self, ax=None, cmap=None):
def plot_model_concentration(self, index=0, ax=None, cmap=None):
"""Plot the model concentration as a 2d heat map
The model concentration is normalized such that the maximum pixel intensity
Expand All @@ -410,7 +433,7 @@ def plot_model_concentration(self, ax=None, cmap=None):
Returns:
matplotlib.axes._subplots.AxesSubplot: The axes the plot was drawn on
"""
return _ss_plot_image(self.model_conc, "Model Concentration", ax, cmap)
return _ss_plot_image(self.model_concs[index], "Model Concentration", ax, cmap)

def plot_cost_history(self, ax=None):
"""Plot the cost history
Expand Down Expand Up @@ -460,15 +483,19 @@ def plot_timeseries(self, simulation_time, image_interval_time, ax=None):
"""
m = sme.open_sbml_file(self.filename)
self.apply_params(m, self.params)
results = m.simulate(
simulation_time=simulation_time, image_interval=image_interval_time
)
concs = []
times = []
for result in results:
concs.append(np.sum(self._get_conc(result)))
times.append(result.time_point)
return _ss_plot_line(times, concs, "Concentration time series", ax)
for geom_img_file in self.geom_img_files:
if geom_img_file:
m.import_geometry_from_image(geom_img_file)
results = m.simulate(
simulation_time=simulation_time, image_interval=image_interval_time
)
concs = []
times = []
for result in results:
concs.append(np.sum(self._get_conc(result)))
times.append(result.time_point)
ax = _ss_plot_line(times, concs, "Concentration time series", ax)
return ax

def plot_all(self, cmap=None):
"""Generate all plots
Expand All @@ -485,10 +512,11 @@ def plot_all(self, cmap=None):
self.plot_timeseries(self.simulation_time, self.simulation_time / 100.0, ax3)
plt.show()

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(18, 12))
self.plot_target_concentration(ax1, cmap)
self.plot_model_concentration(ax2, cmap)
plt.show()
for index, geom_img_file in enumerate(self.geom_img_files):
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(18, 12))
self.plot_target_concentration(index, ax1, cmap)
self.plot_model_concentration(index, ax2, cmap)
plt.show()

def get_model(self):
"""Returns the model with best parameters applied
Expand Down

0 comments on commit 6147a08

Please sign in to comment.