@@ -304,6 +304,101 @@ def __repr__(self) -> str:
304
304
string += '_' * 40
305
305
return string
306
306
307
+ def _get_search_space_modifications (
308
+ self ,
309
+ include : Optional [Dict [str , Any ]],
310
+ exclude : Optional [Dict [str , Any ]],
311
+ pipeline : List [Tuple [str , autoPyTorchChoice ]]
312
+ ) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
313
+ """
314
+ Get what pipeline or components to
315
+ include in or exclude from the searching space
316
+
317
+ Args:
318
+ include (Optional[Dict[str, Any]]): Components to include in the searching
319
+ exclude (Optional[Dict[str, Any]]): Components to exclude from the searching
320
+ pipeline (List[Tuple[str, autoPyTorchChoice]]):
321
+ Available components
322
+
323
+ Returns:
324
+ include, exclude (Tuple[Dict[str, Any]], Dict[str, Any]]]):
325
+ modified version of `include` and `exclude`
326
+ """
327
+
328
+ key_exist = {pair [0 ]: True for pair in pipeline }
329
+
330
+ if include is None :
331
+ include = {} if self .include is None else self .include
332
+
333
+ for key in include .keys ():
334
+ if key_exist .get (key , False ):
335
+ raise ValueError ('Keys in `include` must be {}, but got {}' .format (
336
+ list (key_exist .keys ()), key )
337
+ )
338
+
339
+ if exclude is None :
340
+ exclude = {} if self .exclude is None else self .exclude
341
+
342
+ for key in exclude :
343
+ if key_exist .get (key , False ):
344
+ raise ValueError ('Keys in `exclude` must be {}, but got {}' .format (
345
+ list (key_exist .keys ()), key )
346
+ )
347
+
348
+ return include , exclude
349
+
350
+ @staticmethod
351
+ def _update_search_space_per_node (
352
+ cs : ConfigurationSpace ,
353
+ dataset_properties : Dict [str , Any ],
354
+ include : Optional [Dict [str , Any ]],
355
+ exclude : Optional [Dict [str , Any ]],
356
+ matches : np .ndarray ,
357
+ node_name : str ,
358
+ node_idx : int ,
359
+ node : autoPyTorchChoice
360
+ ) -> ConfigurationSpace :
361
+ """
362
+ Args:
363
+ cs (ConfigurationSpace): Searching space information
364
+ dataset_properties (Dict[str, Any]): The properties of dataset
365
+ include (Optional[Dict[str, Any]]): Components to include in the searching
366
+ exclude (Optional[Dict[str, Any]]): Components to exclude from the searching
367
+ node_name (str): The name of the component choice
368
+ node_idx (int): The index of the component in a provided list of components
369
+ node (autoPyTorchChoice): The module of the component
370
+ matches (np.ndarray): ...
371
+
372
+ Returns:
373
+ modified cs (ConfigurationSpace):
374
+ modified searching space information based on the arguments.
375
+ """
376
+ is_choice = isinstance (node , autoPyTorchChoice )
377
+
378
+ if not is_choice :
379
+ # if the node isn't a choice we can add it immediately because
380
+ # it must be active (if it wasn't, np.sum(matches) would be zero
381
+ assert not isinstance (node , autoPyTorchChoice )
382
+ cs .add_configuration_space (
383
+ node_name ,
384
+ node .get_hyperparameter_search_space (dataset_properties , # type: ignore[arg-type]
385
+ ** node ._get_search_space_updates ()),
386
+ )
387
+ else :
388
+ # If the node is a choice, we have to figure out which of
389
+ # its choices are actually legal choices
390
+ choices_list = find_active_choices (
391
+ matches , node , node_idx ,
392
+ dataset_properties ,
393
+ include .get (node_name ),
394
+ exclude .get (node_name )
395
+ )
396
+ sub_config_space = node .get_hyperparameter_search_space (
397
+ dataset_properties , include = choices_list )
398
+ cs .add_configuration_space (node_name , sub_config_space )
399
+
400
+ return cs
401
+
307
402
def _get_base_search_space (
308
403
self ,
309
404
cs : ConfigurationSpace ,
@@ -312,75 +407,50 @@ def _get_base_search_space(
312
407
exclude : Optional [Dict [str , Any ]],
313
408
pipeline : List [Tuple [str , autoPyTorchChoice ]]
314
409
) -> ConfigurationSpace :
315
- if include is None :
316
- if self .include is None :
317
- include = {}
318
- else :
319
- include = self .include
320
-
321
- keys = [pair [0 ] for pair in pipeline ]
322
- for key in include :
323
- if key not in keys :
324
- raise ValueError ('Invalid key in include: %s; should be one '
325
- 'of %s' % (key , keys ))
326
-
327
- if exclude is None :
328
- if self .exclude is None :
329
- exclude = {}
330
- else :
331
- exclude = self .exclude
410
+ """
411
+ Get the searching space
332
412
333
- keys = [pair [0 ] for pair in pipeline ]
334
- for key in exclude :
335
- if key not in keys :
336
- raise ValueError ('Invalid key in exclude: %s; should be one '
337
- 'of %s' % (key , keys ))
413
+ Args:
414
+ cs (ConfigurationSpace): Searching space information
415
+ dataset_properties (Dict[str, Any]): The properties of dataset
416
+ include (Optional[Dict[str, Any]]): Components to include in the searching
417
+ exclude (Optional[Dict[str, Any]]): Components to exclude from the searching
418
+ pipeline (List[Tuple[str, autoPyTorchChoice]]):
419
+ Available components
338
420
421
+ Returns:
422
+ modified cs (ConfigurationSpace):
423
+ modified searching space information based on the arguments.
424
+ """
425
+ include , exclude = self ._get_search_space_modifications (
426
+ include = include , exclude = exclude , pipeline = pipeline
427
+ )
339
428
if self .search_space_updates is not None :
340
429
self ._check_search_space_updates (include = include ,
341
430
exclude = exclude )
342
431
self .search_space_updates .apply (pipeline = pipeline )
343
432
433
+ # The size of this array exponentially grows, so it will be better to remove
344
434
matches = get_match_array (
345
435
pipeline , dataset_properties , include = include , exclude = exclude )
346
436
347
437
# Now we have only legal combinations at this step of the pipeline
348
438
# Simple sanity checks
349
- assert np .sum (matches ) != 0 , "No valid pipeline found."
350
-
351
- assert np .sum (matches ) <= np .size (matches ), \
352
- "'matches' is not binary; %s <= %d, %s" % \
353
- (str (np .sum (matches )), np .size (matches ), str (matches .shape ))
439
+ if np .sum (matches ) == 0 :
440
+ raise ValueError ("No valid pipeline found." )
441
+ if np .sum (matches ) > np .size (matches ):
442
+ raise TypeError ("'matches' is not binary; {} <= {}, {}" .format (
443
+ np .sum (matches ), np .size (matches ), str (matches .shape )
444
+ ))
354
445
355
446
# Iterate each dimension of the matches array (each step of the
356
447
# pipeline) to see if we can add a hyperparameter for that step
357
- for node_idx , n_ in enumerate (pipeline ):
358
- node_name , node = n_
359
-
360
- is_choice = isinstance (node , autoPyTorchChoice )
361
-
362
- # if the node isn't a choice we can add it immediately because it
363
- # must be active (if it wasn't, np.sum(matches) would be zero
364
- if not is_choice :
365
- # for mypy
366
- assert not isinstance (node , autoPyTorchChoice )
367
- cs .add_configuration_space (
368
- node_name ,
369
- node .get_hyperparameter_search_space (dataset_properties , # type: ignore[arg-type]
370
- ** node ._get_search_space_updates ()),
371
- )
372
- # If the node is a choice, we have to figure out which of its
373
- # choices are actually legal choices
374
- else :
375
- choices_list = find_active_choices (
376
- matches , node , node_idx ,
377
- dataset_properties ,
378
- include .get (node_name ),
379
- exclude .get (node_name )
380
- )
381
- sub_config_space = node .get_hyperparameter_search_space (
382
- dataset_properties , include = choices_list )
383
- cs .add_configuration_space (node_name , sub_config_space )
448
+ for node_idx , (node_name , node ) in enumerate (pipeline ):
449
+ cs = self ._update_search_space_per_node (
450
+ cs = ConfigurationSpace , dataset_properties = dataset_properties ,
451
+ include = include , exclude = exclude , matches = matches ,
452
+ node = node , node_idx = node_idx , node_name = node_name
453
+ )
384
454
385
455
# And now add forbidden parameter configurations
386
456
# According to matches
0 commit comments