Skip to content

Monitor plot #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 69 additions & 16 deletions ngclearn/components/base_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from ngclearn import Component, Compartment
from ngclearn import numpy as np
from ngcsimlib.utils import add_component_resolver, add_resolver_meta, get_current_path
from ngcsimlib.utils import add_component_resolver, add_resolver_meta, \
get_current_path
from ngcsimlib.logger import warn, critical
import matplotlib.pyplot as plt


class Base_Monitor(Component):
Expand All @@ -21,7 +23,8 @@ class Base_Monitor(Component):
Using custom window length:
myMonitor.watch(myComponent.myCompartment, customWindowLength)

To get values out of the monitor either path to the stored value directly, or pass in a compartment directly. All
To get values out of the monitor either path to the stored value
directly, or pass in a compartment directly. All
paths are the same as their local path variable.

Using a compartment:
Expand All @@ -30,7 +33,8 @@ class Base_Monitor(Component):
Using a path:
myMonitor.get_store(myComponent.myCompartment.path).value

There can only be one monitor in existence at a time due to the way it interacts with resolvers and the compilers
There can only be one monitor in existence at a time due to the way it
interacts with resolvers and the compilers
for ngclearn.

Args:
Expand All @@ -53,10 +57,10 @@ def build_advance(compartments):

"""
critical(
"build_advance() is not defined on this monitor, use either the monitor found in ngclearn.components or "
"build_advance() is not defined on this monitor, use either the "
"monitor found in ngclearn.components or "
"ngclearn.components.lava (If using lava)")


@staticmethod
def build_reset(compartments):
"""
Expand All @@ -66,6 +70,7 @@ def build_reset(compartments):

Returns: The method to reset the stored values.
"""

@staticmethod
def _reset(**kwargs):
return_vals = []
Expand Down Expand Up @@ -95,7 +100,8 @@ def __lshift__(self, other):

def watch(self, compartment, window_length):
"""
Sets the monitor to watch a specific compartment, for a specified window length.
Sets the monitor to watch a specific compartment, for a specified
window length.

Args:
compartment: the compartment object to monitor
Expand Down Expand Up @@ -150,7 +156,7 @@ def halt_all(self):
"""
for compartment in self._sources:
self.halt(compartment)

def _update_resolver(self):
output_compartments = []
compartments = []
Expand All @@ -162,13 +168,18 @@ def _update_resolver(self):
parameters = []

add_component_resolver(self.__class__.__name__, "advance_state",
(self.build_advance(compartments), output_compartments))
(self.build_advance(compartments),
output_compartments))
add_resolver_meta(self.__class__.__name__, "advance_state",
(args, parameters, compartments + [o for o in output_compartments], False))
(args, parameters,
compartments + [o for o in output_compartments],
False))

add_component_resolver(self.__class__.__name__, "reset", (self.build_reset(compartments), output_compartments))
add_component_resolver(self.__class__.__name__, "reset", (
self.build_reset(compartments), output_compartments))
add_resolver_meta(self.__class__.__name__, "reset",
(args, parameters, [o for o in output_compartments], False))
(args, parameters, [o for o in output_compartments],
False))

def _add_path(self, path):
_path = path.split("/")[1:]
Expand Down Expand Up @@ -210,7 +221,8 @@ def save(self, directory, **kwargs):
for key in self.compartments:
n = key.split("/")[-1]
_dict["sources"][key] = self.__dict__[n].value.shape
_dict["stores"][key + "*store"] = self.__dict__[n + "*store"].value.shape
_dict["stores"][key + "*store"] = self.__dict__[
n + "*store"].value.shape

with open(file_name, "w") as f:
json.dump(_dict, f)
Expand All @@ -221,9 +233,9 @@ def load(self, directory, **kwargs):
vals = json.load(f)

for comp_path, shape in vals["stores"].items():

compartment_path = comp_path.split("/")[-1]
new_path = get_current_path() + "/" + "/".join(compartment_path.split("*")[-3:-1])
new_path = get_current_path() + "/" + "/".join(
compartment_path.split("*")[-3:-1])

cs, end = self._add_path(new_path)

Expand All @@ -233,8 +245,6 @@ def load(self, directory, **kwargs):
cs[end] = new_comp
setattr(self, compartment_path, new_comp)



for comp_path, shape in vals['sources'].items():
compartment_path = comp_path.split("/")[-1]
new_comp = Compartment(np.zeros(shape))
Expand All @@ -244,3 +254,46 @@ def load(self, directory, **kwargs):
self.compartments.append(new_comp.path)

self._update_resolver()

def make_plot(self, compartment, ax=None, ylabel=None, xlabel=None, title=None, n=None, plot_func=None):
vals = self.view(compartment)

if n is None:
n = vals.shape[2]
if title is None:
title = compartment.name.split("/")[0] + " " + compartment.display_name

if ylabel is None:
_ylabel = compartment.units
elif ylabel:
_ylabel = ylabel
else:
_ylabel = None

if xlabel is None:
_xlabel = "Time Steps"
elif xlabel:
_xlabel = xlabel
else:
_xlabel = None

if ax is None:
_ax = plt
_ax.title(title)
if _ylabel:
_ax.ylabel(_ylabel)
if _xlabel:
_ax.xlabel(_xlabel)
else:
_ax = ax
_ax.set_title(title)
if _ylabel:
_ax.set_ylabel(_ylabel)
if _xlabel:
_ax.set_xlabel(_xlabel)

if plot_func is None:
for k in range(n):
_ax.plot(vals[:, 0, k])
else:
plot_func(vals, ax=_ax)
Loading