Skip to content

Commit

Permalink
-
Browse files Browse the repository at this point in the history
  • Loading branch information
jdcpni committed Jan 9, 2022
1 parent 3d27c00 commit 4810379
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions tests/composition/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,13 +764,16 @@ class TestControlMechanisms:
'has \'state_features\' specified ([\'EXT[OutputPort-0]\']) that are missing from the '
'Composition or any nested within it."'
]
state_feature_specs = ['state_feat_specified', 'misplaced_shadow', 'ext_shadow', 'ext_output_port']

state_feature_specs = ['legal_feature', 'misplaced_shadow', 'ext_shadow', 'ext_output_port']

state_feature_args = [
(state_feature_specs[0], messages[0], UserWarning),
(state_feature_specs[1], messages[1], pnl.CompositionError),
(state_feature_specs[2], messages[2], pnl.OptimizationControlMechanismError),
(state_feature_specs[3], messages[3], pnl.OptimizationControlMechanismError)
]

@pytest.mark.control
@pytest.mark.parametrize('state_feature_args', state_feature_args, ids=[x for x in state_feature_specs])
def test_ocm_state_input_ports_warnings_and_errors(self, state_feature_args):
Expand All @@ -785,23 +788,19 @@ def test_ocm_state_input_ports_warnings_and_errors(self, state_feature_args):
ocomp = pnl.Composition(pathways=[icomp], name='OUTER COMP')
ocomp.add_linear_processing_pathway([oa,oc])
ocomp.add_linear_processing_pathway([ob,oc])
state_features_dict = {'state_feat_specified':ia.input_port,
state_features_dict = {'legal_feature':ia.input_port,
'misplaced_shadow':ib.input_port,
'ext_shadow':ext.input_port,
'ext_output_port':ext.output_port}
state_features = state_features_dict[state_feature_args[0]]
message = state_feature_args[1]
ocm = pnl.OptimizationControlMechanism(
state_features=state_features,
objective_mechanism=[ic,ib],
function=pnl.GridSearch(),
control_signals=[pnl.ControlSignal(modulates=(pnl.SLOPE,ia),
allocation_samples=[10, 20, 30]),
pnl.ControlSignal(modulates=(pnl.INTERCEPT,oc),
allocation_samples=[10, 20, 30]),
]
)
assert True
ocm = pnl.OptimizationControlMechanism(state_features=state_features,
objective_mechanism=[ic,ib],
function=pnl.GridSearch(),
control_signals=[pnl.ControlSignal(modulates=(pnl.SLOPE,ia),
allocation_samples=[10, 20, 30]),
pnl.ControlSignal(modulates=(pnl.INTERCEPT,oc),
allocation_samples=[10, 20, 30])])
if state_feature_args[2] is UserWarning:
with pytest.warns(UserWarning) as warning:
ocomp.add_controller(ocm)
Expand All @@ -813,10 +812,6 @@ def test_ocm_state_input_ports_warnings_and_errors(self, state_feature_args):
ocomp.run()
assert message in str(error.value)





@pytest.mark.control
def test_ocm_state_and_state_dict(self):
ia = pnl.ProcessingMechanism(name='IA')
Expand Down

0 comments on commit 4810379

Please sign in to comment.