Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor updates #51

Merged
merged 22 commits into from
Sep 25, 2024
29 changes: 19 additions & 10 deletions emodel_generalisation/adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from functools import partial
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand All @@ -44,8 +43,7 @@
from emodel_generalisation.model.modifiers import remove_soma
from emodel_generalisation.utils import FEATURE_FILTER
from emodel_generalisation.utils import get_scores

matplotlib.use("Agg")
from emodel_generalisation.utils import isolate


def get_scales(scales_params, with_unity=False):
Expand Down Expand Up @@ -85,7 +83,6 @@ def build_resistance_models(

# it seems dask does not quite work on this (to investigate, but multiprocessing is fast enough)
df = func(df, access_point, parallel_factory="multiprocessing")

models = {}
for emodel in emodels:
_df = df[df.emodel == emodel]
Expand All @@ -94,13 +91,16 @@ def build_resistance_models(
if len(rin[rin < 0]) == 0:
try:
coeffs, extra = Polynomial.fit(np.log10(scaler), np.log10(rin), 3, full=True)
if extra[0] < rcond_min:
models[emodel] = {
"resistance": {"polyfit_params": coeffs.convert().coef.tolist()},
"shape": exemplar_data[key],
}
if extra[0] > rcond_min:
print(f"resistance fit for {key} of {emodel} is not so good")
models[emodel] = {
"resistance": {"polyfit_params": coeffs.convert().coef.tolist()},
"shape": exemplar_data[key],
}
except (np.linalg.LinAlgError, TypeError):
print(f"fail to fit emodel {emodel}")
else:
print(f"resistance fit for {key} of {emodel} has negative rin")
return df[df.emodel.isin(models)], models


Expand Down Expand Up @@ -145,7 +145,7 @@ def _adapt_combo(combo, models, rhos, key="soma", min_scale=0.01, max_scale=10):
return {f"{key}_scaler": np.clip(scale, min_scale, max_scale)}


def _adapt_single_soma_ais(
def __adapt_single_soma_ais(
combo,
access_point=None,
models=None,
Expand Down Expand Up @@ -199,6 +199,15 @@ def _adapt_single_soma_ais(
return {k: combo[k] for k in ["soma_scaler", "ais_scaler"]}


def _adapt_single_soma_ais(*args, **kwargs):
timeout = kwargs.pop("timeout", 30 * 60)
res = isolate(__adapt_single_soma_ais, timeout=timeout)(*args, **kwargs)
if res is None:
print("timeout", args, kwargs)
return {k: 1.0 for k in ["soma_scaler", "ais_scaler"]}
return res


def make_evaluation_df(combos_df, emodels, exemplar_data, rhos=None):
"""Make a df to be evaluated."""
df = pd.DataFrame()
Expand Down
72 changes: 19 additions & 53 deletions emodel_generalisation/bluecellulab_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,60 +3,19 @@
import logging
import os
from copy import copy
from multiprocessing.context import TimeoutError # pylint: disable=redefined-builtin
from pathlib import Path

import bluecellulab
import efel
from bluecellulab.simulation.neuron_globals import NeuronGlobals
from bluepyparallel.evaluator import evaluate
from bluepyparallel.parallel import NestedPool

from emodel_generalisation.utils import isolate

logger = logging.getLogger(__name__)
AXON_LOC = "self.axonal[1](0.5)._ref_v"


def isolate(func, timeout=None):
"""Isolate a generic function for independent NEURON instances.

It must be used in conjunction with NestedPool.

Example:

.. code-block:: python

def _to_be_isolated(morphology_path, point):
cell = nrnhines.get_NRN_cell(morphology_path)
return nrnhines.point_to_section_end(cell.icell.all, point)

def _isolated(morph_data):
return nrnhines.isolate(_to_be_isolated)(*morph_data)

with nrnhines.NestedPool(processes=n_workers) as pool:
result = pool.imap_unordered(_isolated, data)


Args:
func (function): function to isolate

Returns:
the isolated function

Note: it does not work as decorator.
"""

def func_isolated(*args, **kwargs):
with NestedPool(1, maxtasksperchild=1) as pool:
res = pool.apply_async(func, args, kwargs)
try:
out = res.get(timeout=timeout)
except TimeoutError: # pragma: no cover
out = None
return out

return func_isolated


def calculate_threshold_current(cell, config, holding_current):
"""Calculate threshold current"""
min_current_spike_count = run_spike_sim(
Expand Down Expand Up @@ -84,7 +43,7 @@ def calculate_threshold_current(cell, config, holding_current):
if max_current_spike_count < 1:
logger.debug("Cell is not firing at max current, we multiply by 2")
config["min_threshold_current"] = copy(config["max_threshold_current"])
config["max_threshold_current"] *= 2.0
config["max_threshold_current"] *= 1.2
return calculate_threshold_current(cell, config, holding_current)

return binsearch_threshold_current(
Expand All @@ -99,8 +58,7 @@ def calculate_threshold_current(cell, config, holding_current):
def binsearch_threshold_current(cell, config, holding_current, min_current, max_current):
"""Binary search for threshold currents"""
mid_current = (min_current + max_current) / 2

if abs(max_current - min_current) < config["threshold_current_precision"]:
if abs(max_current - min_current) < config.get("threshold_current_precision", 0.001):
spike_count = run_spike_sim(
cell,
config,
Expand Down Expand Up @@ -218,13 +176,21 @@ def calculate_rmp_and_rin(cell, config):
"stim_end": [config["rin"]["step_stop"]],
"stimulus_current": [config["rin"]["step_amp"]],
}
features = efel.getFeatureValues([trace], ["voltage_base", "ohmic_input_resistance_vb_ssse"])[0]
rmp = None
features = efel.getFeatureValues(
[trace], ["spike_count", "voltage_base", "ohmic_input_resistance_vb_ssse"]
)[0]

rmp = 0
if features["voltage_base"] is not None:
rmp = features["voltage_base"][0]
rin = None

rin = -1.0
if features["ohmic_input_resistance_vb_ssse"] is not None:
rin = features["ohmic_input_resistance_vb_ssse"][0]

if features["spike_count"] > 0:
logger.warning("SPIKES! %s, %s", rmp, rin)
return 0.0, -1.0
return rmp, rin


Expand Down Expand Up @@ -289,12 +255,12 @@ def _isolated_current_evaluation(*args, **kwargs):
res = isolate(_current_evaluation, timeout=timeout)(*args, **kwargs)
if res is None:
res = {
"resting_potential": None,
"input_resistance": None,
"resting_potential": 0.0,
"input_resistance": -1.0,
}
if not kwargs.get("only_rin", False):
res["holding_current"] = None
res["threshold_current"] = None
res["holding_current"] = 0.0
res["threshold_current"] = 0.0

return res

Expand Down
51 changes: 34 additions & 17 deletions emodel_generalisation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def compute_currents(
"step_stop": 2000.0,
"threshold_current_precision": 0.001,
"min_threshold_current": 0.0,
"max_threshold_current": 0.2,
"max_threshold_current": 0.1,
"spike_at_ais": False, # does not work with placeholder
"deterministic": True,
"celsius": 34.0,
Expand Down Expand Up @@ -206,12 +206,15 @@ def compute_currents(
)

failed_cells = unique_cells_df[
unique_cells_df["input_resistance"].isna() | (unique_cells_df["input_resistance"] < 0)
unique_cells_df["input_resistance"].isna() | (unique_cells_df["input_resistance"] <= 0)
].index
if len(failed_cells) > 0:
L.info("still %s failed cells (we drop):", len(failed_cells))
L.info("still %s failed cells (we set default values):", len(failed_cells))
L.info(unique_cells_df.loc[failed_cells])
unique_cells_df.loc[failed_cells, "mtype"] = None
unique_cells_df.loc[failed_cells, "@dynamics:holding_current"] = 0.0
unique_cells_df.loc[failed_cells, "@dynamics:threshold_current"] = 0.0
unique_cells_df.loc[failed_cells, "@dynamics:input_resistance"] = 0.0
unique_cells_df.loc[failed_cells, "@dynamics:resting_potential"] = -80.0

cols = ["resting_potential", "input_resistance", "exception"]
if not only_rin:
Expand Down Expand Up @@ -427,7 +430,7 @@ def evaluate(
"name": pd.Series(dtype=str),
}
)
for gid, emodel in enumerate(cells_df.emodel.unique):
for gid, emodel in enumerate(cells_df.emodel.unique()):
morph = access_point.get_morphologies(emodel)
exemplar_df.loc[gid, "emodel"] = emodel
exemplar_df.loc[gid, "path"] = morph["path"]
Expand Down Expand Up @@ -755,8 +758,13 @@ def _get_resistance_models(exemplar_df, exemplar_data, scales_params):
"""We fit the scale/Rin relation for AIS and soma."""
models = {}
for emodel in exemplar_df.emodel:
Path(f"local/{emodel}").mkdir(parents=True, exist_ok=True)
models[emodel] = build_all_resistance_models(
access_point, [emodel], exemplar_data[emodel], scales_params
access_point,
[emodel],
exemplar_data[emodel],
scales_params,
fig_path=Path(f"local/{emodel}"),
)
return models

Expand All @@ -765,6 +773,11 @@ def _get_resistance_models(exemplar_df, exemplar_data, scales_params):
_get_resistance_models, exemplar_df.loc[placeholder_mask], exemplar_data, scales_params
)

# needed for internal reuse
for col in cells_df.columns:
if cells_df[col].dtype == "category":
cells_df[col] = cells_df[col].astype("object")

L.info("Adapting AIS and soma of all cells..")
cells_df["ais_scaler"] = 0.0
cells_df["soma_scaler"] = 0.0
Expand All @@ -777,7 +790,7 @@ def _adapt():
mask = cells_df["emodel"] == emodel

if emodel in exemplar_data and not exemplar_data[emodel]["placeholder"]:
L.info("Adapting a non placeholder model...")
L.info("Adapting the non placeholder model %s...", emodel)

if len(Morphology(cells_df[mask].head(1)["path"].tolist()[0]).root_sections) == 1:
raise ValueError(
Expand All @@ -792,16 +805,20 @@ def _adapt():
.set_index("emodel")
.to_dict()
)
cells_df.loc[mask] = adapt_soma_ais(
cells_df.loc[mask],
access_point,
resistance_models[emodel],
rhos,
parallel_factory=parallel_factory,
min_scale=min_scale,
max_scale=max_scale,
n_steps=2,
)
with Reuse(
local_dir / f"adapt_df_{emodel}.csv", disable=no_reuse, index_col=0
) as reuse:
cells_df.loc[mask] = reuse(
adapt_soma_ais,
cells_df.loc[mask],
access_point,
resistance_models[emodel],
rhos,
parallel_factory=parallel_factory,
min_scale=min_scale,
max_scale=max_scale,
n_steps=2,
).drop(columns=["exception"])

else:
if len(Morphology(cells_df[mask].head(1)["path"].tolist()[0]).root_sections) > 1:
Expand Down
22 changes: 17 additions & 5 deletions emodel_generalisation/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from functools import partial
from pathlib import Path

import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand All @@ -52,8 +52,6 @@

# pylint: disable=too-many-lines,too-many-locals

matplotlib.use("Agg")


class MarkovChain:
"""Class to setup and run a markov chain on emodel parameter space."""
Expand Down Expand Up @@ -1060,16 +1058,19 @@ def plot_corner(
cmap="gnuplot",
normalize=False,
highlights=None,
sort_params=True,
with_pearson=False,
):
"""Make a corner plot which consists of scatter plots of all pairs.

Args:
feature (str): name of feature for coloring heatmap
filename (str): name of figure for corner plot
"""
params = np.array(sorted(df.normalized_parameters.columns.to_list()))
params = np.array(df.normalized_parameters.columns.to_list())
_params = np.array([PARAM_LABELS.get(p, p) for p in params])
params = params[np.argsort(_params)]
if sort_params:
params = params[np.argsort(_params)]
n_params = len(params)

# get feature data
Expand All @@ -1096,6 +1097,10 @@ def plot_corner(
fig = plt.figure(figsize=(5 + 0.5 * n_params, 5 + 0.5 * n_params))
gs = fig.add_gridspec(n_params, n_params, hspace=0.1, wspace=0.1)
im = None
if with_pearson:
_cmap = plt.get_cmap("coolwarm")
norm = mpl.colors.Normalize(vmin=-0.8, vmax=0.8)
pearson_colors = mpl.cm.ScalarMappable(norm=norm, cmap=_cmap)
# pylint: disable=too-many-nested-blocks
for i, param1 in enumerate(params):
_param1 = PARAM_LABELS.get(param1, param1)
Expand All @@ -1112,6 +1117,13 @@ def plot_corner(
ax.set_frame_on(False)
elif j < i:
ax.set_frame_on(True)
if with_pearson:
pearson = pearsonr(
df[("normalized_parameters", param1)].to_numpy(),
df[("normalized_parameters", param2)].to_numpy(),
)
ax.spines[:].set_color(pearson_colors.to_rgba(pearson[0]))
ax.spines[:].set(lw=3.0)
im = _plot_2d_data(
ax, m[i][j], vmin, vmax, rev=feature is not None, cmap=cmap, normalize=normalize
)
Expand Down
Loading
Loading