Skip to content

Commit 855c57e

Browse files
[ADD] Forkserver as default multiprocessing strategy (#223)
* First push of forkserver * [Fix] Missing file * [FIX] mypy * [Fix] renam choice to init * [Fix] Unit test * [Fix] bugs in examples * [Fix] ensemble builder * Update autoPyTorch/pipeline/components/preprocessing/image_preprocessing/normalise/__init__.py Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> * Update autoPyTorch/pipeline/components/preprocessing/image_preprocessing/normalise/__init__.py Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> * Update autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/__init__.py Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> * Update autoPyTorch/pipeline/components/preprocessing/image_preprocessing/normalise/__init__.py Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> * Update autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/__init__.py Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> * Update autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/__init__.py Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> * Update autoPyTorch/pipeline/components/setup/network_head/__init__.py Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> * Update autoPyTorch/pipeline/components/setup/network_initializer/__init__.py Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> * Update autoPyTorch/pipeline/components/setup/network_embedding/__init__.py Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> * [FIX] improve doc-strings * Fix rebase Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com>
1 parent 1e08fc9 commit 855c57e

File tree

54 files changed

+2910
-2728
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+2910
-2728
lines changed

autoPyTorch/api/base_task.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@
5454
setup_logger,
5555
start_log_server,
5656
)
57+
from autoPyTorch.utils.parallel import preload_modules
5758
from autoPyTorch.utils.pipeline import get_configuration_space, get_dataset_requirements
59+
from autoPyTorch.utils.single_thread_client import SingleThreadedClient
5860
from autoPyTorch.utils.stopwatch import StopWatch
5961

6062

@@ -190,7 +192,16 @@ def __init__(
190192

191193
self.stop_logging_server = None # type: Optional[multiprocessing.synchronize.Event]
192194

195+
# Single core, local runs should use fork
196+
# to prevent the __main__ requirements in
197+
# examples. Nevertheless, multi-process runs
198+
# have spawn as requirement to reduce the
199+
# possibility of a deadlock
193200
self._dask_client = None
201+
self._multiprocessing_context = 'forkserver'
202+
if self.n_jobs == 1:
203+
self._multiprocessing_context = 'fork'
204+
self._dask_client = SingleThreadedClient()
194205

195206
self.search_space_updates = search_space_updates
196207
if search_space_updates is not None:
@@ -300,7 +311,8 @@ def _get_logger(self, name: str) -> PicklableClientLogger:
300311
# under the above logging configuration setting
301312
# We need to specify the logger_name so that received records
302313
# are treated under the logger_name ROOT logger setting
303-
context = multiprocessing.get_context('spawn')
314+
context = multiprocessing.get_context(self._multiprocessing_context)
315+
preload_modules(context)
304316
self.stop_logging_server = context.Event()
305317
port = context.Value('l') # be safe by using a long
306318
port.value = -1
@@ -505,6 +517,7 @@ def _do_dummy_prediction(self) -> None:
505517
stats = Stats(scenario_mock)
506518
stats.start_timing()
507519
ta = ExecuteTaFuncWithQueue(
520+
pynisher_context=self._multiprocessing_context,
508521
backend=self._backend,
509522
seed=self.seed,
510523
metric=self._metric,
@@ -599,6 +612,7 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
599612
stats = Stats(scenario_mock)
600613
stats.start_timing()
601614
ta = ExecuteTaFuncWithQueue(
615+
pynisher_context=self._multiprocessing_context,
602616
backend=self._backend,
603617
seed=self.seed,
604618
metric=self._metric,
@@ -929,6 +943,7 @@ def _search(
929943
random_state=self.seed,
930944
precision=precision,
931945
logger_port=self._logger_port,
946+
pynisher_context=self._multiprocessing_context,
932947
)
933948
self._stopwatch.stop_task(ensemble_task_name)
934949

@@ -969,6 +984,7 @@ def _search(
969984
start_num_run=self._backend.get_next_num_run(peek=True),
970985
search_space_updates=self.search_space_updates,
971986
portfolio_selection=portfolio_selection,
987+
pynisher_context=self._multiprocessing_context,
972988
)
973989
try:
974990
run_history, self.trajectory, budget_type = \
@@ -1299,5 +1315,6 @@ def _print_debug_info_to_log(self) -> None:
12991315
self._logger.debug(' System: %s', platform.system())
13001316
self._logger.debug(' Machine: %s', platform.machine())
13011317
self._logger.debug(' Platform: %s', platform.platform())
1318+
self._logger.debug(' multiprocessing_context: %s', str(self._multiprocessing_context))
13021319
for key, value in vars(self).items():
13031320
self._logger.debug(f"\t{key}->{value}")

autoPyTorch/ensemble/ensemble_builder.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
3737
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_loss, calculate_score
3838
from autoPyTorch.utils.logging_ import get_named_client_logger
39+
from autoPyTorch.utils.parallel import preload_modules
3940

4041
Y_ENSEMBLE = 0
4142
Y_TEST = 1
@@ -64,6 +65,7 @@ def __init__(
6465
ensemble_memory_limit: Optional[int],
6566
random_state: int,
6667
logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT,
68+
pynisher_context: str = 'fork',
6769
):
6870
""" SMAC callback to handle ensemble building
6971
Args:
@@ -111,6 +113,8 @@ def __init__(
111113
read at most n new prediction files in each iteration
112114
logger_port: int
113115
port in where to publish a msg
116+
pynisher_context: str
117+
The multiprocessing context for pynisher. One of spawn/fork/forkserver.
114118
115119
Returns:
116120
List[Tuple[int, float, float, float]]:
@@ -135,6 +139,7 @@ def __init__(
135139
self.ensemble_memory_limit = ensemble_memory_limit
136140
self.random_state = random_state
137141
self.logger_port = logger_port
142+
self.pynisher_context = pynisher_context
138143

139144
# Store something similar to SMAC's runhistory
140145
self.history = [] # type: List[Dict[str, float]]
@@ -160,7 +165,6 @@ def __call__(
160165
def build_ensemble(
161166
self,
162167
dask_client: dask.distributed.Client,
163-
pynisher_context: str = 'spawn',
164168
unit_test: bool = False
165169
) -> None:
166170

@@ -236,7 +240,7 @@ def build_ensemble(
236240
iteration=self.iteration,
237241
return_predictions=False,
238242
priority=100,
239-
pynisher_context=pynisher_context,
243+
pynisher_context=self.pynisher_context,
240244
logger_port=self.logger_port,
241245
unit_test=unit_test,
242246
))
@@ -585,11 +589,11 @@ def __init__(
585589
def run(
586590
self,
587591
iteration: int,
592+
pynisher_context: str,
588593
time_left: Optional[float] = None,
589594
end_at: Optional[float] = None,
590595
time_buffer: int = 5,
591596
return_predictions: bool = False,
592-
pynisher_context: str = 'spawn', # only change for unit testing!
593597
) -> Tuple[
594598
List[Dict[str, float]],
595599
int,
@@ -655,6 +659,7 @@ def run(
655659
if wall_time_in_s < 1:
656660
break
657661
context = multiprocessing.get_context(pynisher_context)
662+
preload_modules(context)
658663

659664
safe_ensemble_script = pynisher.enforce_limits(
660665
wall_time_in_s=wall_time_in_s,

autoPyTorch/evaluation/tae.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from autoPyTorch.utils.common import replace_string_bool_to_bool
3030
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
3131
from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger
32+
from autoPyTorch.utils.parallel import preload_modules
3233

3334

3435
def fit_predict_try_except_decorator(
@@ -92,29 +93,29 @@ class ExecuteTaFuncWithQueue(AbstractTAFunc):
9293
"""
9394

9495
def __init__(
95-
self,
96-
backend: Backend,
97-
seed: int,
98-
metric: autoPyTorchMetric,
99-
cost_for_crash: float,
100-
abort_on_first_run_crash: bool,
101-
pipeline_config: typing.Optional[typing.Dict[str, typing.Any]] = None,
102-
initial_num_run: int = 1,
103-
stats: typing.Optional[Stats] = None,
104-
run_obj: str = 'quality',
105-
par_factor: int = 1,
106-
output_y_hat_optimization: bool = True,
107-
include: typing.Optional[typing.Dict[str, typing.Any]] = None,
108-
exclude: typing.Optional[typing.Dict[str, typing.Any]] = None,
109-
memory_limit: typing.Optional[int] = None,
110-
disable_file_output: bool = False,
111-
init_params: typing.Dict[str, typing.Any] = None,
112-
budget_type: str = None,
113-
ta: typing.Optional[typing.Callable] = None,
114-
logger_port: int = None,
115-
all_supported_metrics: bool = True,
116-
pynisher_context: str = 'spawn',
117-
search_space_updates: typing.Optional[HyperparameterSearchSpaceUpdates] = None
96+
self,
97+
backend: Backend,
98+
seed: int,
99+
metric: autoPyTorchMetric,
100+
cost_for_crash: float,
101+
abort_on_first_run_crash: bool,
102+
pynisher_context: str,
103+
pipeline_config: typing.Optional[typing.Dict[str, typing.Any]] = None,
104+
initial_num_run: int = 1,
105+
stats: typing.Optional[Stats] = None,
106+
run_obj: str = 'quality',
107+
par_factor: int = 1,
108+
output_y_hat_optimization: bool = True,
109+
include: typing.Optional[typing.Dict[str, typing.Any]] = None,
110+
exclude: typing.Optional[typing.Dict[str, typing.Any]] = None,
111+
memory_limit: typing.Optional[int] = None,
112+
disable_file_output: bool = False,
113+
init_params: typing.Dict[str, typing.Any] = None,
114+
budget_type: str = None,
115+
ta: typing.Optional[typing.Callable] = None,
116+
logger_port: int = None,
117+
all_supported_metrics: bool = True,
118+
search_space_updates: typing.Optional[HyperparameterSearchSpaceUpdates] = None
118119
):
119120

120121
eval_function = autoPyTorch.evaluation.train_evaluator.eval_function
@@ -249,6 +250,7 @@ def run(
249250
) -> typing.Tuple[StatusType, float, float, typing.Dict[str, typing.Any]]:
250251

251252
context = multiprocessing.get_context(self.pynisher_context)
253+
preload_modules(context)
252254
queue: multiprocessing.queues.Queue = context.Queue()
253255

254256
if not (instance_specific is None or instance_specific == '0'):

autoPyTorch/optimizer/smbo.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def __init__(self,
109109
ensemble_callback: typing.Optional[EnsembleBuilderManager] = None,
110110
logger_port: typing.Optional[int] = None,
111111
search_space_updates: typing.Optional[HyperparameterSearchSpaceUpdates] = None,
112-
portfolio_selection: typing.Optional[str] = None
112+
portfolio_selection: typing.Optional[str] = None,
113+
pynisher_context: str = 'spawn',
113114
):
114115
"""
115116
Interface to SMAC. This method calls the SMAC optimize method, and allows
@@ -156,6 +157,8 @@ def __init__(self,
156157
Additional arguments to the smac scenario
157158
get_smac_object_callback (typing.Optional[typing.Callable]):
158159
Allows to create a user specified SMAC object
160+
pynisher_context (str):
161+
A string indicating the multiprocessing context to use
159162
ensemble_callback (typing.Optional[EnsembleBuilderManager]):
160163
A callback used in this scenario to start ensemble building subtasks
161164
portfolio_selection (str), (default=None):
@@ -204,6 +207,7 @@ def __init__(self,
204207
self.disable_file_output = disable_file_output
205208
self.smac_scenario_args = smac_scenario_args
206209
self.get_smac_object_callback = get_smac_object_callback
210+
self.pynisher_context = pynisher_context
207211

208212
self.ensemble_callback = ensemble_callback
209213

@@ -274,7 +278,8 @@ def run_smbo(self, func: typing.Optional[typing.Callable] = None
274278
logger_port=self.logger_port,
275279
all_supported_metrics=self.all_supported_metrics,
276280
pipeline_config=self.pipeline_config,
277-
search_space_updates=self.search_space_updates
281+
search_space_updates=self.search_space_updates,
282+
pynisher_context=self.pynisher_context,
278283
)
279284
ta = ExecuteTaFuncWithQueue
280285
self.logger.info("Created TA")
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import os
2+
from collections import OrderedDict
3+
from typing import Any, Dict, List, Optional
4+
5+
import ConfigSpace.hyperparameters as CSH
6+
from ConfigSpace.configuration_space import ConfigurationSpace
7+
8+
from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice
9+
from autoPyTorch.pipeline.components.base_component import (
10+
ThirdPartyComponents,
11+
autoPyTorchComponent,
12+
find_components,
13+
)
14+
from autoPyTorch.pipeline.components.preprocessing.image_preprocessing.normalise.base_normalizer import BaseNormalizer
15+
16+
17+
normalise_directory = os.path.split(__file__)[0]
18+
_normalizers = find_components(__package__,
19+
normalise_directory,
20+
BaseNormalizer)
21+
22+
_addons = ThirdPartyComponents(BaseNormalizer)
23+
24+
25+
def add_normalizer(normalizer: BaseNormalizer) -> None:
26+
_addons.add_component(normalizer)
27+
28+
29+
class NormalizerChoice(autoPyTorchChoice):
30+
"""
31+
Allows for dynamically choosing normalizer component at runtime
32+
"""
33+
34+
def get_components(self) -> Dict[str, autoPyTorchComponent]:
35+
"""Returns the available normalizer components
36+
37+
Args:
38+
None
39+
40+
Returns:
41+
Dict[str, autoPyTorchComponent]: all BaseNormalizer components available
42+
as choices for encoding the categorical columns
43+
"""
44+
components = OrderedDict()
45+
components.update(_normalizers)
46+
components.update(_addons.components)
47+
return components
48+
49+
def get_hyperparameter_search_space(self,
50+
dataset_properties: Optional[Dict[str, Any]] = None,
51+
default: Optional[str] = None,
52+
include: Optional[List[str]] = None,
53+
exclude: Optional[List[str]] = None) -> ConfigurationSpace:
54+
cs = ConfigurationSpace()
55+
56+
if dataset_properties is None:
57+
dataset_properties = dict()
58+
59+
dataset_properties = {**self.dataset_properties, **dataset_properties}
60+
61+
available_preprocessors = self.get_available_components(dataset_properties=dataset_properties,
62+
include=include,
63+
exclude=exclude)
64+
65+
if len(available_preprocessors) == 0:
66+
raise ValueError("no image normalizers found, please add an image normalizer")
67+
68+
if default is None:
69+
defaults = ['ImageNormalizer', 'NoNormalizer']
70+
for default_ in defaults:
71+
if default_ in available_preprocessors:
72+
if include is not None and default_ not in include:
73+
continue
74+
if exclude is not None and default_ in exclude:
75+
continue
76+
default = default_
77+
break
78+
79+
updates = self._get_search_space_updates()
80+
if '__choice__' in updates.keys():
81+
choice_hyperparameter = updates['__choice__']
82+
if not set(choice_hyperparameter.value_range).issubset(available_preprocessors):
83+
raise ValueError("Expected given update for {} to have "
84+
"choices in {} got {}".format(self.__class__.__name__,
85+
available_preprocessors,
86+
choice_hyperparameter.value_range))
87+
preprocessor = CSH.CategoricalHyperparameter('__choice__',
88+
choice_hyperparameter.value_range,
89+
default_value=choice_hyperparameter.default_value)
90+
else:
91+
preprocessor = CSH.CategoricalHyperparameter('__choice__',
92+
list(available_preprocessors.keys()),
93+
default_value=default)
94+
cs.add_hyperparameter(preprocessor)
95+
96+
# add only child hyperparameters of preprocessor choices
97+
for name in preprocessor.choices:
98+
preprocessor_configuration_space = available_preprocessors[name].\
99+
get_hyperparameter_search_space(dataset_properties)
100+
parent_hyperparameter = {'parent': preprocessor, 'value': name}
101+
cs.add_configuration_space(name, preprocessor_configuration_space,
102+
parent_hyperparameter=parent_hyperparameter)
103+
104+
self.configuration_space = cs
105+
self.dataset_properties = dataset_properties
106+
return cs

0 commit comments

Comments
 (0)