Skip to content

Commit 5d18b8a

Browse files
committed
pass instances
1 parent fb16f87 commit 5d18b8a

File tree

2 files changed

+82
-5
lines changed

2 files changed

+82
-5
lines changed

src/irace/__init__.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from rpy2.robjects.vectors import DataFrame, BoolVector, FloatVector, IntVector, StrVector, ListVector, IntArray, Matrix, ListSexpVector,FloatSexpVector,IntSexpVector,StrSexpVector,BoolSexpVector
1616
from rpy2.robjects.functions import SignatureTranslatedFunction
1717
from rpy2.rinterface import RRuntimeWarning
18+
import json
19+
1820

1921
rpy2conversion = ro.conversion.get_conversion()
2022
irace_converter = ro.default_converter + numpy2ri.converter + pandas2ri.converter
@@ -83,6 +85,7 @@ def tmp_r_target_runner(experiment, scenario):
8385
py_experiment['configuration'] = OrderedDict(
8486
(k,v) for k,v in py_experiment['configuration'].items() if not pd.isna(v)
8587
)
88+
py_experiment['instance'] = context['py_instances'][int(py_experiment['id.instance']) - 1]
8689
try:
8790
ret = context['py_target_runner'](py_experiment, py_scenario)
8891
except:
@@ -94,7 +97,6 @@ def tmp_r_target_runner(experiment, scenario):
9497
def check_windows(scenario):
9598
if scenario.get('parallel', 1) != 1 and os.name == 'nt':
9699
raise NotImplementedError('Parallel running on windows is not supported yet. Follow https://github.com/auto-optimization/iracepy/issues/16 for updates. Alternatively, use Linux or MacOS or the irace R package directly.')
97-
98100
class irace:
99101
# Import irace R package
100102
try:
@@ -107,12 +109,20 @@ class irace:
107109

108110
def __init__(self, scenario, parameters_table, target_runner):
109111
self.scenario = scenario
112+
self.instances = scenario.get('instances', None)
113+
self.context = {}
110114
if 'instances' in scenario:
111-
self.scenario['instances'] = np.asarray(scenario['instances'])
115+
self.context.update({
116+
'py_instances': self.scenario['instances'],
117+
})
118+
self.scenario['instances'] = StrVector(list(map(lambda x: json.dumps(x, skipkeys=True, default=self.scenario.get('instanceObjectSerializer', lambda x: '<not serializable>')), self.scenario['instances'])))
119+
self.scenario.pop('instanceObjectSerializer', None)
112120
with localconverter(irace_converter_hack):
113121
self.parameters = self._pkg.readParameters(text = parameters_table, digits = self.scenario.get('digits', 4))
114-
self.context = {'py_target_runner' : target_runner,
115-
'py_scenario': self.scenario }
122+
self.context.update({
123+
'py_target_runner' : target_runner,
124+
'py_scenario': self.scenario,
125+
})
116126
check_windows(scenario)
117127

118128
def read_configurations(self, filename=None, text=None):

tests/test_data_passable.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from irace import irace
33
import pandas as pd
44
from multiprocessing import Queue
5+
import pytest
6+
import os
57

68
q = Queue()
79

@@ -12,6 +14,10 @@ def target_runner(experiment, scenario):
1214
else:
1315
return dict(cost=1)
1416

17+
def target_runner2(experiment, scenario):
18+
if experiment['id.instance'] == 1:
19+
experiment['instance'].put(1335)
20+
return dict(cost=1)
1521

1622
params = '''
1723
one "" c ('0', '1')
@@ -34,4 +40,65 @@ def test():
3440
tuner = irace(scenario, params, target_runner)
3541
best_conf = tuner.run()
3642
assert q.get() == 124
37-
43+
44+
def test_instances():
45+
q = Queue()
46+
scenario = dict(
47+
instances = [q],
48+
maxExperiments = 180,
49+
debugLevel = 0,
50+
parallel = 1,
51+
logFile = "",
52+
seed = 123
53+
)
54+
tuner = irace(scenario, params, target_runner2)
55+
best_conf = tuner.run()
56+
assert q.get() == 1335
57+
58+
@pytest.mark.skipif(os.name == 'nt',
59+
reason="Parallel on Windows not supported")
60+
def test_instances2():
61+
q = Queue()
62+
scenario = dict(
63+
instances = [q],
64+
maxExperiments = 180,
65+
debugLevel = 0,
66+
parallel = 2,
67+
logFile = "",
68+
seed = 123
69+
)
70+
tuner = irace(scenario, params, target_runner2)
71+
best_conf = tuner.run()
72+
assert q.get() == 1335
73+
74+
def test_default_serializer():
75+
q = Queue()
76+
scenario = dict(
77+
instances = [q],
78+
maxExperiments = 180,
79+
debugLevel = 0,
80+
parallel = 1,
81+
logFile = "",
82+
seed = 123,
83+
instanceObjectSerializer = lambda x: 'hello world'
84+
)
85+
tuner = irace(scenario, params, target_runner2)
86+
best_conf = tuner.run()
87+
assert q.get() == 1335
88+
89+
@pytest.mark.skipif(os.name == 'nt',
90+
reason="Parallel on Windows not supported")
91+
def test_default_serializer():
92+
q = Queue()
93+
scenario = dict(
94+
instances = [q],
95+
maxExperiments = 180,
96+
debugLevel = 0,
97+
parallel = 2,
98+
logFile = "",
99+
seed = 123,
100+
instanceObjectSerializer = lambda x: 'hello world'
101+
)
102+
tuner = irace(scenario, params, target_runner2)
103+
best_conf = tuner.run()
104+
assert q.get() == 1335

0 commit comments

Comments
 (0)