Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions py_experimenter/experimenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,14 +387,14 @@ def _execution_wrapper(self,
:raises NoExperimentsLeftError: If there are no experiments left to be executed.
:raises DatabaseConnectionError: If an error occurred during the connection to the database.
"""
_, keyfield_values = self.dbconnector.get_experiment_configuration(random_order)
experiment_id, keyfield_values = self.dbconnector.get_experiment_configuration(random_order)

result_field_names = utils.get_result_field_names(self.config)
custom_fields = dict(self.config.items('CUSTOM')) if self.has_section('CUSTOM') else None
table_name = self.get_config_value('PY_EXPERIMENTER', 'table')

result_processor = ResultProcessor(self.config, self.database_credential_file_path, table_name=table_name,
condition=keyfield_values, result_fields=result_field_names)
result_fields=result_field_names, experiment_id = experiment_id)
result_processor._set_name(self.name)
result_processor._set_machine(socket.gethostname())

Expand Down
17 changes: 7 additions & 10 deletions py_experimenter/result_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ class ResultProcessor:
database.
"""

def __init__(self, _config: dict, credential_path, table_name: str, condition: dict, result_fields: List[str]):
def __init__(self, _config: dict, credential_path, table_name: str, result_fields: List[str], experiment_id: int):
self._table_name = table_name
self._where = ' AND '.join([f"{str(key)}='{str(value)}'" for key, value in condition.items()])
self._result_fields = result_fields
self._config = _config
self._timestamp_on_result_fields = utils.timestamps_for_result_fields(self._config)
self._experiment_id_condition = f'ID = {experiment_id}'

if _config['PY_EXPERIMENTER']['provider'] == 'sqlite':
self._dbconnector = DatabaseConnectorLITE(_config)
Expand All @@ -54,7 +54,7 @@ def process_results(self, results: dict) -> None:

keys = self._dbconnector.escape_sql_chars(*list(results.keys()))
values = self._dbconnector.escape_sql_chars(*list(results.values()))
self._dbconnector._update_database(keys=keys, values=values, where=self._where)
self._dbconnector._update_database(keys=keys, values=values, where=self._experiment_id_condition)

@staticmethod
def _add_timestamps_to_results(results: dict, time: datetime) -> List[Tuple[str, object]]:
Expand All @@ -69,19 +69,16 @@ def _change_status(self, status):
time = time.strftime("%m/%d/%Y, %H:%M:%S")

if status == 'done' or status == 'error':
self._dbconnector._update_database(keys=['status', 'end_date'], values=[status, time], where=self._where)
self._dbconnector._update_database(keys=['status', 'end_date'], values=[status, time], where=self._experiment_id_condition)

def _write_error(self, error_msg):
self._dbconnector._update_database(keys=['error'], values=[error_msg], where=self._where)
self._dbconnector._update_database(keys=['error'], values=[error_msg], where=self._experiment_id_condition)

def _set_machine(self, machine_id):
self._dbconnector._update_database(keys=['machine'], values=[machine_id], where=self._where)
self._dbconnector._update_database(keys=['machine'], values=[machine_id], where=self._experiment_id_condition)

def _set_name(self, name):
self._dbconnector._update_database(keys=['name'], values=[name], where=self._where)

def _not_executed_yet(self) -> bool:
return self._dbconnector.not_executed_yet(where=self._where)
self._dbconnector._update_database(keys=['name'], values=[name], where=self._experiment_id_condition)

def _valid_result_fields(self, result_fields):
return set(result_fields).issubset(set(self._result_fields))
22 changes: 10 additions & 12 deletions test/test_result_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,27 @@
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_test_connection')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_create_database_if_not_existing')
@pytest.mark.parametrize(
'config, table_name, condition, result_fields, expected_provider',
'config, table_name, result_fields, expected_provider',
[
(
utils.load_config(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file.cfg')),
'test_table',
{'test': 'condition'},
['result_field_1', 'result_field_2'],
DatabaseConnectorMYSQL
),
(
utils.load_config(os.path.join('test', 'test_config_files', 'load_config_test_file', 'sqlite_test_file.cfg')),
'test_table',
{'test': 'condition'},
['result_field_1', 'result_field_2'],
DatabaseConnectorLITE
),
]
)
def test_init(create_database_if_not_existing_mock, test_connection_mysql, test_connection_sqlite, config, table_name, condition, result_fields, expected_provider):
def test_init(create_database_if_not_existing_mock, test_connection_mysql, test_connection_sqlite, config, table_name, result_fields, expected_provider):
create_database_if_not_existing_mock.return_value = None
test_connection_mysql.return_value = None
test_connection_sqlite.return_value = None
result_processor = ResultProcessor(config, CREDENTIAL_PATH, table_name, condition, result_fields)
result_processor = ResultProcessor(config, CREDENTIAL_PATH, table_name, result_fields, 0)

assert table_name == result_processor._table_name
assert result_fields == result_processor._result_fields
Expand All @@ -60,7 +58,7 @@ def test_init_raises_error(mock_fn):
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_test_connection')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_create_database_if_not_existing')
@pytest.mark.parametrize(
'result_fields, results, error, errorstring',
'result_fields, results, error, errorstring, experiment_id',
[
(
[
Expand All @@ -71,18 +69,18 @@ def test_init_raises_error(mock_fn):
'result_field_2': 'result_field_2_value',
},
InvalidResultFieldError,
f"Invalid result keys: {{'result_field_2'}}"
f"Invalid result keys: {{'result_field_2'}}",
0
),
]
)
def test_process_results_raises_error(create_database_mock, test_connection_mock, result_fields, results, error, errorstring):
def test_process_results_raises_error(create_database_mock, test_connection_mock, result_fields, results, error, errorstring, experiment_id):
create_database_mock.return_value = None
test_connection_mock.return_value = None
table_name = 'test_table'
condition = {'test': 'condition'}
config = utils.load_config(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file.cfg'))

result_processor = ResultProcessor(config, CREDENTIAL_PATH, table_name, condition, result_fields)
result_processor = ResultProcessor(config, CREDENTIAL_PATH, table_name, result_fields, experiment_id)

with pytest.raises(error, match=errorstring):
result_processor.process_results(results)
Expand All @@ -102,8 +100,8 @@ def test_valid_result_fields(create_database_if_not_existing_mock, test_connecti
create_database_if_not_existing_mock.return_value = None
test_connection_mock.return_value = None
mock_config = utils.load_config(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file.cfg'))
assert subset_boolean == ResultProcessor(mock_config, CREDENTIAL_PATH, 'test_table_name', {
'test_condition_key': 'test_condition_value'}, used_result_fields)._valid_result_fields(existing_result_fields)
assert subset_boolean == ResultProcessor(mock_config, CREDENTIAL_PATH, 'test_table_name',
used_result_fields, 0)._valid_result_fields(existing_result_fields)


@pytest.mark.parametrize(
Expand Down