Skip to content

Commit 14113f9

Browse files
committed
use **dataset_kwargs
1 parent c7cc712 commit 14113f9

File tree

1 file changed

+20
-34
lines changed

1 file changed

+20
-34
lines changed

autoPyTorch/api/base_task.py

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,14 +1353,6 @@ def refit(
13531353
def fit_pipeline(
13541354
self,
13551355
configuration: Configuration,
1356-
dataset: Optional[BaseDataset] = None,
1357-
X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
1358-
y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
1359-
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
1360-
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
1361-
dataset_name: Optional[str] = None,
1362-
resampling_strategy: Optional[Union[HoldoutValTypes, CrossValTypes]] = None,
1363-
resampling_strategy_args: Optional[Dict[str, Any]] = None,
13641356
run_time_limit_secs: int = 60,
13651357
memory_limit: Optional[int] = None,
13661358
eval_metric: Optional[str] = None,
@@ -1372,6 +1364,7 @@ def fit_pipeline(
13721364
budget: Optional[float] = None,
13731365
pipeline_options: Optional[Dict] = None,
13741366
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
1367+
**dataset_kwargs: Any
13751368
) -> Tuple[Optional[BasePipeline], RunInfo, RunValue, BaseDataset]:
13761369
"""
13771370
Fit a pipeline on the given task for the budget.
@@ -1383,19 +1376,6 @@ def fit_pipeline(
13831376
methods.
13841377
13851378
Args:
1386-
X_train, y_train, X_test, y_test: Union[np.ndarray, List, pd.DataFrame]
1387-
A pair of features (X_train) and targets (y_train) used to fit a
1388-
pipeline. Additionally, a holdout of this pairs (X_test, y_test) can
1389-
be provided to track the generalization performance of each stage.
1390-
dataset_name (Optional[str]):
1391-
Name of the dataset, if None, random value is used.
1392-
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
1393-
Strategy to split the training data. if None, uses
1394-
HoldoutValTypes.holdout_validation.
1395-
resampling_strategy_args (Optional[Dict[str, Any]]):
1396-
Arguments required for the chosen resampling strategy. If None, uses
1397-
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
1398-
in ```datasets/resampling_strategy.py```.
13991379
run_time_limit_secs (int: default=60):
14001380
Time limit for a single call to the machine learning model.
14011381
Model fitting will be terminated if the machine learning algorithm
@@ -1465,8 +1445,15 @@ def fit_pipeline(
14651445
+ `all`:
14661446
do not save any of the above.
14671447
For more information check `autoPyTorch.evaluation.utils.DisableFileOutputParameters`.
1468-
configuration: (Configuration)
1448+
configuration (Configuration):
14691449
configuration to fit the pipeline with.
1450+
**dataset_kwargs (Any):
1451+
Can contain either `dataset (BaseDataset)` object or
1452+
keyword arguments specifying the dataset like X_train, y_train,
1453+
X_test, y_test (Optional[Union[List, pd.DataFrame, np.ndarray]] = None)
1454+
and other parameters like dataset_name (str),
1455+
resampling_strategy (Union[HoldoutValTypes, CrossValTypes]),
1456+
resampling_strategy_args (Dict[str, Any]).
14701457
14711458
Returns:
14721459
(BasePipeline):
@@ -1477,19 +1464,18 @@ def fit_pipeline(
14771464
Result of fitting the pipeline
14781465
(BaseDataset):
14791466
Dataset created from the given tensors
1480-
"""
1467+
"""
1468+
1469+
if 'dataset' not in dataset_kwargs:
1470+
if (
1471+
dataset_kwargs.get('X_train', None) is not None
1472+
and dataset_kwargs.get('y_train', None) is not None
1473+
):
1474+
raise ValueError("No dataset provided, must provide X_train, y_train tensors")
14811475

1482-
if dataset is None:
1483-
assert X_train is not None and \
1484-
y_train is not None, "No dataset provided, must provide X_train, y_train tensors"
1485-
dataset = self.get_dataset(X_train=X_train,
1486-
y_train=y_train,
1487-
X_test=X_test,
1488-
y_test=y_test,
1489-
resampling_strategy=resampling_strategy,
1490-
resampling_strategy_args=resampling_strategy_args,
1491-
dataset_name=dataset_name
1492-
)
1476+
dataset = self.get_dataset(**dataset_kwargs)
1477+
else:
1478+
dataset = dataset_kwargs['dataset']
14931479

14941480
# dataset_name is created inside the constructor of BaseDataset
14951481
# we expect it to be not None. This is for mypy

0 commit comments

Comments
 (0)