Skip to content

Commit

Permalink
Minor updates (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudon authored Sep 25, 2024
1 parent 6b29fab commit f86b895
Show file tree
Hide file tree
Showing 13 changed files with 178 additions and 132 deletions.
29 changes: 19 additions & 10 deletions emodel_generalisation/adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from functools import partial
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand All @@ -44,8 +43,7 @@
from emodel_generalisation.model.modifiers import remove_soma
from emodel_generalisation.utils import FEATURE_FILTER
from emodel_generalisation.utils import get_scores

matplotlib.use("Agg")
from emodel_generalisation.utils import isolate


def get_scales(scales_params, with_unity=False):
Expand Down Expand Up @@ -85,7 +83,6 @@ def build_resistance_models(

# it seems dask does not quite work on this (to investigate, but multiprocessing is fast enough)
df = func(df, access_point, parallel_factory="multiprocessing")

models = {}
for emodel in emodels:
_df = df[df.emodel == emodel]
Expand All @@ -94,13 +91,16 @@ def build_resistance_models(
if len(rin[rin < 0]) == 0:
try:
coeffs, extra = Polynomial.fit(np.log10(scaler), np.log10(rin), 3, full=True)
if extra[0] < rcond_min:
models[emodel] = {
"resistance": {"polyfit_params": coeffs.convert().coef.tolist()},
"shape": exemplar_data[key],
}
if extra[0] > rcond_min:
print(f"resistance fit for {key} of {emodel} is not so good")
models[emodel] = {
"resistance": {"polyfit_params": coeffs.convert().coef.tolist()},
"shape": exemplar_data[key],
}
except (np.linalg.LinAlgError, TypeError):
print(f"fail to fit emodel {emodel}")
else:
print(f"resistance fit for {key} of {emodel} has negative rin")
return df[df.emodel.isin(models)], models


Expand Down Expand Up @@ -145,7 +145,7 @@ def _adapt_combo(combo, models, rhos, key="soma", min_scale=0.01, max_scale=10):
return {f"{key}_scaler": np.clip(scale, min_scale, max_scale)}


def _adapt_single_soma_ais(
def __adapt_single_soma_ais(
combo,
access_point=None,
models=None,
Expand Down Expand Up @@ -199,6 +199,15 @@ def _adapt_single_soma_ais(
return {k: combo[k] for k in ["soma_scaler", "ais_scaler"]}


def _adapt_single_soma_ais(*args, **kwargs):
timeout = kwargs.pop("timeout", 30 * 60)
res = isolate(__adapt_single_soma_ais, timeout=timeout)(*args, **kwargs)
if res is None:
print("timeout", args, kwargs)
return {k: 1.0 for k in ["soma_scaler", "ais_scaler"]}
return res


def make_evaluation_df(combos_df, emodels, exemplar_data, rhos=None):
"""Make a df to be evaluated."""
df = pd.DataFrame()
Expand Down
72 changes: 19 additions & 53 deletions emodel_generalisation/bluecellulab_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,60 +3,19 @@
import logging
import os
from copy import copy
from multiprocessing.context import TimeoutError # pylint: disable=redefined-builtin
from pathlib import Path

import bluecellulab
import efel
from bluecellulab.simulation.neuron_globals import NeuronGlobals
from bluepyparallel.evaluator import evaluate
from bluepyparallel.parallel import NestedPool

from emodel_generalisation.utils import isolate

logger = logging.getLogger(__name__)
AXON_LOC = "self.axonal[1](0.5)._ref_v"


def isolate(func, timeout=None):
"""Isolate a generic function for independent NEURON instances.
It must be used in conjunction with NestedPool.
Example:
.. code-block:: python
def _to_be_isolated(morphology_path, point):
cell = nrnhines.get_NRN_cell(morphology_path)
return nrnhines.point_to_section_end(cell.icell.all, point)
def _isolated(morph_data):
return nrnhines.isolate(_to_be_isolated)(*morph_data)
with nrnhines.NestedPool(processes=n_workers) as pool:
result = pool.imap_unordered(_isolated, data)
Args:
func (function): function to isolate
Returns:
the isolated function
Note: it does not work as decorator.
"""

def func_isolated(*args, **kwargs):
with NestedPool(1, maxtasksperchild=1) as pool:
res = pool.apply_async(func, args, kwargs)
try:
out = res.get(timeout=timeout)
except TimeoutError: # pragma: no cover
out = None
return out

return func_isolated


def calculate_threshold_current(cell, config, holding_current):
"""Calculate threshold current"""
min_current_spike_count = run_spike_sim(
Expand Down Expand Up @@ -84,7 +43,7 @@ def calculate_threshold_current(cell, config, holding_current):
if max_current_spike_count < 1:
logger.debug("Cell is not firing at max current, we multiply by 2")
config["min_threshold_current"] = copy(config["max_threshold_current"])
config["max_threshold_current"] *= 2.0
config["max_threshold_current"] *= 1.2
return calculate_threshold_current(cell, config, holding_current)

return binsearch_threshold_current(
Expand All @@ -99,8 +58,7 @@ def calculate_threshold_current(cell, config, holding_current):
def binsearch_threshold_current(cell, config, holding_current, min_current, max_current):
"""Binary search for threshold currents"""
mid_current = (min_current + max_current) / 2

if abs(max_current - min_current) < config["threshold_current_precision"]:
if abs(max_current - min_current) < config.get("threshold_current_precision", 0.001):
spike_count = run_spike_sim(
cell,
config,
Expand Down Expand Up @@ -218,13 +176,21 @@ def calculate_rmp_and_rin(cell, config):
"stim_end": [config["rin"]["step_stop"]],
"stimulus_current": [config["rin"]["step_amp"]],
}
features = efel.getFeatureValues([trace], ["voltage_base", "ohmic_input_resistance_vb_ssse"])[0]
rmp = None
features = efel.getFeatureValues(
[trace], ["spike_count", "voltage_base", "ohmic_input_resistance_vb_ssse"]
)[0]

rmp = 0
if features["voltage_base"] is not None:
rmp = features["voltage_base"][0]
rin = None

rin = -1.0
if features["ohmic_input_resistance_vb_ssse"] is not None:
rin = features["ohmic_input_resistance_vb_ssse"][0]

if features["spike_count"] > 0:
logger.warning("SPIKES! %s, %s", rmp, rin)
return 0.0, -1.0
return rmp, rin


Expand Down Expand Up @@ -289,12 +255,12 @@ def _isolated_current_evaluation(*args, **kwargs):
res = isolate(_current_evaluation, timeout=timeout)(*args, **kwargs)
if res is None:
res = {
"resting_potential": None,
"input_resistance": None,
"resting_potential": 0.0,
"input_resistance": -1.0,
}
if not kwargs.get("only_rin", False):
res["holding_current"] = None
res["threshold_current"] = None
res["holding_current"] = 0.0
res["threshold_current"] = 0.0

return res

Expand Down
51 changes: 34 additions & 17 deletions emodel_generalisation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def compute_currents(
"step_stop": 2000.0,
"threshold_current_precision": 0.001,
"min_threshold_current": 0.0,
"max_threshold_current": 0.2,
"max_threshold_current": 0.1,
"spike_at_ais": False, # does not work with placeholder
"deterministic": True,
"celsius": 34.0,
Expand Down Expand Up @@ -206,12 +206,15 @@ def compute_currents(
)

failed_cells = unique_cells_df[
unique_cells_df["input_resistance"].isna() | (unique_cells_df["input_resistance"] < 0)
unique_cells_df["input_resistance"].isna() | (unique_cells_df["input_resistance"] <= 0)
].index
if len(failed_cells) > 0:
L.info("still %s failed cells (we drop):", len(failed_cells))
L.info("still %s failed cells (we set default values):", len(failed_cells))
L.info(unique_cells_df.loc[failed_cells])
unique_cells_df.loc[failed_cells, "mtype"] = None
unique_cells_df.loc[failed_cells, "@dynamics:holding_current"] = 0.0
unique_cells_df.loc[failed_cells, "@dynamics:threshold_current"] = 0.0
unique_cells_df.loc[failed_cells, "@dynamics:input_resistance"] = 0.0
unique_cells_df.loc[failed_cells, "@dynamics:resting_potential"] = -80.0

cols = ["resting_potential", "input_resistance", "exception"]
if not only_rin:
Expand Down Expand Up @@ -427,7 +430,7 @@ def evaluate(
"name": pd.Series(dtype=str),
}
)
for gid, emodel in enumerate(cells_df.emodel.unique):
for gid, emodel in enumerate(cells_df.emodel.unique()):
morph = access_point.get_morphologies(emodel)
exemplar_df.loc[gid, "emodel"] = emodel
exemplar_df.loc[gid, "path"] = morph["path"]
Expand Down Expand Up @@ -755,8 +758,13 @@ def _get_resistance_models(exemplar_df, exemplar_data, scales_params):
"""We fit the scale/Rin relation for AIS and soma."""
models = {}
for emodel in exemplar_df.emodel:
Path(f"local/{emodel}").mkdir(parents=True, exist_ok=True)
models[emodel] = build_all_resistance_models(
access_point, [emodel], exemplar_data[emodel], scales_params
access_point,
[emodel],
exemplar_data[emodel],
scales_params,
fig_path=Path(f"local/{emodel}"),
)
return models

Expand All @@ -765,6 +773,11 @@ def _get_resistance_models(exemplar_df, exemplar_data, scales_params):
_get_resistance_models, exemplar_df.loc[placeholder_mask], exemplar_data, scales_params
)

# needed for internal reuse
for col in cells_df.columns:
if cells_df[col].dtype == "category":
cells_df[col] = cells_df[col].astype("object")

L.info("Adapting AIS and soma of all cells..")
cells_df["ais_scaler"] = 0.0
cells_df["soma_scaler"] = 0.0
Expand All @@ -777,7 +790,7 @@ def _adapt():
mask = cells_df["emodel"] == emodel

if emodel in exemplar_data and not exemplar_data[emodel]["placeholder"]:
L.info("Adapting a non placeholder model...")
L.info("Adapting the non placeholder model %s...", emodel)

if len(Morphology(cells_df[mask].head(1)["path"].tolist()[0]).root_sections) == 1:
raise ValueError(
Expand All @@ -792,16 +805,20 @@ def _adapt():
.set_index("emodel")
.to_dict()
)
cells_df.loc[mask] = adapt_soma_ais(
cells_df.loc[mask],
access_point,
resistance_models[emodel],
rhos,
parallel_factory=parallel_factory,
min_scale=min_scale,
max_scale=max_scale,
n_steps=2,
)
with Reuse(
local_dir / f"adapt_df_{emodel}.csv", disable=no_reuse, index_col=0
) as reuse:
cells_df.loc[mask] = reuse(
adapt_soma_ais,
cells_df.loc[mask],
access_point,
resistance_models[emodel],
rhos,
parallel_factory=parallel_factory,
min_scale=min_scale,
max_scale=max_scale,
n_steps=2,
).drop(columns=["exception"])

else:
if len(Morphology(cells_df[mask].head(1)["path"].tolist()[0]).root_sections) > 1:
Expand Down
22 changes: 17 additions & 5 deletions emodel_generalisation/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from functools import partial
from pathlib import Path

import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand All @@ -52,8 +52,6 @@

# pylint: disable=too-many-lines,too-many-locals

matplotlib.use("Agg")


class MarkovChain:
"""Class to setup and run a markov chain on emodel parameter space."""
Expand Down Expand Up @@ -1060,16 +1058,19 @@ def plot_corner(
cmap="gnuplot",
normalize=False,
highlights=None,
sort_params=True,
with_pearson=False,
):
"""Make a corner plot which consists of scatter plots of all pairs.
Args:
feature (str): name of feature for coloring heatmap
filename (str): name of figure for corner plot
"""
params = np.array(sorted(df.normalized_parameters.columns.to_list()))
params = np.array(df.normalized_parameters.columns.to_list())
_params = np.array([PARAM_LABELS.get(p, p) for p in params])
params = params[np.argsort(_params)]
if sort_params:
params = params[np.argsort(_params)]
n_params = len(params)

# get feature data
Expand All @@ -1096,6 +1097,10 @@ def plot_corner(
fig = plt.figure(figsize=(5 + 0.5 * n_params, 5 + 0.5 * n_params))
gs = fig.add_gridspec(n_params, n_params, hspace=0.1, wspace=0.1)
im = None
if with_pearson:
_cmap = plt.get_cmap("coolwarm")
norm = mpl.colors.Normalize(vmin=-0.8, vmax=0.8)
pearson_colors = mpl.cm.ScalarMappable(norm=norm, cmap=_cmap)
# pylint: disable=too-many-nested-blocks
for i, param1 in enumerate(params):
_param1 = PARAM_LABELS.get(param1, param1)
Expand All @@ -1112,6 +1117,13 @@ def plot_corner(
ax.set_frame_on(False)
elif j < i:
ax.set_frame_on(True)
if with_pearson:
pearson = pearsonr(
df[("normalized_parameters", param1)].to_numpy(),
df[("normalized_parameters", param2)].to_numpy(),
)
ax.spines[:].set_color(pearson_colors.to_rgba(pearson[0]))
ax.spines[:].set(lw=3.0)
im = _plot_2d_data(
ax, m[i][j], vmin, vmax, rev=feature is not None, cmap=cmap, normalize=normalize
)
Expand Down
Loading

0 comments on commit f86b895

Please sign in to comment.