Skip to content

Nudging over small opt to conv/deconv and arg "x_shape" cleanup (tests passed) #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions ngclearn/components/synapses/convolution/convSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ConvSynapse(JaxComponent): ## base-level convolutional cable
Args:
name: the string name of this cell

x_size: dimension of input signal (assuming a square input)
x_shape: 2d shape of input map signal (component currently assumess a square input maps)

shape: tuple specifying shape of this synaptic cable (usually a 4-tuple
with number `filter height x filter width x input channels x number output channels`);
Expand All @@ -36,15 +36,15 @@ class ConvSynapse(JaxComponent): ## base-level convolutional cable

padding: pre-operator padding to use -- "VALID" (none), "SAME"

resist_scale: aa fixed (resistance) scaling factor to apply to synaptic
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
transform (Default: 1.), i.e., yields: out = ((K @ in) * resist_scale) + b
where `@` denotes convolution

batch_size: batch size dimension of this component
"""

# Define Functions
def __init__(self, name, shape, x_size, filter_init=None, bias_init=None, stride=1,
def __init__(self, name, shape, x_shape, filter_init=None, bias_init=None, stride=1,
padding=None, resist_scale=1., batch_size=1, **kwargs):
super().__init__(name, **kwargs)

Expand All @@ -53,6 +53,7 @@ def __init__(self, name, shape, x_size, filter_init=None, bias_init=None, stride

## Synapse meta-parameters
self.shape = shape ## shape of synaptic filter tensor
x_size, x_size = x_shape
self.x_size = x_size
self.Rscale = resist_scale ## post-transformation scale factor
self.padding = padding
Expand Down Expand Up @@ -150,9 +151,12 @@ def help(self): ## component help function
hyperparams = {
"shape": "Shape of synaptic filter value matrix; `kernel width` x `kernel height` "
"x `number input channels` x `number output channels`",
"x_shape": "Shape of any single incoming/input feature map",
"weight_init": "Initialization conditions for synaptic filter (K) values",
"bias_init": "Initialization conditions for bias/base-rate (b) values",
"resist_scale": "Resistance level output scaling factor (R)"
"resist_scale": "Resistance level output scaling factor (R)",
"stride": "length / size of stride",
"padding": "pre-operator padding to use, i.e., `VALID` `SAME`"
}
info = {self.name: properties,
"compartments": compartment_props,
Expand Down
12 changes: 8 additions & 4 deletions ngclearn/components/synapses/convolution/deconvSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DeconvSynapse(JaxComponent): ## base-level deconvolutional cable
Args:
name: the string name of this cell

x_size: dimension of input signal (assuming a square input)
x_shape: 2d shape of input map signal (component currently assumess a square input maps)

shape: tuple specifying shape of this synaptic cable (usually a 4-tuple
with number `filter height x filter width x input channels x number output channels`);
Expand All @@ -36,15 +36,15 @@ class DeconvSynapse(JaxComponent): ## base-level deconvolutional cable

padding: pre-operator padding to use -- "VALID" (none), "SAME"

resist_scale: aa fixed (resistance) scaling factor to apply to synaptic
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
transform (Default: 1.), i.e., yields: out = ((W @.T Rscale) * in) + b
where `@.T` denotes deconvolution

batch_size: batch size dimension of this component
"""

# Define Functions
def __init__(self, name, shape, x_size, filter_init=None, bias_init=None, stride=1,
def __init__(self, name, shape, x_shape, filter_init=None, bias_init=None, stride=1,
padding=None, resist_scale=1., batch_size=1, **kwargs):
super().__init__(name, **kwargs)

Expand All @@ -53,6 +53,7 @@ def __init__(self, name, shape, x_size, filter_init=None, bias_init=None, stride

## Synapse meta-parameters
self.shape = shape ## shape of synaptic filter tensor
x_size, x_size = x_shape
self.x_size = x_size
self.Rscale = resist_scale ## post-transformation scale factor
self.padding = padding
Expand Down Expand Up @@ -138,9 +139,12 @@ def help(self): ## component help function
hyperparams = {
"shape": "Shape of synaptic filter value matrix; `kernel width` x `kernel height` "
"x `number input channels` x `number output channels`",
"x_shape": "Shape of any single incoming/input feature map",
"weight_init": "Initialization conditions for synaptic filter (K) values",
"bias_init": "Initialization conditions for bias/base-rate (b) values",
"resist_scale": "Resistance level output scaling factor (R)"
"resist_scale": "Resistance level output scaling factor (R)",
"stride": "length / size of stride",
"padding": "pre-operator padding to use, i.e., `VALID` `SAME`"
}
info = {self.name: properties,
"compartments": compartment_props,
Expand Down
36 changes: 24 additions & 12 deletions ngclearn/components/synapses/convolution/hebbianConvSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class HebbianConvSynapse(ConvSynapse): ## Hebbian-evolved convolutional cable
Args:
name: the string name of this cell

x_size: dimension of input signal (assuming a square input)
x_shape: 2d shape of input map signal (component currently assumess a square input maps)

shape: tuple specifying shape of this synaptic cable (usually a 4-tuple
with number `filter height x filter width x input channels x number output channels`);
Expand All @@ -49,7 +49,7 @@ class HebbianConvSynapse(ConvSynapse): ## Hebbian-evolved convolutional cable

padding: pre-operator padding to use -- "VALID" (none), "SAME"

resist_scale: aa fixed (resistance) scaling factor to apply to synaptic
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
transform (Default: 1.), i.e., yields: out = ((K @ in) * resist_scale) + b
where `@` denotes convolution

Expand Down Expand Up @@ -81,11 +81,11 @@ class HebbianConvSynapse(ConvSynapse): ## Hebbian-evolved convolutional cable
"""

# Define Functions
def __init__(self, name, shape, x_size, eta=0., filter_init=None, bias_init=None,
def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None,
stride=1, padding=None, resist_scale=1., w_bound=0.,
is_nonnegative=False, w_decay=0., sign_value=1., optim_type="sgd",
batch_size=1, **kwargs):
super().__init__(name, shape, x_size=x_size, filter_init=filter_init,
super().__init__(name, shape, x_shape=x_shape, filter_init=filter_init,
bias_init=bias_init, resist_scale=resist_scale, stride=stride,
padding=padding, batch_size=batch_size, **kwargs)

Expand All @@ -109,6 +109,15 @@ def __init__(self, name, shape, x_size, eta=0., filter_init=None, bias_init=None
## Shape error correction -- do shape correction inference for local updates
self._init(self.batch_size, self.x_size, self.shape, self.stride,
self.padding, self.pad_args, self.weights)
self.antiPad = None
k_size, k_size, n_in_chan, n_out_chan = self.shape
if padding == "SAME":
self.antiPad = _conv_same_transpose_padding(self.post.value.shape[1],
self.x_size, k_size, stride)
elif padding == "VALID":
self.antiPad = _conv_valid_transpose_padding(self.post.value.shape[1],
self.x_size, k_size, stride)

########################################################################

## set up outer optimization compartments
Expand Down Expand Up @@ -170,16 +179,16 @@ def evolve(self, opt_params, weights, biases, dWeights, dBiases):

@staticmethod
def _backtransmit(sign_value, x_size, shape, stride, padding, x_delta_shape,
pre, post, weights): ## action-backpropagating routine
antiPad, post, weights): ## action-backpropagating routine
## calc dInputs - adjustment w.r.t. input signal
k_size, k_size, n_in_chan, n_out_chan = shape
antiPad = None
if padding == "SAME":
antiPad = _conv_same_transpose_padding(post.shape[1], x_size,
k_size, stride)
elif padding == "VALID":
antiPad = _conv_valid_transpose_padding(post.shape[1], x_size,
k_size, stride)
# antiPad = None
# if padding == "SAME":
# antiPad = _conv_same_transpose_padding(post.shape[1], x_size,
# k_size, stride)
# elif padding == "VALID":
# antiPad = _conv_valid_transpose_padding(post.shape[1], x_size,
# k_size, stride)
dInputs = calc_dX_conv(weights, post, delta_shape=x_delta_shape,
stride_size=stride, anti_padding=antiPad)
## flip sign of back-transmitted signal (if applicable)
Expand Down Expand Up @@ -234,9 +243,12 @@ def help(self): ## component help function
hyperparams = {
"shape": "Shape of synaptic filter value matrix; `kernel width` x `kernel height` "
"x `number input channels` x `number output channels`",
"x_shape": "Shape of any single incoming/input feature map",
"weight_init": "Initialization conditions for synaptic filter (K) values",
"bias_init": "Initialization conditions for bias/base-rate (b) values",
"resist_scale": "Resistance level output scaling factor (R)",
"stride": "length / size of stride",
"padding": "pre-operator padding to use, i.e., `VALID` `SAME`",
"is_nonnegative": "Should filters be constrained to be non-negative post-updates?",
"sign_value": "Scalar `flipping` constant -- changes direction to Hebbian descent if < 0",
"w_bound": "Soft synaptic bound applied to filters post-update",
Expand Down
11 changes: 7 additions & 4 deletions ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class HebbianDeconvSynapse(DeconvSynapse): ## Hebbian-evolved deconvolutional ca
Args:
name: the string name of this cell

x_size: dimension of input signal (assuming a square input)
x_shape: dimension of input signal (assuming a square input)

shape: tuple specifying shape of this synaptic cable (usually a 4-tuple
with number `filter height x filter width x input channels x number output channels`);
Expand All @@ -47,7 +47,7 @@ class HebbianDeconvSynapse(DeconvSynapse): ## Hebbian-evolved deconvolutional ca

padding: pre-operator padding to use -- "VALID" (none), "SAME"

resist_scale: aa fixed (resistance) scaling factor to apply to synaptic
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
transform (Default: 1.), i.e., yields: out = ((W @.T Rscale) * in) + b
where `@.T` denotes deconvolution

Expand Down Expand Up @@ -79,10 +79,10 @@ class HebbianDeconvSynapse(DeconvSynapse): ## Hebbian-evolved deconvolutional ca
"""

# Define Functions
def __init__(self, name, shape, x_size, eta=0., filter_init=None, bias_init=None,
def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None,
stride=1, padding=None, resist_scale=1., w_bound=0., is_nonnegative=False,
w_decay=0., sign_value=1., optim_type="sgd", batch_size=1, **kwargs):
super().__init__(name, shape, x_size=x_size, filter_init=filter_init,
super().__init__(name, shape, x_shape=x_shape, filter_init=filter_init,
bias_init=bias_init, resist_scale=resist_scale,
stride=stride, padding=padding, batch_size=batch_size,
**kwargs)
Expand Down Expand Up @@ -224,9 +224,12 @@ def help(self): ## component help function
hyperparams = {
"shape": "Shape of synaptic filter value matrix; `kernel width` x `kernel height` "
"x `number input channels` x `number output channels`",
"x_shape": "Shape of any single incoming/input feature map",
"weight_init": "Initialization conditions for synaptic filter (K) values",
"bias_init": "Initialization conditions for bias/base-rate (b) values",
"resist_scale": "Resistance level output scaling factor (R)",
"stride": "length / size of stride",
"padding": "pre-operator padding to use, i.e., `VALID` `SAME`",
"is_nonnegative": "Should filters be constrained to be non-negative post-updates?",
"sign_value": "Scalar `flipping` constant -- changes direction to Hebbian descent if < 0",
"w_bound": "Soft synaptic bound applied to filters post-update",
Expand Down
4 changes: 2 additions & 2 deletions ngclearn/components/synapses/convolution/staticConvSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class StaticConvSynapse(ConvSynapse):
Args:
name: the string name of this cell

x_size: dimension of input signal (assuming a square input)
x_shape: 2d shape of input map signal (component currently assumess a square input maps)

shape: tuple specifying shape of this synaptic cable (usually a 4-tuple
with number `filter height x filter width x input channels x number output channels`);
Expand All @@ -31,7 +31,7 @@ class StaticConvSynapse(ConvSynapse):

padding: pre-operator padding to use -- "VALID" (none), "SAME"

resist_scale: aa fixed (resistance) scaling factor to apply to synaptic
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
transform (Default: 1.), i.e., yields: out = ((K @ in) * resist_scale) + b
where `@` denotes convolution

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class StaticDeconvSynapse(DeconvSynapse):
Args:
name: the string name of this cell

x_size: dimension of input signal (assuming a square input)
x_shape: dimension of input signal (assuming a square input)

shape: tuple specifying shape of this synaptic cable (usually a 4-tuple
with number `filter height x filter width x input channels x number output channels`);
Expand All @@ -31,7 +31,7 @@ class StaticDeconvSynapse(DeconvSynapse):

padding: pre-operator padding to use -- "VALID" (none), "SAME"

resist_scale: aa fixed (resistance) scaling factor to apply to synaptic
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
transform (Default: 1.), i.e., yields: out = ((K @ in) * resist_scale) + b
where `@` denotes convolution

Expand Down
38 changes: 25 additions & 13 deletions ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TraceSTDPConvSynapse(ConvSynapse): ## trace-based STDP convolutional cable
Args:
name: the string name of this cell

x_size: dimension of input signal (assuming a square input)
x_shape: 2d shape of input map signal (component currently assumess a square input maps)

shape: tuple specifying shape of this synaptic cable (usually a 4-tuple
with number `filter height x filter width x input channels x number output channels`);
Expand All @@ -55,7 +55,7 @@ class TraceSTDPConvSynapse(ConvSynapse): ## trace-based STDP convolutional cable

padding: pre-operator padding to use -- "VALID" (none), "SAME"

resist_scale: aa fixed (resistance) scaling factor to apply to synaptic
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
transform (Default: 1.), i.e., yields: out = ((K @ in) * resist_scale) + b
where `@` denotes convolution

Expand All @@ -69,10 +69,10 @@ class TraceSTDPConvSynapse(ConvSynapse): ## trace-based STDP convolutional cable
"""

# Define Functions
def __init__(self, name, shape, x_size, A_plus, A_minus, eta=0.,
def __init__(self, name, shape, x_shape, A_plus, A_minus, eta=0.,
pretrace_target=0., filter_init=None, stride=1, padding=None,
resist_scale=1., w_bound=0., w_decay=0., batch_size=1, **kwargs):
super().__init__(name, shape, x_size=x_size, filter_init=filter_init,
super().__init__(name, shape, x_shape=x_shape, filter_init=filter_init,
bias_init=None, resist_scale=resist_scale, stride=stride,
padding=padding, batch_size=batch_size, **kwargs)

Expand All @@ -97,6 +97,15 @@ def __init__(self, name, shape, x_size, A_plus, A_minus, eta=0.,
## Shape error correction -- do shape correction inference for local updates
self._init(self.batch_size, self.x_size, self.shape, self.stride,
self.padding, self.pad_args, self.weights)
k_size, k_size, n_in_chan, n_out_chan = self.shape
if padding == "SAME":
self.antiPad = _conv_same_transpose_padding(
self.postSpike.value.shape[1],
self.x_size, k_size, stride)
elif padding == "VALID":
self.antiPad = _conv_valid_transpose_padding(
self.postSpike.value.shape[1],
self.x_size, k_size, stride)
########################################################################

def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
Expand Down Expand Up @@ -147,17 +156,17 @@ def evolve(self, weights, dWeights):
self.dWeights.set(dWeights)

@staticmethod
def _backtransmit(x_size, shape, stride, padding, x_delta_shape,
preSpike, postSpike, weights): ## action-backpropagating routine
def _backtransmit(x_size, shape, stride, padding, x_delta_shape, antiPad,
postSpike, weights): ## action-backpropagating routine
## calc dInputs - adjustment w.r.t. input signal
k_size, k_size, n_in_chan, n_out_chan = shape
antiPad = None
if padding == "SAME":
antiPad = _conv_same_transpose_padding(postSpike.shape[1], x_size,
k_size, stride)
elif padding == "VALID":
antiPad = _conv_valid_transpose_padding(postSpike.shape[1], x_size,
k_size, stride)
# antiPad = None
# if padding == "SAME":
# antiPad = _conv_same_transpose_padding(postSpike.shape[1], x_size,
# k_size, stride)
# elif padding == "VALID":
# antiPad = _conv_valid_transpose_padding(postSpike.shape[1], x_size,
# k_size, stride)
dInputs = calc_dX_conv(weights, postSpike, delta_shape=x_delta_shape,
stride_size=stride, anti_padding=antiPad)
return dInputs
Expand Down Expand Up @@ -213,9 +222,12 @@ def help(self): ## component help function
hyperparams = {
"shape": "Shape of synaptic filter value matrix; `kernel width` x `kernel height` "
"x `number input channels` x `number output channels`",
"x_shape": "Shape of any single incoming/input feature map",
"weight_init": "Initialization conditions for synaptic filter (K) values",
"bias_init": "Initialization conditions for bias/base-rate (b) values",
"resist_scale": "Resistance level output scaling factor (R)",
"stride": "length / size of stride",
"padding": "pre-operator padding to use, i.e., `VALID` `SAME`",
"A_plus": "Strength of long-term potentiation (LTP)",
"A_minus": "Strength of long-term depression (LTD)",
"eta": "Global learning rate (multiplier beyond A_plus and A_minus)",
Expand Down
Loading