@@ -1353,6 +1353,15 @@ def refit(
1353
1353
def fit_pipeline (
1354
1354
self ,
1355
1355
configuration : Configuration ,
1356
+ * ,
1357
+ dataset : Optional [BaseDataset ] = None ,
1358
+ X_train : Optional [Union [List , pd .DataFrame , np .ndarray ]] = None ,
1359
+ y_train : Optional [Union [List , pd .DataFrame , np .ndarray ]] = None ,
1360
+ X_test : Optional [Union [List , pd .DataFrame , np .ndarray ]] = None ,
1361
+ y_test : Optional [Union [List , pd .DataFrame , np .ndarray ]] = None ,
1362
+ dataset_name : Optional [str ] = None ,
1363
+ resampling_strategy : Optional [Union [HoldoutValTypes , CrossValTypes ]] = None ,
1364
+ resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
1356
1365
run_time_limit_secs : int = 60 ,
1357
1366
memory_limit : Optional [int ] = None ,
1358
1367
eval_metric : Optional [str ] = None ,
@@ -1364,7 +1373,6 @@ def fit_pipeline(
1364
1373
budget : Optional [float ] = None ,
1365
1374
pipeline_options : Optional [Dict ] = None ,
1366
1375
disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
1367
- ** dataset_kwargs : Any
1368
1376
) -> Tuple [Optional [BasePipeline ], RunInfo , RunValue , BaseDataset ]:
1369
1377
"""
1370
1378
Fit a pipeline on the given task for the budget.
@@ -1376,6 +1384,26 @@ def fit_pipeline(
1376
1384
methods.
1377
1385
1378
1386
Args:
1387
+ configuration (Configuration):
1388
+ configuration to fit the pipeline with.
1389
+ dataset (BaseDataset):
1390
+ An object of the appropriate child class of `BaseDataset`,
1391
+ that will be used to fit the pipeline
1392
+ X_train, y_train, X_test, y_test: Union[np.ndarray, List, pd.DataFrame]
1393
+ A pair of features (X_train) and targets (y_train) used to fit a
1394
+ pipeline. Additionally, a holdout of this pairs (X_test, y_test) can
1395
+ be provided to track the generalization performance of each stage.
1396
+ dataset_name (Optional[str]):
1397
+ Name of the dataset, if None, random value is used.
1398
+ resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
1399
+ Strategy to split the training data. if None, uses
1400
+ HoldoutValTypes.holdout_validation.
1401
+ resampling_strategy_args (Optional[Dict[str, Any]]):
1402
+ Arguments required for the chosen resampling strategy. If None, uses
1403
+ the default values provided in DEFAULT_RESAMPLING_PARAMETERS
1404
+ in ```datasets/resampling_strategy.py```.
1405
+ dataset_name (Optional[str]):
1406
+ name of the dataset, used as experiment name.
1379
1407
run_time_limit_secs (int: default=60):
1380
1408
Time limit for a single call to the machine learning model.
1381
1409
Model fitting will be terminated if the machine learning algorithm
@@ -1445,15 +1473,6 @@ def fit_pipeline(
1445
1473
+ `all`:
1446
1474
do not save any of the above.
1447
1475
For more information check `autoPyTorch.evaluation.utils.DisableFileOutputParameters`.
1448
- configuration (Configuration):
1449
- configuration to fit the pipeline with.
1450
- **dataset_kwargs (Any):
1451
- Can contain either `dataset (BaseDataset)` object or
1452
- keyword arguments specifying the dataset like X_train, y_train,
1453
- X_test, y_test (Optional[Union[List, pd.DataFrame, np.ndarray]] = None)
1454
- and other parameters like dataset_name (str),
1455
- resampling_strategy (Union[HoldoutValTypes, CrossValTypes]),
1456
- resampling_strategy_args (Dict[str, Any]).
1457
1476
1458
1477
Returns:
1459
1478
(BasePipeline):
@@ -1466,16 +1485,20 @@ def fit_pipeline(
1466
1485
Dataset created from the given tensors
1467
1486
"""
1468
1487
1469
- if ' dataset' not in dataset_kwargs :
1488
+ if dataset is None :
1470
1489
if (
1471
- dataset_kwargs . get ( ' X_train' , None ) is not None
1472
- and dataset_kwargs . get ( ' y_train' , None ) is not None
1490
+ X_train is not None
1491
+ and y_train is not None
1473
1492
):
1474
1493
raise ValueError ("No dataset provided, must provide X_train, y_train tensors" )
1475
-
1476
- dataset = self .get_dataset (** dataset_kwargs )
1477
- else :
1478
- dataset = dataset_kwargs ['dataset' ]
1494
+ dataset = self .get_dataset (X_train = X_train ,
1495
+ y_train = y_train ,
1496
+ X_test = X_test ,
1497
+ y_test = y_test ,
1498
+ resampling_strategy = resampling_strategy ,
1499
+ resampling_strategy_args = resampling_strategy_args ,
1500
+ dataset_name = dataset_name
1501
+ )
1479
1502
1480
1503
# dataset_name is created inside the constructor of BaseDataset
1481
1504
# we expect it to be not None. This is for mypy
0 commit comments