Skip to content

Commit d7a23b0

Browse files
committed
[refactor] Refactor _get_base_search_space in base_pipeline
1 parent 57111e9 commit d7a23b0

File tree

1 file changed

+124
-54
lines changed

1 file changed

+124
-54
lines changed

autoPyTorch/pipeline/base_pipeline.py

Lines changed: 124 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,101 @@ def __repr__(self) -> str:
304304
string += '_' * 40
305305
return string
306306

307+
def _get_search_space_modifications(
308+
self,
309+
include: Optional[Dict[str, Any]],
310+
exclude: Optional[Dict[str, Any]],
311+
pipeline: List[Tuple[str, autoPyTorchChoice]]
312+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
313+
"""
314+
Get what pipeline or components to
315+
include in or exclude from the searching space
316+
317+
Args:
318+
include (Optional[Dict[str, Any]]): Components to include in the searching
319+
exclude (Optional[Dict[str, Any]]): Components to exclude from the searching
320+
pipeline (List[Tuple[str, autoPyTorchChoice]]):
321+
Available components
322+
323+
Returns:
324+
include, exclude (Tuple[Dict[str, Any]], Dict[str, Any]]]):
325+
modified version of `include` and `exclude`
326+
"""
327+
328+
key_exist = {pair[0]: True for pair in pipeline}
329+
330+
if include is None:
331+
include = {} if self.include is None else self.include
332+
333+
for key in include.keys():
334+
if key_exist.get(key, False):
335+
raise ValueError('Keys in `include` must be {}, but got {}'.format(
336+
list(key_exist.keys()), key)
337+
)
338+
339+
if exclude is None:
340+
exclude = {} if self.exclude is None else self.exclude
341+
342+
for key in exclude:
343+
if key_exist.get(key, False):
344+
raise ValueError('Keys in `exclude` must be {}, but got {}'.format(
345+
list(key_exist.keys()), key)
346+
)
347+
348+
return include, exclude
349+
350+
@staticmethod
351+
def _update_search_space_per_node(
352+
cs: ConfigurationSpace,
353+
dataset_properties: Dict[str, Any],
354+
include: Optional[Dict[str, Any]],
355+
exclude: Optional[Dict[str, Any]],
356+
matches: np.ndarray,
357+
node_name: str,
358+
node_idx: int,
359+
node: autoPyTorchChoice
360+
) -> ConfigurationSpace:
361+
"""
362+
Args:
363+
cs (ConfigurationSpace): Searching space information
364+
dataset_properties (Dict[str, Any]): The properties of dataset
365+
include (Optional[Dict[str, Any]]): Components to include in the searching
366+
exclude (Optional[Dict[str, Any]]): Components to exclude from the searching
367+
node_name (str): The name of the component choice
368+
node_idx (int): The index of the component in a provided list of components
369+
node (autoPyTorchChoice): The module of the component
370+
matches (np.ndarray): ...
371+
372+
Returns:
373+
modified cs (ConfigurationSpace):
374+
modified searching space information based on the arguments.
375+
"""
376+
is_choice = isinstance(node, autoPyTorchChoice)
377+
378+
if not is_choice:
379+
# if the node isn't a choice we can add it immediately because
380+
# it must be active (if it wasn't, np.sum(matches) would be zero
381+
assert not isinstance(node, autoPyTorchChoice)
382+
cs.add_configuration_space(
383+
node_name,
384+
node.get_hyperparameter_search_space(dataset_properties, # type: ignore[arg-type]
385+
**node._get_search_space_updates()),
386+
)
387+
else:
388+
# If the node is a choice, we have to figure out which of
389+
# its choices are actually legal choices
390+
choices_list = find_active_choices(
391+
matches, node, node_idx,
392+
dataset_properties,
393+
include.get(node_name),
394+
exclude.get(node_name)
395+
)
396+
sub_config_space = node.get_hyperparameter_search_space(
397+
dataset_properties, include=choices_list)
398+
cs.add_configuration_space(node_name, sub_config_space)
399+
400+
return cs
401+
307402
def _get_base_search_space(
308403
self,
309404
cs: ConfigurationSpace,
@@ -312,75 +407,50 @@ def _get_base_search_space(
312407
exclude: Optional[Dict[str, Any]],
313408
pipeline: List[Tuple[str, autoPyTorchChoice]]
314409
) -> ConfigurationSpace:
315-
if include is None:
316-
if self.include is None:
317-
include = {}
318-
else:
319-
include = self.include
320-
321-
keys = [pair[0] for pair in pipeline]
322-
for key in include:
323-
if key not in keys:
324-
raise ValueError('Invalid key in include: %s; should be one '
325-
'of %s' % (key, keys))
326-
327-
if exclude is None:
328-
if self.exclude is None:
329-
exclude = {}
330-
else:
331-
exclude = self.exclude
410+
"""
411+
Get the searching space
332412
333-
keys = [pair[0] for pair in pipeline]
334-
for key in exclude:
335-
if key not in keys:
336-
raise ValueError('Invalid key in exclude: %s; should be one '
337-
'of %s' % (key, keys))
413+
Args:
414+
cs (ConfigurationSpace): Searching space information
415+
dataset_properties (Dict[str, Any]): The properties of dataset
416+
include (Optional[Dict[str, Any]]): Components to include in the searching
417+
exclude (Optional[Dict[str, Any]]): Components to exclude from the searching
418+
pipeline (List[Tuple[str, autoPyTorchChoice]]):
419+
Available components
338420
421+
Returns:
422+
modified cs (ConfigurationSpace):
423+
modified searching space information based on the arguments.
424+
"""
425+
include, exclude = self._get_search_space_modifications(
426+
include=include, exclude=exclude, pipeline=pipeline
427+
)
339428
if self.search_space_updates is not None:
340429
self._check_search_space_updates(include=include,
341430
exclude=exclude)
342431
self.search_space_updates.apply(pipeline=pipeline)
343432

433+
# The size of this array exponentially grows, so it will be better to remove
344434
matches = get_match_array(
345435
pipeline, dataset_properties, include=include, exclude=exclude)
346436

347437
# Now we have only legal combinations at this step of the pipeline
348438
# Simple sanity checks
349-
assert np.sum(matches) != 0, "No valid pipeline found."
350-
351-
assert np.sum(matches) <= np.size(matches), \
352-
"'matches' is not binary; %s <= %d, %s" % \
353-
(str(np.sum(matches)), np.size(matches), str(matches.shape))
439+
if np.sum(matches) == 0:
440+
raise ValueError("No valid pipeline found.")
441+
if np.sum(matches) > np.size(matches):
442+
raise TypeError("'matches' is not binary; {} <= {}, {}".format(
443+
np.sum(matches), np.size(matches), str(matches.shape)
444+
))
354445

355446
# Iterate each dimension of the matches array (each step of the
356447
# pipeline) to see if we can add a hyperparameter for that step
357-
for node_idx, n_ in enumerate(pipeline):
358-
node_name, node = n_
359-
360-
is_choice = isinstance(node, autoPyTorchChoice)
361-
362-
# if the node isn't a choice we can add it immediately because it
363-
# must be active (if it wasn't, np.sum(matches) would be zero
364-
if not is_choice:
365-
# for mypy
366-
assert not isinstance(node, autoPyTorchChoice)
367-
cs.add_configuration_space(
368-
node_name,
369-
node.get_hyperparameter_search_space(dataset_properties, # type: ignore[arg-type]
370-
**node._get_search_space_updates()),
371-
)
372-
# If the node is a choice, we have to figure out which of its
373-
# choices are actually legal choices
374-
else:
375-
choices_list = find_active_choices(
376-
matches, node, node_idx,
377-
dataset_properties,
378-
include.get(node_name),
379-
exclude.get(node_name)
380-
)
381-
sub_config_space = node.get_hyperparameter_search_space(
382-
dataset_properties, include=choices_list)
383-
cs.add_configuration_space(node_name, sub_config_space)
448+
for node_idx, (node_name, node) in enumerate(pipeline):
449+
cs = self._update_search_space_per_node(
450+
cs=ConfigurationSpace, dataset_properties=dataset_properties,
451+
include=include, exclude=exclude, matches=matches,
452+
node=node, node_idx=node_idx, node_name=node_name
453+
)
384454

385455
# And now add forbidden parameter configurations
386456
# According to matches

0 commit comments

Comments
 (0)