@@ -348,11 +348,17 @@ def replace_data(self, X_train: BaseDatasetInputType,
348
348
349
349
def get_dataset_properties (self , dataset_requirements : List [FitRequirement ]) -> Dict [str , Any ]:
350
350
"""
351
- Gets the dataset properties required in the fit dictionary
351
+ Gets the dataset properties required in the fit dictionary.
352
+ This depends on the components that are active in the
353
+ pipeline and returns the properties they need about the dataset.
354
+ Information of the required properties of each component
355
+ can be found in their documentation.
352
356
Args:
353
357
dataset_requirements (List[FitRequirement]): List of
354
358
fit requirements that the dataset properties must
355
- contain.
359
+ contain. This is created using the `get_dataset_requirements
360
+ function in
361
+ <https://github.com/automl/Auto-PyTorch/blob/refactor_development/autoPyTorch/utils/pipeline.py#L25>`
356
362
357
363
Returns:
358
364
dataset_properties (Dict[str, Any]):
@@ -362,19 +368,15 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) ->
362
368
for dataset_requirement in dataset_requirements :
363
369
dataset_properties [dataset_requirement .name ] = getattr (self , dataset_requirement .name )
364
370
365
- # Add task type, output type and issparse to dataset properties as
366
- # they are not a dataset requirement in the pipeline
367
- dataset_properties .update ({'task_type' : self .task_type ,
368
- 'output_type' : self .output_type ,
369
- 'issparse' : self .issparse ,
370
- 'input_shape' : self .input_shape ,
371
- 'output_shape' : self .output_shape
372
- })
371
+ # Add the required dataset info to dataset properties as
372
+ # they might not be a dataset requirement in the pipeline
373
+ dataset_properties .update (self .get_required_dataset_info ())
373
374
return dataset_properties
374
375
375
376
def get_required_dataset_info (self ) -> Dict [str , Any ]:
376
377
"""
377
- Returns a dictionary containing required dataset properties to instantiate a pipeline,
378
+ Returns a dictionary containing required dataset
379
+ properties to instantiate a pipeline.
378
380
"""
379
381
info = {'output_type' : self .output_type ,
380
382
'issparse' : self .issparse }
0 commit comments