Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
cdcf766
Add fit pipeline with tests
ravinkohli Nov 30, 2021
bc5b469
Add documentation for get dataset
ravinkohli Nov 30, 2021
0359c8c
update documentation
ravinkohli Nov 30, 2021
75eb604
fix tests
ravinkohli Nov 30, 2021
136f619
remove permutation importance from visualisation example
ravinkohli Nov 30, 2021
4731363
change disable_file_output
ravinkohli Nov 30, 2021
af48ebf
add
ravinkohli Dec 6, 2021
3df4e06
fix flake
ravinkohli Dec 6, 2021
e8289e4
fix test and examples
ravinkohli Dec 6, 2021
4018d02
change type of disable_file_output
ravinkohli Dec 6, 2021
add8890
Address comments from eddie
ravinkohli Dec 6, 2021
d8739cd
fix docstring in api
ravinkohli Dec 6, 2021
f1ea974
fix tests for base api
ravinkohli Dec 6, 2021
38471f1
fix tests for base api
ravinkohli Dec 6, 2021
02ac9de
fix tests after rebase
ravinkohli Dec 6, 2021
fd32939
reduce dataset size in example
ravinkohli Dec 7, 2021
3958750
remove optional from doc string
ravinkohli Dec 7, 2021
c33381a
Handle unsuccessful fitting of pipeline better
ravinkohli Dec 7, 2021
dff0e5c
fix flake in tests
ravinkohli Dec 7, 2021
eb648e5
change to default configuration for documentation
ravinkohli Dec 7, 2021
974ea1c
add warning for no ensemble created when y_optimization in disable_fi…
ravinkohli Dec 7, 2021
cc19e4c
reduce budget for single configuration
ravinkohli Dec 7, 2021
ab93ee6
address comments from eddie
ravinkohli Dec 7, 2021
c246b20
address comments from shuhei
ravinkohli Dec 9, 2021
a0a4e75
Add autoPyTorchEnum
ravinkohli Dec 9, 2021
a0fef77
fix flake in tests
ravinkohli Dec 10, 2021
8094ff1
address comments from shuhei
ravinkohli Dec 19, 2021
4d90706
Apply suggestions from code review
ravinkohli Dec 19, 2021
c7cc712
fix flake
ravinkohli Dec 19, 2021
14113f9
use **dataset_kwargs
ravinkohli Dec 20, 2021
5b2f75f
fix flake
ravinkohli Dec 20, 2021
24aac05
change to enforce keyword args
ravinkohli Dec 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
change to enforce keyword args
  • Loading branch information
ravinkohli committed Dec 20, 2021
commit 24aac05da7b522d9e1214b4dbff8dc4e99871b66
57 changes: 40 additions & 17 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,15 @@ def refit(
def fit_pipeline(
self,
configuration: Configuration,
*,
dataset: Optional[BaseDataset] = None,
X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
dataset_name: Optional[str] = None,
resampling_strategy: Optional[Union[HoldoutValTypes, CrossValTypes]] = None,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
run_time_limit_secs: int = 60,
memory_limit: Optional[int] = None,
eval_metric: Optional[str] = None,
Expand All @@ -1364,7 +1373,6 @@ def fit_pipeline(
budget: Optional[float] = None,
pipeline_options: Optional[Dict] = None,
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
**dataset_kwargs: Any
) -> Tuple[Optional[BasePipeline], RunInfo, RunValue, BaseDataset]:
"""
Fit a pipeline on the given task for the budget.
Expand All @@ -1376,6 +1384,26 @@ def fit_pipeline(
methods.

Args:
configuration (Configuration):
configuration to fit the pipeline with.
dataset (BaseDataset):
An object of the appropriate child class of `BaseDataset`,
that will be used to fit the pipeline
X_train, y_train, X_test, y_test: Union[np.ndarray, List, pd.DataFrame]
A pair of features (X_train) and targets (y_train) used to fit a
pipeline. Additionally, a holdout of this pairs (X_test, y_test) can
be provided to track the generalization performance of each stage.
dataset_name (Optional[str]):
Name of the dataset, if None, random value is used.
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
Strategy to split the training data. if None, uses
HoldoutValTypes.holdout_validation.
resampling_strategy_args (Optional[Dict[str, Any]]):
Arguments required for the chosen resampling strategy. If None, uses
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
in ```datasets/resampling_strategy.py```.
dataset_name (Optional[str]):
name of the dataset, used as experiment name.
run_time_limit_secs (int: default=60):
Time limit for a single call to the machine learning model.
Model fitting will be terminated if the machine learning algorithm
Expand Down Expand Up @@ -1445,15 +1473,6 @@ def fit_pipeline(
+ `all`:
do not save any of the above.
For more information check `autoPyTorch.evaluation.utils.DisableFileOutputParameters`.
configuration (Configuration):
configuration to fit the pipeline with.
**dataset_kwargs (Any):
Can contain either `dataset (BaseDataset)` object or
keyword arguments specifying the dataset like X_train, y_train,
X_test, y_test (Optional[Union[List, pd.DataFrame, np.ndarray]] = None)
and other parameters like dataset_name (str),
resampling_strategy (Union[HoldoutValTypes, CrossValTypes]),
resampling_strategy_args (Dict[str, Any]).

Returns:
(BasePipeline):
Expand All @@ -1466,16 +1485,20 @@ def fit_pipeline(
Dataset created from the given tensors
"""

if 'dataset' not in dataset_kwargs:
if dataset is None:
if (
dataset_kwargs.get('X_train', None) is not None
and dataset_kwargs.get('y_train', None) is not None
X_train is not None
and y_train is not None
):
raise ValueError("No dataset provided, must provide X_train, y_train tensors")

dataset = self.get_dataset(**dataset_kwargs)
else:
dataset = dataset_kwargs['dataset']
dataset = self.get_dataset(X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
resampling_strategy=resampling_strategy,
resampling_strategy_args=resampling_strategy_args,
dataset_name=dataset_name
)

# dataset_name is created inside the constructor of BaseDataset
# we expect it to be not None. This is for mypy
Expand Down