Skip to content

Commit

Permalink
TVB-2687 Fix display of Simulation Operation Params in the overlay, b…
Browse files Browse the repository at this point in the history
…y including the model and monitors forms
  • Loading branch information
liadomide committed Jun 16, 2020
1 parent c1e4679 commit 1b91070
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 55 deletions.
32 changes: 16 additions & 16 deletions framework_tvb/tvb/adapters/simulator/monitor_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ def get_form_for_monitor(monitor_class):

class MonitorForm(Form):

def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):
super(MonitorForm, self).__init__(prefix)
def __init__(self, variables_of_interest_indexes={}, prefix='', project_id=None):
super(MonitorForm, self).__init__(prefix, project_id)
self.project_id = project_id
self.period = ScalarField(Monitor.period, self)
self.variables_of_interest_indexes = variables_of_interest_indexes
self.variables_of_interest = MultiSelectField(List(of=str, label='Model Variables to watch',
choices=tuple(self.variables_of_interest_indexes.keys())),
choices=tuple(self.variables_of_interest_indexes.keys())),
self, name='variables_of_interest')

def fill_from_trait(self, trait):
Expand All @@ -112,44 +112,44 @@ def fill_trait(self, datatype):
super(MonitorForm, self).fill_trait(datatype)
datatype.variables_of_interest = numpy.array(list(self.variables_of_interest_indexes.values()))

#TODO: We should review the code here, we could probably reduce the number of classes that are used here
# TODO: We should review the code here, we could probably reduce the number of classes that are used here


class RawMonitorForm(Form):

def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):
super(RawMonitorForm, self).__init__(variables_of_interest_indexes, prefix, project_id)
def __init__(self, variables_of_interest_indexes={}, prefix='', project_id=None):
super(RawMonitorForm, self).__init__(prefix, project_id)


class SubSampleMonitorForm(MonitorForm):

def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):
def __init__(self, variables_of_interest_indexes={}, prefix='', project_id=None):
super(SubSampleMonitorForm, self).__init__(variables_of_interest_indexes, prefix, project_id)


class SpatialAverageMonitorForm(MonitorForm):

def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):
def __init__(self, variables_of_interest_indexes={}, prefix='', project_id=None):
super(SpatialAverageMonitorForm, self).__init__(variables_of_interest_indexes, prefix, project_id)
self.spatial_mask = ArrayField(SpatialAverage.spatial_mask, self)
self.default_mask = ScalarField(SpatialAverage.default_mask, self)


class GlobalAverageMonitorForm(MonitorForm):

def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):
def __init__(self, variables_of_interest_indexes={}, prefix='', project_id=None):
super(GlobalAverageMonitorForm, self).__init__(variables_of_interest_indexes, prefix, project_id)


class TemporalAverageMonitorForm(MonitorForm):

def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):
def __init__(self, variables_of_interest_indexes={}, prefix='', project_id=None):
super(TemporalAverageMonitorForm, self).__init__(variables_of_interest_indexes, prefix, project_id)


class ProjectionMonitorForm(MonitorForm):

def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):
def __init__(self, variables_of_interest_indexes={}, prefix='', project_id=None):
super(ProjectionMonitorForm, self).__init__(variables_of_interest_indexes, prefix, project_id)
self.region_mapping = DataTypeSelectField(RegionMappingIndex, self, name='region_mapping', required=True,
label=Projection.region_mapping.label,
Expand All @@ -159,7 +159,7 @@ def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):

class EEGMonitorForm(ProjectionMonitorForm):

def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):
def __init__(self, variables_of_interest_indexes={}, prefix='', project_id=None):
super(EEGMonitorForm, self).__init__(variables_of_interest_indexes, prefix, project_id)

sensor_filter = FilterChain(fields=[FilterChain.datatype + '.sensors_type'], operations=["=="],
Expand All @@ -179,7 +179,7 @@ def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):

class MEGMonitorForm(ProjectionMonitorForm):

def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):
def __init__(self, variables_of_interest_indexes={}, prefix='', project_id=None):
super(MEGMonitorForm, self).__init__(variables_of_interest_indexes, prefix, project_id)

sensor_filter = FilterChain(fields=[FilterChain.datatype + '.sensors_type'], operations=["=="],
Expand All @@ -197,7 +197,7 @@ def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):

class iEEGMonitorForm(ProjectionMonitorForm):

def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):
def __init__(self, variables_of_interest_indexes={}, prefix='', project_id=None):
super(iEEGMonitorForm, self).__init__(variables_of_interest_indexes, prefix, project_id)

sensor_filter = FilterChain(fields=[FilterChain.datatype + '.sensors_type'], operations=["=="],
Expand All @@ -216,7 +216,7 @@ def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):

class BoldMonitorForm(MonitorForm):

def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):
def __init__(self, variables_of_interest_indexes={}, prefix='', project_id=None):
super(BoldMonitorForm, self).__init__(variables_of_interest_indexes, prefix, project_id)
self.hrf_kernel_choices = get_ui_name_to_monitor_equation_dict()
default_hrf_kernel = list(self.hrf_kernel_choices.values())[0]
Expand All @@ -237,5 +237,5 @@ def fill_from_trait(self, trait):

class BoldRegionROIMonitorForm(BoldMonitorForm):

def __init__(self, variables_of_interest_indexes, prefix='', project_id=None):
def __init__(self, variables_of_interest_indexes={}, prefix='', project_id=None):
super(BoldRegionROIMonitorForm, self).__init__(variables_of_interest_indexes, prefix, project_id)
26 changes: 22 additions & 4 deletions framework_tvb/tvb/adapters/simulator/simulator_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
"""
import json
from tvb.adapters.simulator.model_forms import get_model_to_form_dict
from tvb.adapters.simulator.monitor_forms import get_monitor_to_form_dict
from tvb.adapters.simulator.simulator_fragments import *
from tvb.adapters.simulator.coupling_forms import get_ui_name_to_coupling_dict
from tvb.adapters.datatypes.db.simulation_history import SimulationHistoryIndex
Expand Down Expand Up @@ -131,10 +133,26 @@ def __init__(self):
def get_form_class(self):
return SimulatorAdapterForm

@staticmethod
def get_simulator_fragments():
return [SimulatorSurfaceFragment, SimulatorRMFragment, SimulatorStimulusFragment, SimulatorModelFragment,
SimulatorIntegratorFragment, SimulatorMonitorFragment, SimulatorFinalFragment]
def get_adapter_fragments(self, view_model):
# type (SimulatorAdapterModel) -> dict
forms = {None: [SimulatorSurfaceFragment, SimulatorRMFragment, SimulatorStimulusFragment,
SimulatorModelFragment, SimulatorIntegratorFragment, SimulatorMonitorFragment,
SimulatorFinalFragment]}

current_model_class = type(view_model.model)
all_model_forms = get_model_to_form_dict()
forms["model"] = [all_model_forms.get(current_model_class)]

all_monitor_forms = get_monitor_to_form_dict()
selected_monitor_forms = []
for monitor in view_model.monitors:
current_monitor_class = type(monitor)
selected_monitor_forms.append(all_monitor_forms.get(current_monitor_class))

forms["monitors"] = selected_monitor_forms
# Not sure if where we should in fact include the entire tree, or it will become too tedious.
# For now I think it is ok if we rename this section "Summary" and filter what is shown
return forms

def load_view_model(self, operation):
storage_path = self.file_handler.get_project_folder(operation.project, str(operation.id))
Expand Down
9 changes: 9 additions & 0 deletions framework_tvb/tvb/core/adapters/abcadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,15 @@ def get_form(self):
def get_form_class(self):
return None

def get_adapter_fragments(self, view_model):
"""
The result will be used for introspecting and checking operation changed input
params from the defaults, to show in web gui.
:return: a list of ABCAdapterForm classes, in case the current Adapter GUI
will be composed of multiple sub-forms.
"""
return {}

def get_view_model_class(self):
return self.get_form_class().get_view_model()

Expand Down
69 changes: 35 additions & 34 deletions framework_tvb/tvb/core/adapters/inputs_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,31 @@
def _review_operation_inputs_for_adapter_model(form_fields, form_model, view_model):
changed_attr = {}
inputs_datatypes = []

for field in form_fields:
if not isinstance(field, TraitDataTypeSelectField) and not isinstance(field, DataTypeSelectField):
attr_default = getattr(form_model, field.name)
attr_vm = getattr(view_model, field.name)
if attr_vm != attr_default:
if isinstance(attr_default, float) or isinstance(attr_default, str):
changed_attr[field.label] = attr_vm
else:
changed_attr[field.label] = attr_vm.title
else:
if not hasattr(view_model, field.name):
continue

if isinstance(field, TraitDataTypeSelectField) or isinstance(field, DataTypeSelectField):

attr_vm = getattr(view_model, field.name)
data_type = ABCAdapter.load_entity_by_gid(attr_vm)
if attr_vm:
changed_attr[field.label] = data_type.display_name
inputs_datatypes.append(data_type)

else:
attr_default = getattr(form_model, field.name)
attr_vm = getattr(view_model, field.name)
if attr_vm != attr_default:
if isinstance(attr_vm, float) or isinstance(attr_vm, int) or isinstance(attr_vm, str):
changed_attr[field.label] = attr_vm
elif isinstance(attr_vm, tuple) or isinstance(attr_vm, list):
changed_attr[field.label] = ', '.join([str(sub_attr) for sub_attr in attr_vm])
else:
# All HasTraits instances will show as being different than default, even if the same,
changed_attr[field.label] = str(attr_vm)

return inputs_datatypes, changed_attr


Expand All @@ -63,32 +72,24 @@ def review_operation_inputs_from_adapter(adapter, operation):
form_model = adapter.get_view_model_class()()
form_fields = adapter.get_form_class()().fields

if 'SimulatorAdapter' in operation.algorithm.classname:
fragments = adapter.get_simulator_fragments()
inputs_datatypes, changed_attr = _review_operation_inputs_for_adapter_model(form_fields, form_model, view_model)
inputs_datatypes, changed_attr = _review_operation_inputs_for_adapter_model(form_fields, form_model, view_model)

fragments_dict = adapter.get_adapter_fragments(view_model)
# The Simulator, for example will have Fragments
for path, fragments in fragments_dict.items():
if path is None:
fragment_defaults = form_model
fragment_model = view_model
else:
fragment_defaults = getattr(form_model, path)
fragment_model = getattr(view_model, path)

for fragment in fragments:
fragment_fields = fragment().fields
for field in fragment_fields:
if hasattr(view_model, field.name):
if not isinstance(field, TraitDataTypeSelectField) and not isinstance(field, DataTypeSelectField):
attr_default = getattr(form_model, field.name)
attr_vm = getattr(view_model, field.name)
if attr_vm != attr_default:
if isinstance(attr_default, float) or isinstance(attr_default, str):
changed_attr[field.label] = attr_vm
else:
if not isinstance(attr_default, tuple):
changed_attr[field.label] = attr_vm.title
else:
for sub_attr in attr_default:
changed_attr[field.label] = sub_attr.title
else:
attr_vm = getattr(view_model, field.name)
data_type = ABCAdapter.load_entity_by_gid(attr_vm)
if attr_vm:
changed_attr[field.label] = data_type.display_name
inputs_datatypes.append(data_type)
else:
inputs_datatypes, changed_attr = _review_operation_inputs_for_adapter_model(form_fields, form_model, view_model)

part_dts, part_changed = _review_operation_inputs_for_adapter_model(fragment_fields,
fragment_defaults, fragment_model)
inputs_datatypes.extend(part_dts)
changed_attr.update(part_changed)

return inputs_datatypes, changed_attr
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
{% endif %}

{% if (loop.index0 == 0) and ((nodeFields | length) > 1) and 'operation' in nodeType %}
<legend>Changed Input Parameters</legend>
<legend>Summary Input Parameters</legend>
{% endif %}

<dl>
Expand Down

0 comments on commit 1b91070

Please sign in to comment.