Open
Description
Hi, BrainPy team
While modelling synapse connections with delay, I discovered the delay in FullProjAlignPost* is always off by some steps compared to the value passed as argument.
I tested a few delays with dt set to 1.0, roughly psc delay = relu(delay - 2.0)
This happens in both feedforward and recurrent connection
example:
with delay = 2.0, psc delay is 0
with delay = 20.0, psc delay is 18
reproducer code as modified from tutorial example
import numpy as np
import brainpy as bp
import brainpy.math as bm
class ExponCUBA(bp.Projection):
def __init__(self, pre, post, prob, g_max, tau, delay):
super().__init__()
self.proj = bp.dyn.FullProjAlignPost(
pre=pre,
delay=delay,
comm=bp.dnn.OneToOne(pre.num, g_max),
syn=bp.dyn.Expon(post.num, tau=tau),
out=bp.dyn.CUBA(),
post=post,
)
class CUBA_Net(bp.DynSysGroup):
def __init__(self, scale=1.0, delay=None):
super().__init__()
# network size
num = int(3200 * scale)
# neurons
pars = dict(V_rest=-49, V_th=-50., V_reset=-60, tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.), spk_reset="hard")
self.E = bp.dyn.LifRef(num, **pars)
self.I = bp.dyn.LifRef(num, **pars)
# synapses
we = 1.62 / scale # excitatory synaptic weight (voltage)
self.E2I = ExponCUBA(self.E, self.I, 0.02, g_max=we, tau=5., delay=delay)
def update(self, inp):
self.E2I()
self.E(inp)
self.I()
return self.E.spike, self.I.spike
input_data = 600 + 200 * np.random.rand(100)
bm.set_dt(1.0)
def run_model(inputs, delay: float):
bp_model = CUBA_Net(delay=delay)
runner = bp.DSRunner(
bp_model,
monitors={
"psc": bp_model.E2I.proj.syn.g,
},
)
runner.run(inputs=input_data)
ts = runner.mon.ts
out = runner.mon
return out, ts
def visualize_field(
out_0,
out_1,
ts,
field: str,
name=None,
legend=["0", "1"],
population=[1, 50, 53, 60, 85],
title_format=lambda i, s: f"[{i}] = {s}",
ref_line=None,
ref_name=None,
):
assert not (out_0 is None and out_1 is None)
if name is None:
name = field
n_sel = len(population)
sel = population
fig, gs = bp.visualize.get_figure(n_sel, 1, 1.5, 8)
for i, s in enumerate(sel):
ax1 = fig.add_subplot(gs[i, 0])
if out_0 is not None:
ax1.plot(ts, out_tf[field][:, s], label=f"{field}_{legend[1]}", color="b")
if out_1 is not None:
ax1.plot(
ts, out_bp[field][:, s], label=f"{field}_{legend[0]}", color="r", linestyle="-."
)
if ref_line is not None:
if ref_name is None:
ref_name = f"{field}_th"
ax1.plot(ts, ref_line[s] * np.ones_like(ts), label=ref_name)
ax1.set_xlabel("Time (ms)")
ax1.set_ylabel(f"{name}")
ax1.set_xlim(-0.1, ts[-1] + 0.1)
ax1.set_title(title_format(i, s))
ax1.xaxis.set_minor_locator(MultipleLocator(2))
ax1.legend(loc="upper right")
plt.show()
out0, ts = run_model(input_data, 0.0)
out20, _ = run_model(input_data, 20.0)
visualize_field(
out20,
out0,
ts,
population=[2, 88],
legend=["20", "0"],
field="psc",
name="post spike I",
)