Skip to content

Commit

Permalink
update get_configspace so that group names are flattened and individu…
Browse files Browse the repository at this point in the history
…al modules have equal probability
  • Loading branch information
perib committed Oct 11, 2024
1 parent 29c27bd commit d5dd1eb
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions tpot2/config/get_configspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,22 @@ def get_configspace(name, n_classes=3, n_samples=1000, n_features=100, random_st
#raise error
raise ValueError(f"Could not find configspace for {name}")

def flatten_group_names(name):
#if string
if isinstance(name, str):
if name in GROUPNAMES:
return flatten_group_names(GROUPNAMES[name])
else:
return name

flattened_list = []
for key in name:
if key in GROUPNAMES:
flattened_list.extend(flatten_group_names(GROUPNAMES[key]))
else:
flattened_list.append(key)

return flattened_list

def get_search_space(name, n_classes=3, n_samples=1000, n_features=100, random_state=None, return_choice_pipeline=True, base_node=EstimatorNode):
"""
Expand Down Expand Up @@ -471,6 +487,7 @@ def get_search_space(name, n_classes=3, n_samples=1000, n_features=100, random_s
Note: for some special cases with methods using wrapped estimators, the returned search space is a TPOT2.search_spaces.pipelines.WrapperPipeline object.
"""
name = flatten_group_names(name)

#if list of names, return a list of EstimatorNodes
if isinstance(name, list) or isinstance(name, np.ndarray):
Expand All @@ -483,9 +500,9 @@ def get_search_space(name, n_classes=3, n_samples=1000, n_features=100, random_s
else:
return np.hstack(search_spaces)

if name in GROUPNAMES:
name_list = GROUPNAMES[name]
return get_search_space(name_list, n_classes=n_classes, n_samples=n_samples, n_features=n_features, random_state=random_state, return_choice_pipeline=return_choice_pipeline, base_node=base_node)
# if name in GROUPNAMES:
# name_list = GROUPNAMES[name]
# return get_search_space(name_list, n_classes=n_classes, n_samples=n_samples, n_features=n_features, random_state=random_state, return_choice_pipeline=return_choice_pipeline, base_node=base_node)

return get_node(name, n_classes=n_classes, n_samples=n_samples, n_features=n_features, random_state=random_state, base_node=base_node)

Expand Down

0 comments on commit d5dd1eb

Please sign in to comment.