Skip to content

Commit f32b991

Browse files
committed
Added doc string to explain dataset properties
1 parent 26ad514 commit f32b991

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -348,11 +348,17 @@ def replace_data(self, X_train: BaseDatasetInputType,
348348

349349
def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) -> Dict[str, Any]:
350350
"""
351-
Gets the dataset properties required in the fit dictionary
351+
Gets the dataset properties required in the fit dictionary.
352+
This depends on the components that are active in the
353+
pipeline and returns the properties they need about the dataset.
354+
Information of the required properties of each component
355+
can be found in their documentation.
352356
Args:
353357
dataset_requirements (List[FitRequirement]): List of
354358
fit requirements that the dataset properties must
355-
contain.
359+
contain. This is created using the `get_dataset_requirements
360+
function in
361+
<https://github.com/automl/Auto-PyTorch/blob/refactor_development/autoPyTorch/utils/pipeline.py#L25>`
356362
357363
Returns:
358364
dataset_properties (Dict[str, Any]):
@@ -362,19 +368,15 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) ->
362368
for dataset_requirement in dataset_requirements:
363369
dataset_properties[dataset_requirement.name] = getattr(self, dataset_requirement.name)
364370

365-
# Add task type, output type and issparse to dataset properties as
366-
# they are not a dataset requirement in the pipeline
367-
dataset_properties.update({'task_type': self.task_type,
368-
'output_type': self.output_type,
369-
'issparse': self.issparse,
370-
'input_shape': self.input_shape,
371-
'output_shape': self.output_shape
372-
})
371+
# Add the required dataset info to dataset properties as
372+
# they might not be a dataset requirement in the pipeline
373+
dataset_properties.update(self.get_required_dataset_info())
373374
return dataset_properties
374375

375376
def get_required_dataset_info(self) -> Dict[str, Any]:
376377
"""
377-
Returns a dictionary containing required dataset properties to instantiate a pipeline,
378+
Returns a dictionary containing required dataset
379+
properties to instantiate a pipeline.
378380
"""
379381
info = {'output_type': self.output_type,
380382
'issparse': self.issparse}

autoPyTorch/datasets/tabular_dataset.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,24 @@ def __init__(self,
112112

113113
def get_required_dataset_info(self) -> Dict[str, Any]:
114114
"""
115-
Returns a dictionary containing required dataset properties to instantiate a pipeline,
115+
Returns a dictionary containing required dataset
116+
properties to instantiate a pipeline.
117+
For a Tabular Dataset this includes-
118+
1. 'output_type'- Enum indicating the type of the output for this problem.
119+
We currently use the `sklearn type_of_target
120+
<https://scikit-learn.org/stable/modules/generated/sklearn.utils.multiclass.type_of_target.html>`
121+
to infer the output type from the data and we encode it to an
122+
Enum for which you can find more info in `autopytorch/constants.py
123+
<https://github.com/automl/Auto-PyTorch/blob/refactor_development/autoPyTorch/constants.py>`
124+
2. 'issparse'- A flag indicating if the input is in a sparse matrix.
125+
3. 'numerical_columns'- a list which contains the column numbers
126+
for the numerical columns in the input dataset
127+
4. 'categorical_columns'- a list which contains the column numbers
128+
for the categorical columns in the input dataset
129+
5. 'task_type'- Enum indicating the type of task. For tabular datasets,
130+
currently we support 'tabular_classification' and 'tabular_regression'. and we encode it to an
131+
Enum for which you can find more info in `autopytorch/constants.py
132+
<https://github.com/automl/Auto-PyTorch/blob/refactor_development/autoPyTorch/constants.py>`
116133
"""
117134
info = super().get_required_dataset_info()
118135
info.update({

0 commit comments

Comments
 (0)