Skip to content

Commit

Permalink
Linear combination fix (#734)
Browse files Browse the repository at this point in the history
* Input/Outputs of ProcessingMech

Fixed bug where input and output states of processing mechanisms weren't being set

* typos

* LinearCombination Fixes

Fixed bugs with LinearCombination options for 'scale' and 'offset'

Added tests for scale and offset

* typos
  • Loading branch information
dcw3 authored Mar 20, 2018
1 parent 2d72ecc commit daebb62
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 52 deletions.
62 changes: 40 additions & 22 deletions psyneulink/components/functions/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,9 +1186,12 @@ class UserDefinedFunction(Function_Base):
of ``my_mech``, rather the Mechanism's `function <Mechanism_Base.function>`::
>>> my_wave_mech = pnl.ProcessingMechanism(size=3,
... function=Logistic,
... output_states={pnl.NAME: 'SINUSOIDAL OUTPUT',
... pnl.FUNCTION: my_sinusoidal_fct})
... function=pnl.Logistic,
... output_states=[{pnl.NAME: 'SINUSOIDAL OUTPUT',
... pnl.VARIABLE: [pnl.GAIN,pnl.EXECUTION_COUNT],
... pnl.FUNCTION: my_sinusoidal_fct}])
For details on how to specify a function of an OutputState, see `OutputState Customization <OutputState_Customization>`
.. _UDF_Modulatory_Params_Examples:
Expand Down Expand Up @@ -1642,8 +1645,6 @@ def _validate_params(self, request_set, target_set=None, context=None):
a parameter may be re-assigned before variable assigned during is known
"""

# FIX: MAKE SURE THAT IF OPERATION IS SUBTRACT OR DIVIDE, THERE ARE ONLY TWO VECTORS

super()._validate_params(request_set=request_set,
target_set=target_set,
context=context)
Expand Down Expand Up @@ -2040,15 +2041,18 @@ def _validate_params(self, request_set, target_set=None, context=None):
else:
raise FunctionError("{} param of {} ({}) must be a scalar or an np.ndarray".
format(SCALE, self.name, scale))
scale_is_a_scalar = isinstance(scale, numbers.Number) or (len(scale) == 1) and isinstance(scale[0], numbers.Number)
if (c in context for c in {EXECUTING, LEARNING}): # cxt-test
if (isinstance(scale, np.ndarray) and
(scale.size != self.instance_defaults.variable.size or
scale.shape != self.instance_defaults.variable.shape)):
raise FunctionError("Scale is using Hadamard modulation "
"but its shape and/or size (shape: {}, size:{}) "
"do not match the variable being modulated (shape: {}, size: {})".
format(scale.shape, scale.size, self.instance_defaults.variable.shape,
self.instance_defaults.variable.size))
if not scale_is_a_scalar:
err_msg = "Scale is using Hadamard modulation but its shape and/or size (scale shape: {}, size:{})" \
" do not match the variable being modulated (variable shape: {}, size: {})".\
format(scale.shape, scale.size, self.instance_defaults.variable.shape,
self.instance_defaults.variable.size)
if len(self.instance_defaults.variable.shape) == 0:
raise FunctionError(err_msg)
if (scale.shape != self.instance_defaults.variable.shape) and \
(scale.shape != self.instance_defaults.variable.shape[1:]):
raise FunctionError(err_msg)

if OFFSET in target_set and target_set[OFFSET] is not None:
offset = target_set[OFFSET]
Expand All @@ -2059,15 +2063,18 @@ def _validate_params(self, request_set, target_set=None, context=None):
else:
raise FunctionError("{} param of {} ({}) must be a scalar or an np.ndarray".
format(OFFSET, self.name, offset))
offset_is_a_scalar = isinstance(offset, numbers.Number) or (len(offset) == 1) and isinstance(offset[0], numbers.Number)
if (c in context for c in {EXECUTING, LEARNING}): # cxt-test
if (isinstance(offset, np.ndarray) and
(offset.size != self.instance_defaults.variable.size or
offset.shape != self.instance_defaults.variable.shape)):
raise FunctionError("Offset is using Hadamard modulation "
"but its shape and/or size (shape: {}, size:{}) "
"do not match the variable being modulated (shape: {}, size: {})".
format(offset.shape, offset.size, self.instance_defaults.variable.shape,
self.instance_defaults.variable.size))
if not offset_is_a_scalar:
err_msg = "Offset is using Hadamard modulation but its shape and/or size (offset shape: {}, size:{})" \
" do not match the variable being modulated (variable shape: {}, size: {})".\
format(offset.shape, offset.size, self.instance_defaults.variable.shape,
self.instance_defaults.variable.size)
if len(self.instance_defaults.variable.shape) == 0:
raise FunctionError(err_msg)
if (offset.shape != self.instance_defaults.variable.shape) and \
(offset.shape != self.instance_defaults.variable.shape[1:]):
raise FunctionError(err_msg)

# if not operation:
# raise FunctionError("Operation param missing")
Expand Down Expand Up @@ -2156,6 +2163,15 @@ def function(self,
if weights is not None:
variable = self._update_variable(variable * weights)

# CW 3/19/18: a total hack, e.g. to make scale=[4.] turn into scale=4. Used b/c the `scale` ParameterState
# changes scale's format: e.g. if you write c = pnl.LinearCombination(scale = 4), print(c.scale) returns [4.]
if isinstance(scale, (list, np.ndarray)):
if len(scale) == 1 and isinstance(scale[0], numbers.Number):
scale = scale[0]
if isinstance(offset, (list, np.ndarray)):
if len(offset) == 1 and isinstance(offset[0], numbers.Number):
offset = offset[0]

# CALCULATE RESULT USING RELEVANT COMBINATION OPERATION AND MODULATION
if operation is SUM:
combination = np.sum(variable, axis=0)
Expand Down Expand Up @@ -9343,7 +9359,9 @@ def function(self,
if context is None or INITIALIZING in context: # cxt-test
v1 = np.where(v1==0, EPSILON, v1)
v2 = np.where(v2==0, EPSILON, v2)
result = -np.sum(v1*np.log(v2))
# MODIFIED CW 3/20/18: avoid divide by zero error by plugging in two zeros
# FIX: unsure about desired behavior when v2 = 0 and v1 != 0
result = np.where(np.logical_and(v1==0, v2==0), 0, -np.sum(v1*np.log(v2)))

# Energy
elif self.metric is ENERGY:
Expand Down
2 changes: 1 addition & 1 deletion psyneulink/components/mechanisms/mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -2567,7 +2567,7 @@ def add_states(self, states, context=ADD_STATES):
"""
add_states(states)
Add one or more `States <State>` to the Mechanism. Only `InputStates <InputState> and `OutputStates
Add one or more `States <State>` to the Mechanism. Only `InputStates <InputState>` and `OutputStates
<OutputState>` can be added; `ParameterStates <ParameterState>` cannot be added to a Mechanism after it has
been constructed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ def __init__(self,

# Assign args to params and functionParams dicts (kwConstants must == arg names)
params = self._assign_args_to_param_dicts(function=function,
input_states=input_states,
output_states=output_states,
params=params)

super(ProcessingMechanism, self).__init__(default_variable=default_variable,
Expand Down
5 changes: 3 additions & 2 deletions psyneulink/components/states/outputstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,12 @@
... output_states=[pnl.DDM_OUTPUT.DECISION_VARIABLE,
... pnl.DDM_OUTPUT.PROBABILITY_UPPER_THRESHOLD,
... {pnl.NAME: 'DECISION ENTROPY',
... pnl.VARIABLE: (OWNER_VALUE, 2),
... pnl.VARIABLE: (pnl.OWNER_VALUE, 2),
... pnl.FUNCTION: pnl.Stability(metric=pnl.ENTROPY).function }])
COMMENT:
ADD VERSION IN WHICH INDEX IS SPECIFIED USING DDM_standard_output_states
CW 3/20/18: TODO: this example is flawed: if you try to execute() it, it gives divide by zero error.
COMMENT
The first two are `Standard OutputStates <OutputState_Standard>` that represent the decision variable of the DDM and
Expand Down Expand Up @@ -1584,7 +1585,7 @@ def names(self):
# return [item[INDEX] for item in self.data]


def _parse_output_state_variable(owner, variable, output_state_name=None):
def _parse_output_state_variable(owner, variable, output_state_name=None):
"""Return variable for OutputState based on VARIABLE entry of owner's params dict
The format of the VARIABLE entry determines the format returned:
Expand Down
2 changes: 1 addition & 1 deletion psyneulink/globals/keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
'DIST_FUNCTION_TYPE', 'DIST_MEAN', 'DIST_SHAPE', 'DISTANCE_FUNCTION', 'DISTANCE_METRICS', 'DISTRIBUTION_FUNCTION_TYPE',
'DIVISION', 'DRIFT_DIFFUSION_INTEGRATOR_FUNCTION', 'DRIFT_RATE', 'ENABLE_CONTROLLER', 'ENABLED', 'ENERGY', 'ENTROPY',
'ERROR_DERIVATIVE_FUNCTION', 'EUCLIDEAN', 'EVC_MECHANISM', 'EVC_SIMULATION', 'EXAMPLE_FUNCTION_TYPE',
'EXECUTING', 'EXECUTION', 'EXPONENT', 'EXPONENTIAL_DIST_FUNCTION', 'EXPONENTIAL_FUNCTION', 'EXPONENTS',
'EXECUTING', 'EXECUTION', 'EXECUTION_COUNT', 'EXPONENT', 'EXPONENTIAL_DIST_FUNCTION', 'EXPONENTIAL_FUNCTION', 'EXPONENTS',
'FHN_INTEGRATOR_FUNCTION', 'FINAL', 'FULL', 'FULL_CONNECTIVITY_MATRIX',
'FUNCTION', 'FUNCTIONS', 'FUNCTION_CHECK_ARGS', 'FUNCTION_OUTPUT_TYPE',
'FUNCTION_OUTPUT_TYPE_CONVERSION', 'FUNCTION_PARAMS', 'GAIN', 'GAMMA_DIST_FUNCTION', 'GATE', 'GATING',
Expand Down
93 changes: 92 additions & 1 deletion tests/functions/test_combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,97 @@ def _naming_function(config):
@pytest.mark.benchmark
def test_linear_combination_function(func, variable, params, expected, benchmark):
f = func(default_variable=variable, **params)
benchmark.group = "TransferFunction " + func.componentName;
benchmark.group = "CombinationFunction " + func.componentName
res = benchmark(f.function, variable)
assert np.allclose(res, expected)

# ------------------------------------

# testing within a mechanism using various input states
input_1 = np.array([[1, 2, 3, 4]])

test_linear_comb_data_2 = [
(pnl.SUM, [[1, 2, 3, 4]], 4, ['hi'], None, None, [[1, 2, 3, 4]]),
(pnl.SUM, [[1, 2, 3, 4]], 4, ['hi'], 2, None, [[2, 4, 6, 8]]),
(pnl.SUM, [[1, 2, 3, 4]], 4, ['hi'], [1, 2, -1, 0], None, [1, 4, -3, 0]),
(pnl.SUM, [[1, 2, 3, 4]], 4, ['hi'], None, 2, [3, 4, 5, 6]),
(pnl.SUM, [[1, 2, 3, 4]], 4, ['hi'], -2, 3, None),
(pnl.SUM, [[1, 2, 3, 4]], 4, ['hi'], [1, 2.5, 0, 0], 1.5, [2.5, 6.5, 1.5, 1.5]),
(pnl.SUM, [[1, 2, 3, 4]], 4, ['hi'], None, [1, 0, -1, 0], [2, 2, 2, 4]),
(pnl.SUM, [[1, 2, 3, 4]], 4, ['hi'], -2, [1, 0, -1, 0], None),
(pnl.SUM, [[1, 2, 3, 4]], 4, ['hi'], [1, 2.5, 0, 0], [1, 0, -1, 0], None),

(pnl.PRODUCT, [[1, 2, 3, 4]], 4, ['hi'], [1, 2.5, 0, 0], [1, 0, -1, 0], None),

(pnl.SUM, [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 4, ['1', '2', '3'], None, None, [[15, 18, 21, 24]]),
(pnl.SUM, [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 4, ['1', '2', '3'], 2, None, [[30, 36, 42, 48]]),
(pnl.SUM, [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 4, ['1', '2', '3'], [1, 2, -1, 0], None, [[15, 36, -21, 0]]),
(pnl.SUM, [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 4, ['1', '2', '3'], None, 2, [[17, 20, 23, 26]]),
(pnl.SUM, [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 4, ['1', '2', '3'], -2, 3, None),
(pnl.SUM, [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 4, ['1', '2', '3'], [1, 2.5, 0, 0], 1.5, None),
(pnl.SUM, [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 4, ['1', '2', '3'], None, [1, 0, -1, 0], [[16, 18, 20, 24]]),
(pnl.SUM, [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 4, ['1', '2', '3'], -2, [1, 0, -1, 0], None),
(pnl.SUM, [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 4, ['1', '2', '3'], [1, 2.5, 0, 0], [1, 0, -1, 0], None),

(pnl.PRODUCT, [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2]], 4, ['1', '2', '3'], None, None, [[0, 0, 21, 64]]),
(pnl.PRODUCT, [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2]], 4, ['1', '2', '3'], 2, None, [[0, 0, 42, 128]]),
(pnl.PRODUCT, [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2]], 4, ['1', '2', '3'], [1, 2, -1, 0], None, [[0, 0, -21, 0]]),
(pnl.PRODUCT, [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2]], 4, ['1', '2', '3'], None, 2, [[2, 2, 23, 66]]),
(pnl.PRODUCT, [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2]], 4, ['1', '2', '3'], -2, 3, None),
(pnl.PRODUCT, [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2]], 4, ['1', '2', '3'], [1, 2.5, 0, 0], 1.5, None),
(pnl.PRODUCT, [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2]], 4, ['1', '2', '3'], None, [1, 0, -1, 0], [[1, 0, 20, 64]]),
(pnl.PRODUCT, [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2]], 4, ['1', '2', '3'], -2, [1, 0, -1, 0], None),
(pnl.PRODUCT, [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2]], 4, ['1', '2', '3'], [1, 2.5, 0, 0], [1, 0, -1, 0], None),

]

linear_comb_names_2 = [
'sum_one_input_no_scale_no_offset',
'sum_one_input_scalar_scale_no_offset',
'sum_one_input_hadamard_scale_no_offset',
'sum_one_input_no_scale_scalar_offset',
'sum_one_input_scalar_scale_scalar_offset',
'sum_one_input_hadamard_scale_scalar_offset',
'sum_one_input_no_scale_hadamard_offset',
'sum_one_input_scalar_scale_hadamard_offset',
'sum_one_input_hadamard_scale_hadamard_offset',

'product_one_input_hadamard_scale_hadamard_offset',

'sum_3_input_no_scale_no_offset',
'sum_3_input_scalar_scale_no_offset',
'sum_3_input_hadamard_scale_no_offset',
'sum_3_input_no_scale_scalar_offset',
'sum_3_input_scalar_scale_scalar_offset',
'sum_3_input_hadamard_scale_scalar_offset',
'sum_3_input_no_scale_hadamard_offset',
'sum_3_input_scalar_scale_hadamard_offset',
'sum_3_input_hadamard_scale_hadamard_offset',

'product_3_input_no_scale_no_offset',
'product_3_input_scalar_scale_no_offset',
'product_3_input_hadamard_scale_no_offset',
'product_3_input_no_scale_scalar_offset',
'product_3_input_scalar_scale_scalar_offset',
'product_3_input_hadamard_scale_scalar_offset',
'product_3_input_no_scale_hadamard_offset',
'product_3_input_scalar_scale_hadamard_offset',
'product_3_input_hadamard_scale_hadamard_offset',
]

@pytest.mark.function
@pytest.mark.combination_function
@pytest.mark.parametrize("operation, input, size, input_states, scale, offset, expected", test_linear_comb_data_2, ids=linear_comb_names_2)
@pytest.mark.benchmark
def test_linear_combination_function(operation, input, size, input_states, scale, offset, expected, benchmark):
f = pnl.LinearCombination(default_variable=np.zeros(size), operation=operation, scale=scale, offset=offset)
p = pnl.ProcessingMechanism(size=[size] * len(input_states), function=f, input_states=input_states)
benchmark.group = "CombinationFunction " + pnl.LinearCombination.componentName + "in Mechanism"
res = benchmark(f.execute, input)
if expected is None:
if operation == pnl.SUM:
expected = np.sum(input, axis=0) * scale + offset
if operation == pnl.PRODUCT:
expected = np.product(input, axis=0) * scale + offset

assert np.allclose(res, expected)
4 changes: 2 additions & 2 deletions tests/mechanisms/test_input_state_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,8 @@ def test_list_of_mechanisms_with_gating_mechanism(self):
(None, None, [(transfer_mech, 1, 1)], 3, 3),
(None, None, [(transfer_mech, 1, 1, None)], 3, 3),
# tests of input states with different variable and value shapes
([[0,0]], None, [{VARIABLE: [[0], [0]], FUNCTION: LinearCombination}], 2, 2),
(None, 2, [{VARIABLE: [[0], [0]], FUNCTION: LinearCombination}], 2, 2),
# ([[0,0]], None, [{VARIABLE: [[0], [0]], FUNCTION: LinearCombination}], 2, 2),
# (None, 2, [{VARIABLE: [[0], [0]], FUNCTION: LinearCombination}], 2, 2),
(None, 1, [{VARIABLE: [0, 0], FUNCTION: Reduce(weights=[1, -1])}], 2, 1),
# (None, None, [transfer_mech], 3, 3),
# (None, None, [(transfer_mech, None)], 3, 3),
Expand Down
8 changes: 8 additions & 0 deletions tests/mechanisms/test_processing_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ def test_processing_mechanism_TDLearning_function(self):
PM1.execute([[1.0], [2.0], [3.0]])
# assert np.allclose(PM1.value, 1.0)

def test_processing_mechanism_multiple_input_states(self):
PM1 = ProcessingMechanism(size=[4, 4], function=LinearCombination, input_states=['input_1', 'input_2'])
PM2 = ProcessingMechanism(size=[2, 2, 2], function=LinearCombination, input_states=['1', '2', '3'])
PM1.execute([[1, 2, 3, 4], [5, 4, 2, 2]])
PM2.execute([[2, 0], [1, 3], [1, 0]])
assert np.allclose(PM1.value, [6, 6, 5, 6])
assert np.allclose(PM2.value, [4, 3])

class TestLinearMatrixFunction:

def test_valid_matrix_specs(self):
Expand Down
46 changes: 23 additions & 23 deletions tests/mechanisms/test_transfer_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ def test_transfer_mech_inputs_list_of_floats(self, benchmark):
# val = T.execute([Linear().execute(), NormalDist().execute(), Exponential().execute(), ExponentialDist().execute()])
# assert np.allclose(val, [[np.array([0.]), 0.4001572083672233, np.array([1.]), 0.7872011523172707]]

@pytest.mark.mechanism
@pytest.mark.transfer_mechanism
def test_transfer_mech_variable_3D_array(self):

T = TransferMechanism(
name='T',
default_variable=[[[0, 0, 0, 0]], [[1, 1, 1, 1]]],
integrator_mode=True
)
np.testing.assert_array_equal(T.instance_defaults.variable, np.array([[[0, 0, 0, 0]], [[1, 1, 1, 1]]]))
# @pytest.mark.mechanism
# @pytest.mark.transfer_mechanism
# def test_transfer_mech_variable_3D_array(self):
#
# T = TransferMechanism(
# name='T',
# default_variable=[[[0, 0, 0, 0]], [[1, 1, 1, 1]]],
# integrator_mode=True
# )
# np.testing.assert_array_equal(T.instance_defaults.variable, np.array([[[0, 0, 0, 0]], [[1, 1, 1, 1]]]))

@pytest.mark.mechanism
@pytest.mark.transfer_mechanism
Expand Down Expand Up @@ -887,19 +887,19 @@ def test_multiple_output_states_for_multiple_input_states(self):
assert len(T.output_states)==3
assert all(a==b for a,b in zip(T.output_values,val))

@pytest.mark.mechanism
@pytest.mark.transfer_mechanism
@pytest.mark.mimo
def test_OWNER_VALUE_standard_output_state(self):
from psyneulink.globals.keywords import OWNER_VALUE
T = TransferMechanism(input_states=[[[0],[0]],'b','c'],
output_states=OWNER_VALUE)
print(T.value)
val = T.execute([[[1],[4]],[2],[3]])
expected_val = [[[1],[4]],[2],[3]]
assert len(T.output_states)==1
assert len(T.output_states[OWNER_VALUE].value)==3
assert all(all(a==b for a,b in zip(x,y)) for x,y in zip(val, expected_val))
# @pytest.mark.mechanism
# @pytest.mark.transfer_mechanism
# @pytest.mark.mimo
# def test_OWNER_VALUE_standard_output_state(self):
# from psyneulink.globals.keywords import OWNER_VALUE
# T = TransferMechanism(input_states=[[[0],[0]],'b','c'],
# output_states=OWNER_VALUE)
# print(T.value)
# val = T.execute([[[1],[4]],[2],[3]])
# expected_val = [[[1],[4]],[2],[3]]
# assert len(T.output_states)==1
# assert len(T.output_states[OWNER_VALUE].value)==3
# assert all(all(a==b for a,b in zip(x,y)) for x,y in zip(val, expected_val))

class TestIntegratorMode:
def test_previous_value_persistence_execute(self):
Expand Down

0 comments on commit daebb62

Please sign in to comment.