Open
Description
I'm trying to train the synaptic weight, in order to let the target neuron group fire as a target firing rate.
When delay is set to 0. , the following code works; when delay is set to non zero, it report a jax leak
My versions:
- brainpy==2.6.0.post20240803
- brainpylib==0.3.1
- jax==0.4.28
- jaxlib==0.4.28+cuda12.cudnn89
- python=3.9.19=h955ad1f_1
import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd
import matplotlib.pyplot as plt
import numpy as np
class CustomExponCOBA(bp.Projection):
def __init__(self, pre, post, delay, tau, E, comm):
super().__init__()
self.proj = bp.dyn.ProjAlignPreMg2(
pre=pre,
delay=delay,
syn=bp.dyn.Expon.desc(pre.num, tau=tau),
comm=comm,
out=bp.dyn.COBA(E=E),
post=post,
)
class SNN(bp.DynamicalSystem):
def __init__(self):
super().__init__()
self.poisson = bp.dyn.PoissonGroup(size=11, freqs=999.)
self.E = bp.dyn.LifRef(size=9, V_initializer=bp.init.Normal(-60., 2.))
poisson2E_conn = bp.conn.IJConn(i=np.array([0, 1]), j=np.array([0, 1]))
poisson2E_conn = poisson2E_conn(pre_size=11, post_size=9)
poisson2E_comm = bp.dnn.EventCSRLinear(conn=poisson2E_conn, weight=bm.TrainVar(bm.ones(2, )*1000))
delay = 0.
# delay = 0.1
self.poisson2E = CustomExponCOBA(
pre=self.poisson,
post=self.E,
delay=delay,
tau=0.1,
E=0.,
comm=poisson2E_comm,
)
def update(self, *args, **kwargs):
self.poisson()
self.poisson2E()
self.E()
return self.E.spike
with bm.training_environment():
net = SNN()
batch_size = 5
inputs = bm.ones((5, 123, 1))
class Trainer:
def __init__(self, net, opt):
self.net = net
self.opt = opt
opt.register_train_vars(net.train_vars().unique())
self.f_grad = bm.grad(self.f_loss, grad_vars=self.opt.vars_to_train, return_value=True)
def f_loss(self):
self.net.reset(batch_size)
runner = bp.DSTrainer(self.net, progress_bar=False, numpy_mon_after_run=False)
outs = runner.predict(inputs, reset_state=False)
fr_mean = bm.mean(outs) * 1000. / bm.get_dt()
return bm.square(fr_mean - 3.)
@bm.cls_jit
def f_train(self):
grads, loss = self.f_grad()
self.opt.update(grads)
return loss
trainer = Trainer(net=net, opt=bp.optim.Adam(lr=4e-3))
for i in range(1000):
l = trainer.f_train()
if (i + 1) % 100 == 0:
print(f'Train {i + 1} steps, loss {l}')