Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
565383a
add PyTorch tensorboard log default directory and VS Code workspace
mahbodnr Jul 5, 2021
0b84910
solve initialization error for Connection class
mahbodnr Jul 14, 2021
9311e4e
Update topology.py
C-Earl Jul 14, 2021
4b90aef
Added wmin/wmax to Parameters, begin changes on conv2d
C-Earl Jul 15, 2021
25fec03
Typo in param descriptions fixed
C-Earl Jul 15, 2021
f8e5768
Added wmin/wmax tensor for Conv2dConnection
C-Earl Jul 16, 2021
e74a456
Added functionality for tensor wmin/wmax for MeanFieldConnection (plu…
C-Earl Jul 16, 2021
78596e3
Trialing replacement of all() with any(), plus beginning adding funct…
C-Earl Jul 17, 2021
40de0da
Revert "Trialing replacement of all() with any(), plus beginning addi…
C-Earl Jul 17, 2021
b1cfe75
Abstract class variables for wmin/wmax updated to be dtype float32
C-Earl Jul 17, 2021
f54b05a
Revert "Revert "Trialing replacement of all() with any(), plus beginn…
C-Earl Jul 17, 2021
46cdf36
add test weights method
mahbodnr Jul 18, 2021
1003feb
wmin/wmax tensor compatability for LocalConnection and small typo fix…
C-Earl Jul 20, 2021
fe6921f
update assert message + minor changes
mahbodnr Jul 20, 2021
90a97a0
Fixed all all()'s to any()'s
C-Earl Jul 20, 2021
8b10eff
Merge branch 'tensor-wmin-and-wmax' of https://github.com/mahbodnr/bi…
C-Earl Jul 20, 2021
51f8c55
Initial fixes for wmin/wmax tensor support in learning.py
C-Earl Jul 20, 2021
2a7259a
add other learning rules and connections
mahbodnr Jul 20, 2021
1533e01
- Fixed weight size parameters for MeanFieldConnection
C-Earl Jul 21, 2021
65d7016
Added a check for a tensor containing both finite and infinite for wm…
C-Earl Jul 21, 2021
415911c
Delete tmp_test_connections.py
mahbodnr Jul 21, 2021
c270a86
Additional comment adjustments/deletions
C-Earl Jul 21, 2021
d073a55
Merge branch 'tensor-wmin-and-wmax' of https://github.com/mahbodnr/bi…
C-Earl Jul 21, 2021
03610a6
black formater
Hananel-Hazan Jul 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,16 @@ examples/saved_checkpoints
# PyCharm project folder.
.idea/

# VS Code workspace
.vscode/

# macOS
.DS_Store

figures/

# Analyzer log default directory.
logs/

# PyTorch tensorboard log default directory.
runs/
45 changes: 23 additions & 22 deletions bindsnet/learning/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand Down Expand Up @@ -59,9 +59,9 @@ def __init__(

if (self.nu == torch.zeros(2)).all() and not isinstance(self, NoOp):
warnings.warn(
f"nu is set to [0., 0.] for {type(self).__name__} learning rule. " +
"It will disable the learning process."
)
f"nu is set to [0., 0.] for {type(self).__name__} learning rule. "
+ "It will disable the learning process."
)

# Parameter update reduction across minibatch dimension.
if reduction is None:
Expand All @@ -86,7 +86,8 @@ def update(self) -> None:

# Bound weights.
if (
self.connection.wmin != -np.inf or self.connection.wmax != np.inf
(self.connection.wmin != -np.inf).any()
or (self.connection.wmax != np.inf).any()
) and not isinstance(self, NoOp):
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)

Expand All @@ -103,7 +104,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -120,7 +121,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

def update(self, **kwargs) -> None:
Expand All @@ -144,7 +145,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -162,7 +163,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

assert (
Expand Down Expand Up @@ -260,7 +261,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -278,13 +279,13 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

assert self.source.traces, "Pre-synaptic nodes must record spike traces."
assert (
connection.wmin != -np.inf and connection.wmax != np.inf
), "Connection must define finite wmin and wmax."
assert (connection.wmin != -np.inf).any() and (
connection.wmax != np.inf
).any(), "Connection must define finite wmin and wmax."

self.wmin = connection.wmin
self.wmax = connection.wmax
Expand Down Expand Up @@ -398,7 +399,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -416,7 +417,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

assert (
Expand Down Expand Up @@ -503,7 +504,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -527,7 +528,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

if isinstance(connection, (Connection, LocalConnection)):
Expand Down Expand Up @@ -697,7 +698,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -722,7 +723,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

if isinstance(connection, (Connection, LocalConnection)):
Expand Down Expand Up @@ -901,7 +902,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -926,7 +927,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

# Trace is needed for computing epsilon.
Expand Down
7 changes: 4 additions & 3 deletions bindsnet/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _get_inputs(self, layers: Iterable = None) -> Dict[str, torch.Tensor]:
self.batch_size,
target.res_window_size,
*target.shape,
device=target.s.device
device=target.s.device,
)
else:
inputs[c[1]] = torch.zeros(
Expand Down Expand Up @@ -308,8 +308,9 @@ def run(
plt.show()
"""
# Check input type
assert type(inputs) == dict, ("'inputs' must be a dict of names of layers " +
f"(str) and relevant input tensors. Got {type(inputs).__name__} instead."
assert type(inputs) == dict, (
"'inputs' must be a dict of names of layers "
+ f"(str) and relevant input tensors. Got {type(inputs).__name__} instead."
)
# Parse keyword arguments.
clamps = kwargs.get("clamp", {})
Expand Down
Loading