Skip to content

Commit 24aac05

Browse files
committed
change to enforce keyword args
1 parent 5b2f75f commit 24aac05

File tree

1 file changed

+40
-17
lines changed

1 file changed

+40
-17
lines changed

autoPyTorch/api/base_task.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,6 +1353,15 @@ def refit(
13531353
def fit_pipeline(
13541354
self,
13551355
configuration: Configuration,
1356+
*,
1357+
dataset: Optional[BaseDataset] = None,
1358+
X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
1359+
y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
1360+
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
1361+
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
1362+
dataset_name: Optional[str] = None,
1363+
resampling_strategy: Optional[Union[HoldoutValTypes, CrossValTypes]] = None,
1364+
resampling_strategy_args: Optional[Dict[str, Any]] = None,
13561365
run_time_limit_secs: int = 60,
13571366
memory_limit: Optional[int] = None,
13581367
eval_metric: Optional[str] = None,
@@ -1364,7 +1373,6 @@ def fit_pipeline(
13641373
budget: Optional[float] = None,
13651374
pipeline_options: Optional[Dict] = None,
13661375
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
1367-
**dataset_kwargs: Any
13681376
) -> Tuple[Optional[BasePipeline], RunInfo, RunValue, BaseDataset]:
13691377
"""
13701378
Fit a pipeline on the given task for the budget.
@@ -1376,6 +1384,26 @@ def fit_pipeline(
13761384
methods.
13771385
13781386
Args:
1387+
configuration (Configuration):
1388+
configuration to fit the pipeline with.
1389+
dataset (BaseDataset):
1390+
An object of the appropriate child class of `BaseDataset`,
1391+
that will be used to fit the pipeline
1392+
X_train, y_train, X_test, y_test: Union[np.ndarray, List, pd.DataFrame]
1393+
A pair of features (X_train) and targets (y_train) used to fit a
1394+
pipeline. Additionally, a holdout of this pairs (X_test, y_test) can
1395+
be provided to track the generalization performance of each stage.
1396+
dataset_name (Optional[str]):
1397+
Name of the dataset, if None, random value is used.
1398+
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
1399+
Strategy to split the training data. if None, uses
1400+
HoldoutValTypes.holdout_validation.
1401+
resampling_strategy_args (Optional[Dict[str, Any]]):
1402+
Arguments required for the chosen resampling strategy. If None, uses
1403+
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
1404+
in ```datasets/resampling_strategy.py```.
1405+
dataset_name (Optional[str]):
1406+
name of the dataset, used as experiment name.
13791407
run_time_limit_secs (int: default=60):
13801408
Time limit for a single call to the machine learning model.
13811409
Model fitting will be terminated if the machine learning algorithm
@@ -1445,15 +1473,6 @@ def fit_pipeline(
14451473
+ `all`:
14461474
do not save any of the above.
14471475
For more information check `autoPyTorch.evaluation.utils.DisableFileOutputParameters`.
1448-
configuration (Configuration):
1449-
configuration to fit the pipeline with.
1450-
**dataset_kwargs (Any):
1451-
Can contain either `dataset (BaseDataset)` object or
1452-
keyword arguments specifying the dataset like X_train, y_train,
1453-
X_test, y_test (Optional[Union[List, pd.DataFrame, np.ndarray]] = None)
1454-
and other parameters like dataset_name (str),
1455-
resampling_strategy (Union[HoldoutValTypes, CrossValTypes]),
1456-
resampling_strategy_args (Dict[str, Any]).
14571476
14581477
Returns:
14591478
(BasePipeline):
@@ -1466,16 +1485,20 @@ def fit_pipeline(
14661485
Dataset created from the given tensors
14671486
"""
14681487

1469-
if 'dataset' not in dataset_kwargs:
1488+
if dataset is None:
14701489
if (
1471-
dataset_kwargs.get('X_train', None) is not None
1472-
and dataset_kwargs.get('y_train', None) is not None
1490+
X_train is not None
1491+
and y_train is not None
14731492
):
14741493
raise ValueError("No dataset provided, must provide X_train, y_train tensors")
1475-
1476-
dataset = self.get_dataset(**dataset_kwargs)
1477-
else:
1478-
dataset = dataset_kwargs['dataset']
1494+
dataset = self.get_dataset(X_train=X_train,
1495+
y_train=y_train,
1496+
X_test=X_test,
1497+
y_test=y_test,
1498+
resampling_strategy=resampling_strategy,
1499+
resampling_strategy_args=resampling_strategy_args,
1500+
dataset_name=dataset_name
1501+
)
14791502

14801503
# dataset_name is created inside the constructor of BaseDataset
14811504
# we expect it to be not None. This is for mypy

0 commit comments

Comments
 (0)