Skip to content

add-sigma-to-gaussianErrorCell #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions ngclearn/components/neurons/graded/gaussianErrorCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jax import numpy as jnp, jit
from ngclearn.utils import tensorstats

def _run_cell(dt, targ, mu):
def _run_cell(dt, targ, mu, sigma):
"""
Moves cell dynamics one step forward.

Expand All @@ -14,13 +14,15 @@ def _run_cell(dt, targ, mu):

mu: prediction value

sigma: prediction variance

Returns:
derivative w.r.t. mean "dmu", derivative w.r.t. target dtarg, local loss
"""
return _run_gaussian_cell(dt, targ, mu)
return _run_gaussian_cell(dt, targ, mu, sigma)

@jit
def _run_gaussian_cell(dt, targ, mu):
def _run_gaussian_cell(dt, targ, mu, sigma):
"""
Moves Gaussian cell dynamics one step forward. Specifically, this
routine emulates the error unit behavior of the local cost functional:
Expand All @@ -35,13 +37,16 @@ def _run_gaussian_cell(dt, targ, mu):

mu: prediction value

sigma: prediction variance

Returns:
derivative w.r.t. mean "dmu", derivative w.r.t. target dtarg, loss
"""
dmu = (targ - mu) # e (error unit)
dmu = (targ - mu)/sigma # e (error unit)
dtarg = -dmu # reverse of e
L = -jnp.sum(jnp.square(dmu)) * 0.5
return dmu, dtarg, L
dsigma = 1. # no derivative is calculated at this time for sigma
L = -jnp.sum(jnp.square(dmu)) * 0.5 / sigma
return dmu, dtarg, dsigma, L

class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
"""
Expand All @@ -66,8 +71,10 @@ class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
tau_m: (Unused -- currently cell is a fixed-point model)

leakRate: (Unused -- currently cell is a fixed-point model)

sigma: prediction covariance matrix (𝚺) in multivariate gaussian distribution
"""
def __init__(self, name, n_units, batch_size=1, shape=None, **kwargs):
def __init__(self, name, n_units, batch_size=1, sigma=1, shape=None, **kwargs):
super().__init__(name, **kwargs)

## Layer Size Setup
Expand All @@ -79,6 +86,7 @@ def __init__(self, name, n_units, batch_size=1, shape=None, **kwargs):
self.shape = shape
self.n_units = n_units
self.batch_size = batch_size
self.sigma = sigma

## Convolution shape setup
self.width = self.height = n_units
Expand All @@ -94,11 +102,12 @@ def __init__(self, name, n_units, batch_size=1, shape=None, **kwargs):
self.mask = Compartment(restVals + 1.0)

@staticmethod
def _advance_state(dt, mu, dmu, target, dtarget, modulator, mask):
def _advance_state(dt, mu, dmu, target, dtarget, sigma, modulator, mask):
## compute Gaussian error cell output
dmu, dtarget, L = _run_cell(dt, target * mask, mu * mask)
dmu, dtarget, dsigma, L = _run_cell(dt, target * mask, mu * mask, sigma)
dmu = dmu * modulator * mask
dtarget = dtarget * modulator * mask
dsigma = dsigma * 0 + 1. # no derivative is calculated at this time for sigma
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
return dmu, dtarget, L, mask

Expand Down Expand Up @@ -153,11 +162,12 @@ def help(cls): ## component help function
}
hyperparams = {
"n_units": "Number of neuronal cells to model in this layer",
"batch_size": "Batch size dimension of this component"
"batch_size": "Batch size dimension of this component",
"sigma": "External input variance value (currently fixed and not learnable)"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
"dynamics": "Gaussian(x=target; mu, sigma=1)",
"dynamics": "Gaussian(x=target; mu, sigma)",
"hyperparameters": hyperparams}
return info

Expand Down