Skip to content

rename variables for masking #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ngclearn/components/synapses/denseSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __init__(self, name, shape, weight_init=None, bias_init=None,
self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8}
weights = initialize_params(subkeys[0], self.weight_init, shape)
if 0. < p_conn < 1.: ## only non-zero and <1 probs allowed
mask = random.bernoulli(subkeys[1], p=p_conn, shape=shape)
weights = weights * mask ## sparsify matrix
p_mask = random.bernoulli(subkeys[1], p=p_conn, shape=shape)
weights = weights * p_mask ## sparsify matrix

self.batch_size = 1
## Compartment setup
Expand Down
36 changes: 18 additions & 18 deletions ngclearn/components/synapses/patched/hebbianPatchedSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ngcsimlib.compilers.process import transition

@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,
def _calc_update(pre, post, W, mask, w_bound, is_nonnegative=True, signVal=1.,
prior_type=None, prior_lmbda=0.,
pre_wght=1., post_wght=1.):
"""
Expand All @@ -21,7 +21,7 @@ def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,

W: synaptic weight values (at time t)

w_mask: synaptic weight masking matrix (same shape as W)
mask: synaptic weight masking matrix (same shape as W)

w_bound: maximum value to enforce over newly computed efficacies

Expand Down Expand Up @@ -64,21 +64,21 @@ def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,

dW = dW + prior_lmbda * dW_reg

if w_mask!=None:
dW = dW * w_mask
if mask!=None:
dW = dW * mask

return dW * signVal, db * signVal

@partial(jit, static_argnums=[1,2, 3])
def _enforce_constraints(W, w_mask, w_bound, is_nonnegative=True):
def _enforce_constraints(W, block_mask, w_bound, is_nonnegative=True):
"""
Enforces constraints that the (synaptic) efficacies/values within matrix
`W` must adhere to.

Args:
W: synaptic weight values (at time t)

w_mask: weight mask matrix
block_mask: weight mask matrix

w_bound: maximum value to enforce over newly computed efficacies

Expand All @@ -94,8 +94,8 @@ def _enforce_constraints(W, w_mask, w_bound, is_nonnegative=True):
else:
_W = jnp.clip(_W, -w_bound, w_bound)

if w_mask!=None:
_W = _W * w_mask
if block_mask!=None:
_W = _W * block_mask

return _W

Expand Down Expand Up @@ -138,7 +138,7 @@ class HebbianPatchedSynapse(PatchedSynapse):
bias_init: a kernel to drive initialization of biases for this synaptic cable
(Default: None, which turns off/disables biases)

w_mask: weight mask matrix
block_mask: weight mask matrix

w_bound: maximum weight to softly bound this cable's value matrix to; if
set to 0, then no synaptic value bounding will be applied
Expand Down Expand Up @@ -186,10 +186,10 @@ class HebbianPatchedSynapse(PatchedSynapse):
"""

def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None,
w_mask=None, w_bound=1., is_nonnegative=False, prior=(None, 0.), sign_value=1.,
block_mask=None, w_bound=1., is_nonnegative=False, prior=(None, 0.), sign_value=1.,
optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
resist_scale=1., batch_size=1, **kwargs):
super().__init__(name, shape, n_sub_models, stride_shape, w_mask, weight_init, bias_init, resist_scale,
super().__init__(name, shape, n_sub_models, stride_shape, block_mask, weight_init, bias_init, resist_scale,
p_conn, batch_size=batch_size, **kwargs)

prior_type, prior_lmbda = prior
Expand Down Expand Up @@ -221,7 +221,7 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weig
self.postVals = jnp.zeros((self.batch_size, self.shape[1]))
self.pre = Compartment(self.preVals)
self.post = Compartment(self.postVals)
self.w_mask = w_mask
self.block_mask = block_mask
self.dWeights = Compartment(jnp.zeros(self.shape))
self.dBiases = Compartment(jnp.zeros(self.shape[1]))

Expand All @@ -231,23 +231,23 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weig
if bias_init else [self.weights.value]))

@staticmethod
def _compute_update(w_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
def _compute_update(block_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
post_wght, pre, post, weights):
## calculate synaptic update values
dW, db = _calc_update(
pre, post, weights, w_mask, w_bound, is_nonnegative=is_nonnegative,
pre, post, weights, block_mask, w_bound, is_nonnegative=is_nonnegative,
signVal=sign_value, prior_type=prior_type, prior_lmbda=prior_lmbda, pre_wght=pre_wght,
post_wght=post_wght)

return dW * jnp.where(0 != jnp.abs(weights), 1, 0) , db

@transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"])
@staticmethod
def evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
def evolve(block_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
post_wght, bias_init, pre, post, weights, biases, opt_params):
## calculate synaptic update values
dWeights, dBiases = HebbianPatchedSynapse._compute_update(
w_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda,
block_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda,
pre_wght, post_wght, pre, post, weights
)
## conduct a step of optimization - get newly evolved synaptic weight value matrix
Expand All @@ -257,7 +257,7 @@ def evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_l
# ignore db since no biases configured
opt_params, [weights] = opt(opt_params, [weights], [dWeights])
## ensure synaptic efficacies adhere to constraints
weights = _enforce_constraints(weights, w_mask, w_bound, is_nonnegative=is_nonnegative)
weights = _enforce_constraints(weights, block_mask, w_bound, is_nonnegative=is_nonnegative)
return opt_params, weights, biases, dWeights, dBiases

@transition(output_compartments=["inputs", "outputs", "pre", "post", "dWeights", "dBiases"])
Expand Down Expand Up @@ -313,7 +313,7 @@ def help(cls): ## component help function
"post_wght": "Post-synaptic weighting coefficient (q_post)",
"w_bound": "Soft synaptic bound applied to synapses post-update",
"prior": "prior name and value for synaptic updating prior",
"w_mask": "weight mask matrix",
"block_mask": "weight mask matrix",
"optim_type": "Choice of optimizer to adjust synaptic weights"
}
info = {cls.__name__: properties,
Expand Down
8 changes: 4 additions & 4 deletions ngclearn/components/synapses/patched/patchedSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable
bias_init: a kernel to drive initialization of biases for this synaptic cable
(Default: None, which turns off/disables biases)

w_mask: weight mask matrix
block_mask: weight mask matrix

pre_wght: pre-synaptic weighting factor (Default: 1.)

Expand All @@ -92,7 +92,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable
this to < 1. will result in a sparser synaptic structure
"""

def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), w_mask=None, weight_init=None, bias_init=None,
def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), block_mask=None, weight_init=None, bias_init=None,
resist_scale=1., p_conn=1., batch_size=1, **kwargs):
super().__init__(name, **kwargs)

Expand All @@ -112,7 +112,7 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), w_mask=None,
weights = create_multi_patch_synapses(key=subkeys, shape=shape, n_sub_models=self.n_sub_models, sub_stride=self.sub_stride,
weight_init=self.weight_init)

self.w_mask = jnp.where(weights!=0, 1, 0)
self.block_mask = jnp.where(weights!=0, 1, 0)
self.sub_shape = (shape[0]//n_sub_models, shape[1]//n_sub_models)

self.shape = weights.shape
Expand Down Expand Up @@ -192,7 +192,7 @@ def help(cls): ## component help function
"weight_init": "Initialization conditions for synaptic weight (W) values",
"bias_init": "Initialization conditions for bias/base-rate (b) values",
"resist_scale": "Resistance level scaling factor (Rscale); applied to output of transformation",
"w_mask": "weight mask matrix",
"block_mask": "weight mask matrix",
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)"
}
info = {cls.__name__: properties,
Expand Down