Skip to content

delay synapse cause jax leak #747

Open
@ZhenyuanJin

Description

@ZhenyuanJin

I'm trying to train the synaptic weight, in order to let the target neuron group fire as a target firing rate.
When delay is set to 0. , the following code works; when delay is set to non zero, it report a jax leak

My versions:

- brainpy==2.6.0.post20240803
- brainpylib==0.3.1
- jax==0.4.28
- jaxlib==0.4.28+cuda12.cudnn89
- python=3.9.19=h955ad1f_1
import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd
import matplotlib.pyplot as plt
import numpy as np

class CustomExponCOBA(bp.Projection):
    def __init__(self, pre, post, delay, tau, E, comm):
        super().__init__()

        self.proj = bp.dyn.ProjAlignPreMg2(
            pre=pre,
            delay=delay,
            syn=bp.dyn.Expon.desc(pre.num, tau=tau),
            comm=comm,
            out=bp.dyn.COBA(E=E),
            post=post,
        )

class SNN(bp.DynamicalSystem):
    def __init__(self):
        super().__init__()

        self.poisson = bp.dyn.PoissonGroup(size=11, freqs=999.)
        self.E = bp.dyn.LifRef(size=9, V_initializer=bp.init.Normal(-60., 2.))
        poisson2E_conn = bp.conn.IJConn(i=np.array([0, 1]), j=np.array([0, 1]))
        poisson2E_conn = poisson2E_conn(pre_size=11, post_size=9)
        poisson2E_comm = bp.dnn.EventCSRLinear(conn=poisson2E_conn, weight=bm.TrainVar(bm.ones(2, )*1000))
        delay = 0.
        # delay = 0.1
        self.poisson2E = CustomExponCOBA(
            pre=self.poisson,
            post=self.E,
            delay=delay,
            tau=0.1,
            E=0.,
            comm=poisson2E_comm,
        )

    def update(self, *args, **kwargs):
        self.poisson()
        self.poisson2E()
        self.E()
        return self.E.spike

with bm.training_environment():
    net = SNN()

batch_size = 5
inputs = bm.ones((5, 123, 1))

class Trainer:
    def __init__(self, net, opt):
        self.net = net
        self.opt = opt
        opt.register_train_vars(net.train_vars().unique())
        self.f_grad = bm.grad(self.f_loss, grad_vars=self.opt.vars_to_train, return_value=True)

    def f_loss(self):
        self.net.reset(batch_size)
        runner = bp.DSTrainer(self.net, progress_bar=False, numpy_mon_after_run=False)
        outs = runner.predict(inputs, reset_state=False)
        fr_mean = bm.mean(outs) * 1000. / bm.get_dt()
        return bm.square(fr_mean - 3.)

    @bm.cls_jit
    def f_train(self):
        grads, loss = self.f_grad()
        self.opt.update(grads)
        return loss

trainer = Trainer(net=net, opt=bp.optim.Adam(lr=4e-3))
for i in range(1000):
    l = trainer.f_train()
    if (i + 1) % 100 == 0:
        print(f'Train {i + 1} steps, loss {l}')

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions