Skip to content

Commit

Permalink
llvm/Mechanism: Reinit integrator_function in Mechanism reset if pres…
Browse files Browse the repository at this point in the history
…ent (#3112)

Fixes: ticket #106903

Signed-off-by: Jan Vesely <jan.vesely@rutgers.edu>
  • Loading branch information
jvesely authored Nov 14, 2024
1 parent edf7389 commit 59c9736
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
41 changes: 36 additions & 5 deletions psyneulink/core/components/mechanisms/mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -3197,12 +3197,41 @@ def _gen_llvm_function_reset(self, ctx, builder, m_base_params, m_state, m_arg_i
reinit_in = builder.alloca(reinit_func.args[2].type.pointee, name="reinit_in")
reinit_out = builder.alloca(reinit_func.args[3].type.pointee, name="reinit_out")

reinit_base_params, reinit_state = ctx.get_param_or_state_ptr(builder, self, "function", param_struct_ptr=m_base_params, state_struct_ptr=m_state)
reinit_params, builder = self._gen_llvm_param_ports_for_obj(
self.function, reinit_base_params, ctx, builder, m_base_params, m_state, m_arg_in)
reinit_base_params, reinit_state = ctx.get_param_or_state_ptr(builder,
self,
"function",
param_struct_ptr=m_base_params,
state_struct_ptr=m_state)
reinit_params, builder = self._gen_llvm_param_ports_for_obj(self.function,
reinit_base_params,
ctx,
builder,
m_base_params,
m_state,
m_arg_in)

builder.call(reinit_func, [reinit_params, reinit_state, reinit_in, reinit_out])

if hasattr(self, "integrator_function") and getattr(self, "integrator_mode", False):
reinit_func = ctx.import_llvm_function(self.integrator_function, tags=tags)
reinit_in = builder.alloca(reinit_func.args[2].type.pointee, name="integrator_reinit_in")
reinit_out = builder.alloca(reinit_func.args[3].type.pointee, name="integrator_reinit_out")

reinit_base_params, reinit_state = ctx.get_param_or_state_ptr(builder,
self,
"integrator_function",
param_struct_ptr=m_base_params,
state_struct_ptr=m_state)
reinit_params, builder = self._gen_llvm_param_ports_for_obj(self.integrator_function,
reinit_base_params,
ctx,
builder,
m_base_params,
m_state,
m_arg_in)

builder.call(reinit_func, [reinit_params, reinit_state, reinit_in, reinit_out])

return builder

def _gen_llvm_function(self, *, extra_args=[], ctx:pnlvm.LLVMBuilderContext, tags:frozenset):
Expand All @@ -3212,9 +3241,11 @@ def _gen_llvm_function(self, *, extra_args=[], ctx:pnlvm.LLVMBuilderContext, tag
Mechanisms need to support "is_finished" execution variant (used by scheduling conditions)
on top of the variants supported by Component.
"""

# Call parent "_gen_llvm_function", this should result in calling
# "_gen_llvm_function_body" below
if "is_finished" not in tags:
return super()._gen_llvm_function(extra_args=extra_args, ctx=ctx,
tags=tags)
return super()._gen_llvm_function(extra_args=extra_args, ctx=ctx, tags=tags)

# Keep all 4 standard arguments to ease invocation
args = [ctx.get_param_struct_type(self).as_pointer(),
Expand Down
19 changes: 19 additions & 0 deletions tests/mechanisms/test_mechanisms.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,22 @@ def test_reset_state_transfer_mechanism(self):
np.testing.assert_allclose(output_after_saving_state, output_after_reinitialization)
np.testing.assert_allclose(original_output, [np.array([[0.5]]), np.array([[0.75]])])
np.testing.assert_allclose(output_after_reinitialization, [np.array([[0.875]]), np.array([[0.9375]])])

@pytest.mark.usefixtures("comp_mode_no_llvm")
def test_reset_integrator_function(self, comp_mode):
"""This test checks that the Mechanism.integrator_function is reset when the mechanism is"""

threshold_mech = pnl.TransferMechanism(input_shapes=1,
default_variable=0,
integrator_function=pnl.SimpleIntegrator(rate=1, offset=-0.001),
function=pnl.Linear(intercept=0.06, slope=1),
integrator_mode=True,
execute_until_finished=True,
termination_threshold=10,
reset_stateful_function_when=pnl.AtTrialStart(),
termination_measure=pnl.TimeScale.TRIAL)
comp = pnl.Composition()
comp.add_node(threshold_mech)

results = comp.run(inputs=[[0.0], [0.0]], execution_mode=comp_mode)
np.testing.assert_allclose(results, [[0.05]])

0 comments on commit 59c9736

Please sign in to comment.