@@ -148,71 +148,62 @@ def get_hyperparameter_search_space(
148
148
if default is None :
149
149
defaults = [
150
150
'NoEmbedding' ,
151
- # 'LearnedEntityEmbedding',
151
+ 'LearnedEntityEmbedding' ,
152
152
]
153
153
for default_ in defaults :
154
154
if default_ in available_embedding :
155
155
default = default_
156
156
break
157
157
158
- # Restrict embedding to NoEmbedding until preprocessing is fixed
159
- embedding = CSH .CategoricalHyperparameter ('__choice__' ,
160
- ['NoEmbedding' ],
161
- default_value = default )
158
+ categorical_columns = dataset_properties ['categorical_columns' ] \
159
+ if isinstance (dataset_properties ['categorical_columns' ], List ) else []
160
+
161
+ updates = self ._get_search_space_updates ()
162
+ if '__choice__' in updates .keys ():
163
+ choice_hyperparameter = updates ['__choice__' ]
164
+ if not set (choice_hyperparameter .value_range ).issubset (available_embedding ):
165
+ raise ValueError ("Expected given update for {} to have "
166
+ "choices in {} got {}" .format (self .__class__ .__name__ ,
167
+ available_embedding ,
168
+ choice_hyperparameter .value_range ))
169
+ if len (categorical_columns ) == 0 :
170
+ assert len (choice_hyperparameter .value_range ) == 1
171
+ if 'NoEmbedding' not in choice_hyperparameter .value_range :
172
+ raise ValueError ("Provided {} in choices, however, the dataset "
173
+ "is incompatible with it" .format (choice_hyperparameter .value_range ))
174
+ embedding = CSH .CategoricalHyperparameter ('__choice__' ,
175
+ choice_hyperparameter .value_range ,
176
+ default_value = choice_hyperparameter .default_value )
177
+ else :
178
+
179
+ if len (categorical_columns ) == 0 :
180
+ default = 'NoEmbedding'
181
+ if include is not None and default not in include :
182
+ raise ValueError ("Provided {} in include, however, the dataset "
183
+ "is incompatible with it" .format (include ))
184
+ embedding = CSH .CategoricalHyperparameter ('__choice__' ,
185
+ ['NoEmbedding' ],
186
+ default_value = default )
187
+ else :
188
+ embedding = CSH .CategoricalHyperparameter ('__choice__' ,
189
+ list (available_embedding .keys ()),
190
+ default_value = default )
191
+
162
192
cs .add_hyperparameter (embedding )
193
+ for name in embedding .choices :
194
+ updates = self ._get_search_space_updates (prefix = name )
195
+ config_space = available_embedding [name ].get_hyperparameter_search_space (dataset_properties , # type: ignore
196
+ ** updates )
197
+ parent_hyperparameter = {'parent' : embedding , 'value' : name }
198
+ cs .add_configuration_space (
199
+ name ,
200
+ config_space ,
201
+ parent_hyperparameter = parent_hyperparameter
202
+ )
203
+
163
204
self .configuration_space_ = cs
164
205
self .dataset_properties_ = dataset_properties
165
206
return cs
166
- # categorical_columns = dataset_properties['categorical_columns'] \
167
- # if isinstance(dataset_properties['categorical_columns'], List) else []
168
-
169
- # updates = self._get_search_space_updates()
170
- # if '__choice__' in updates.keys():
171
- # choice_hyperparameter = updates['__choice__']
172
- # if not set(choice_hyperparameter.value_range).issubset(available_embedding):
173
- # raise ValueError("Expected given update for {} to have "
174
- # "choices in {} got {}".format(self.__class__.__name__,
175
- # available_embedding,
176
- # choice_hyperparameter.value_range))
177
- # if len(categorical_columns) == 0:
178
- # assert len(choice_hyperparameter.value_range) == 1
179
- # if 'NoEmbedding' not in choice_hyperparameter.value_range:
180
- # raise ValueError("Provided {} in choices, however, the dataset "
181
- # "is incompatible with it".format(choice_hyperparameter.value_range))
182
- # embedding = CSH.CategoricalHyperparameter('__choice__',
183
- # choice_hyperparameter.value_range,
184
- # default_value=choice_hyperparameter.default_value)
185
- # else:
186
-
187
- # if len(categorical_columns) == 0:
188
- # default = 'NoEmbedding'
189
- # if include is not None and default not in include:
190
- # raise ValueError("Provided {} in include, however, the dataset "
191
- # "is incompatible with it".format(include))
192
- # embedding = CSH.CategoricalHyperparameter('__choice__',
193
- # ['NoEmbedding'],
194
- # default_value=default)
195
- # else:
196
- # embedding = CSH.CategoricalHyperparameter('__choice__',
197
- # list(available_embedding.keys()),
198
- # default_value=default)
199
-
200
- # cs.add_hyperparameter(embedding)
201
- # for name in embedding.choices:
202
- # updates = self._get_search_space_updates(prefix=name)
203
- # config_space = available_embedding[name].get_hyperparameter_search_space(
204
- # dataset_properties, # type: ignore
205
- # **updates)
206
- # parent_hyperparameter = {'parent': embedding, 'value': name}
207
- # cs.add_configuration_space(
208
- # name,
209
- # config_space,
210
- # parent_hyperparameter=parent_hyperparameter
211
- # )
212
-
213
- # self.configuration_space_ = cs
214
- # self.dataset_properties_ = dataset_properties
215
- # return cs
216
207
217
208
def transform (self , X : np .ndarray ) -> np .ndarray :
218
209
assert self .choice is not None , "Cannot call transform before the object is initialized"
0 commit comments