Skip to content

Commit

Permalink
fixed minor bug in deprecation for poiss/bern
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 25, 2024
1 parent cdea291 commit efa61a5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ngclearn/components/input_encoders/bernoulliCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class BernoulliCell(JaxComponent):
target_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
"""

@deprecate_args(target_freq="max_freq")
@deprecate_args(max_freq="target_freq")
def __init__(self, name, n_units, target_freq=63.75, batch_size=1, **kwargs):
super().__init__(name, **kwargs)

Expand Down
18 changes: 16 additions & 2 deletions ngclearn/components/input_encoders/poissonCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import numpy as jnp, random, jit, scipy
from functools import partial
from ngcsimlib.deprecators import deprecate_args

from ngcsimlib.logger import info, warn

class PoissonCell(JaxComponent):
"""
Expand All @@ -29,7 +29,7 @@ class PoissonCell(JaxComponent):
"""

# Define Functions
@deprecate_args(target_freq="max_freq")
@deprecate_args(max_freq="target_freq")
def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
**kwargs):
super().__init__(name, **kwargs)
Expand All @@ -56,6 +56,20 @@ def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
random.uniform(subkey, (self.batch_size, self.n_units), minval=0.,
maxval=1.))

def validate(self, dt, **validation_kwargs):
## check for unstable combinations of dt and target-frequency meta-params
valid = super().validate(**validation_kwargs)
events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability
if events_per_timestep > 1.:
valid = False
warn(
f"{self.name} will be unable to make as many temporal events as "
f"requested! ({events_per_timestep} events/timestep) Unstable "
f"combination of dt = {dt} and target_freq = {self.target_freq} "
f"being used!"
)
return valid

@staticmethod
def _advance_state(t, dt, target_freq, key, inputs, targets, tols):
ms_per_second = 1000 # ms/s
Expand Down

0 comments on commit efa61a5

Please sign in to comment.