diff --git a/tpot2/config/get_configspace.py b/tpot2/config/get_configspace.py index 9b03718b..6e2fb733 100644 --- a/tpot2/config/get_configspace.py +++ b/tpot2/config/get_configspace.py @@ -436,6 +436,22 @@ def get_configspace(name, n_classes=3, n_samples=1000, n_features=100, random_st #raise error raise ValueError(f"Could not find configspace for {name}") +def flatten_group_names(name): + #if string + if isinstance(name, str): + if name in GROUPNAMES: + return flatten_group_names(GROUPNAMES[name]) + else: + return name + + flattened_list = [] + for key in name: + if key in GROUPNAMES: + flattened_list.extend(flatten_group_names(GROUPNAMES[key])) + else: + flattened_list.append(key) + + return flattened_list def get_search_space(name, n_classes=3, n_samples=1000, n_features=100, random_state=None, return_choice_pipeline=True, base_node=EstimatorNode): """ @@ -471,6 +487,7 @@ def get_search_space(name, n_classes=3, n_samples=1000, n_features=100, random_s Note: for some special cases with methods using wrapped estimators, the returned search space is a TPOT2.search_spaces.pipelines.WrapperPipeline object. """ + name = flatten_group_names(name) #if list of names, return a list of EstimatorNodes if isinstance(name, list) or isinstance(name, np.ndarray): @@ -483,9 +500,9 @@ def get_search_space(name, n_classes=3, n_samples=1000, n_features=100, random_s else: return np.hstack(search_spaces) - if name in GROUPNAMES: - name_list = GROUPNAMES[name] - return get_search_space(name_list, n_classes=n_classes, n_samples=n_samples, n_features=n_features, random_state=random_state, return_choice_pipeline=return_choice_pipeline, base_node=base_node) + # if name in GROUPNAMES: + # name_list = GROUPNAMES[name] + # return get_search_space(name_list, n_classes=n_classes, n_samples=n_samples, n_features=n_features, random_state=random_state, return_choice_pipeline=return_choice_pipeline, base_node=base_node) return get_node(name, n_classes=n_classes, n_samples=n_samples, n_features=n_features, random_state=random_state, base_node=base_node)