Skip to content

Delay in FullProjAlignPost* is off by a few steps #709

Open
@alexfanqi

Description

@alexfanqi

Hi, BrainPy team

While modelling synapse connections with delay, I discovered the delay in FullProjAlignPost* is always off by some steps compared to the value passed as argument.

I tested a few delays with dt set to 1.0, roughly psc delay = relu(delay - 2.0)

This happens in both feedforward and recurrent connection

example:
with delay = 2.0, psc delay is 0
image

with delay = 20.0, psc delay is 18
image

reproducer code as modified from tutorial example

import numpy as np
import brainpy as bp
import brainpy.math as bm
class ExponCUBA(bp.Projection):
    def __init__(self, pre, post, prob, g_max, tau, delay):
        super().__init__()
        self.proj = bp.dyn.FullProjAlignPost(
            pre=pre,
            delay=delay, 
            comm=bp.dnn.OneToOne(pre.num, g_max),
            syn=bp.dyn.Expon(post.num, tau=tau),
            out=bp.dyn.CUBA(),
            post=post,
        )

class CUBA_Net(bp.DynSysGroup):
  def __init__(self, scale=1.0, delay=None):
    super().__init__()
      
    # network size
    num = int(3200 * scale)

    # neurons
    pars = dict(V_rest=-49, V_th=-50., V_reset=-60, tau=20., tau_ref=5.,
                V_initializer=bp.init.Normal(-55., 2.), spk_reset="hard")
    self.E = bp.dyn.LifRef(num, **pars)
    self.I = bp.dyn.LifRef(num, **pars)

    # synapses
    we = 1.62 / scale  # excitatory synaptic weight (voltage)
    self.E2I = ExponCUBA(self.E, self.I, 0.02, g_max=we, tau=5., delay=delay)
    
  def update(self, inp):
      self.E2I()
      self.E(inp)
      self.I()
      return self.E.spike, self.I.spike

input_data = 600 + 200 * np.random.rand(100)
bm.set_dt(1.0)
def run_model(inputs, delay: float):
    bp_model = CUBA_Net(delay=delay)
    runner = bp.DSRunner(
        bp_model,
        monitors={
            "psc": bp_model.E2I.proj.syn.g,
        },
    )
    runner.run(inputs=input_data)

    ts = runner.mon.ts
    out = runner.mon

    return out, ts

def visualize_field(
    out_0,
    out_1,
    ts,
    field: str,
    name=None,
    legend=["0", "1"],
    population=[1, 50, 53, 60, 85],
    title_format=lambda i, s: f"[{i}] = {s}",
    ref_line=None,
    ref_name=None,
):
    assert not (out_0 is None and out_1 is None)
    if name is None:
        name = field
    n_sel = len(population)
    sel = population
    fig, gs = bp.visualize.get_figure(n_sel, 1, 1.5, 8)
    for i, s in enumerate(sel):
        ax1 = fig.add_subplot(gs[i, 0])
        if out_0 is not None:
            ax1.plot(ts, out_tf[field][:, s], label=f"{field}_{legend[1]}", color="b")
        if out_1 is not None:
            ax1.plot(
                ts, out_bp[field][:, s], label=f"{field}_{legend[0]}", color="r", linestyle="-."
            )
        if ref_line is not None:
            if ref_name is None:
                ref_name = f"{field}_th"
            ax1.plot(ts, ref_line[s] * np.ones_like(ts), label=ref_name)
        ax1.set_xlabel("Time (ms)")
        ax1.set_ylabel(f"{name}")
        ax1.set_xlim(-0.1, ts[-1] + 0.1)
        ax1.set_title(title_format(i, s))
        ax1.xaxis.set_minor_locator(MultipleLocator(2))
        ax1.legend(loc="upper right")

    plt.show()

out0, ts = run_model(input_data, 0.0)
out20, _ = run_model(input_data, 20.0)

visualize_field(
    out20,
    out0,
    ts,
    population=[2, 88],
    legend=["20", "0"],
    field="psc",
    name="post spike I",
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions