From 59c9736bc171c055b10bab0b9cbbdd37415acdc3 Mon Sep 17 00:00:00 2001 From: Jan Vesely Date: Wed, 13 Nov 2024 19:48:35 -0500 Subject: [PATCH] llvm/Mechanism: Reinit integrator_function in Mechanism reset if present (#3112) Fixes: ticket #106903 Signed-off-by: Jan Vesely --- .../core/components/mechanisms/mechanism.py | 41 ++++++++++++++++--- tests/mechanisms/test_mechanisms.py | 19 +++++++++ 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/psyneulink/core/components/mechanisms/mechanism.py b/psyneulink/core/components/mechanisms/mechanism.py index e261ce5eb3..aae1969640 100644 --- a/psyneulink/core/components/mechanisms/mechanism.py +++ b/psyneulink/core/components/mechanisms/mechanism.py @@ -3197,12 +3197,41 @@ def _gen_llvm_function_reset(self, ctx, builder, m_base_params, m_state, m_arg_i reinit_in = builder.alloca(reinit_func.args[2].type.pointee, name="reinit_in") reinit_out = builder.alloca(reinit_func.args[3].type.pointee, name="reinit_out") - reinit_base_params, reinit_state = ctx.get_param_or_state_ptr(builder, self, "function", param_struct_ptr=m_base_params, state_struct_ptr=m_state) - reinit_params, builder = self._gen_llvm_param_ports_for_obj( - self.function, reinit_base_params, ctx, builder, m_base_params, m_state, m_arg_in) + reinit_base_params, reinit_state = ctx.get_param_or_state_ptr(builder, + self, + "function", + param_struct_ptr=m_base_params, + state_struct_ptr=m_state) + reinit_params, builder = self._gen_llvm_param_ports_for_obj(self.function, + reinit_base_params, + ctx, + builder, + m_base_params, + m_state, + m_arg_in) builder.call(reinit_func, [reinit_params, reinit_state, reinit_in, reinit_out]) + if hasattr(self, "integrator_function") and getattr(self, "integrator_mode", False): + reinit_func = ctx.import_llvm_function(self.integrator_function, tags=tags) + reinit_in = builder.alloca(reinit_func.args[2].type.pointee, name="integrator_reinit_in") + reinit_out = builder.alloca(reinit_func.args[3].type.pointee, name="integrator_reinit_out") + + reinit_base_params, reinit_state = ctx.get_param_or_state_ptr(builder, + self, + "integrator_function", + param_struct_ptr=m_base_params, + state_struct_ptr=m_state) + reinit_params, builder = self._gen_llvm_param_ports_for_obj(self.integrator_function, + reinit_base_params, + ctx, + builder, + m_base_params, + m_state, + m_arg_in) + + builder.call(reinit_func, [reinit_params, reinit_state, reinit_in, reinit_out]) + return builder def _gen_llvm_function(self, *, extra_args=[], ctx:pnlvm.LLVMBuilderContext, tags:frozenset): @@ -3212,9 +3241,11 @@ def _gen_llvm_function(self, *, extra_args=[], ctx:pnlvm.LLVMBuilderContext, tag Mechanisms need to support "is_finished" execution variant (used by scheduling conditions) on top of the variants supported by Component. """ + + # Call parent "_gen_llvm_function", this should result in calling + # "_gen_llvm_function_body" below if "is_finished" not in tags: - return super()._gen_llvm_function(extra_args=extra_args, ctx=ctx, - tags=tags) + return super()._gen_llvm_function(extra_args=extra_args, ctx=ctx, tags=tags) # Keep all 4 standard arguments to ease invocation args = [ctx.get_param_struct_type(self).as_pointer(), diff --git a/tests/mechanisms/test_mechanisms.py b/tests/mechanisms/test_mechanisms.py index f07b2810aa..c326b5c1a3 100644 --- a/tests/mechanisms/test_mechanisms.py +++ b/tests/mechanisms/test_mechanisms.py @@ -280,3 +280,22 @@ def test_reset_state_transfer_mechanism(self): np.testing.assert_allclose(output_after_saving_state, output_after_reinitialization) np.testing.assert_allclose(original_output, [np.array([[0.5]]), np.array([[0.75]])]) np.testing.assert_allclose(output_after_reinitialization, [np.array([[0.875]]), np.array([[0.9375]])]) + + @pytest.mark.usefixtures("comp_mode_no_llvm") + def test_reset_integrator_function(self, comp_mode): + """This test checks that the Mechanism.integrator_function is reset when the mechanism is""" + + threshold_mech = pnl.TransferMechanism(input_shapes=1, + default_variable=0, + integrator_function=pnl.SimpleIntegrator(rate=1, offset=-0.001), + function=pnl.Linear(intercept=0.06, slope=1), + integrator_mode=True, + execute_until_finished=True, + termination_threshold=10, + reset_stateful_function_when=pnl.AtTrialStart(), + termination_measure=pnl.TimeScale.TRIAL) + comp = pnl.Composition() + comp.add_node(threshold_mech) + + results = comp.run(inputs=[[0.0], [0.0]], execution_mode=comp_mode) + np.testing.assert_allclose(results, [[0.05]])