@@ -1353,14 +1353,6 @@ def refit(
1353
1353
def fit_pipeline (
1354
1354
self ,
1355
1355
configuration : Configuration ,
1356
- dataset : Optional [BaseDataset ] = None ,
1357
- X_train : Optional [Union [List , pd .DataFrame , np .ndarray ]] = None ,
1358
- y_train : Optional [Union [List , pd .DataFrame , np .ndarray ]] = None ,
1359
- X_test : Optional [Union [List , pd .DataFrame , np .ndarray ]] = None ,
1360
- y_test : Optional [Union [List , pd .DataFrame , np .ndarray ]] = None ,
1361
- dataset_name : Optional [str ] = None ,
1362
- resampling_strategy : Optional [Union [HoldoutValTypes , CrossValTypes ]] = None ,
1363
- resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
1364
1356
run_time_limit_secs : int = 60 ,
1365
1357
memory_limit : Optional [int ] = None ,
1366
1358
eval_metric : Optional [str ] = None ,
@@ -1372,6 +1364,7 @@ def fit_pipeline(
1372
1364
budget : Optional [float ] = None ,
1373
1365
pipeline_options : Optional [Dict ] = None ,
1374
1366
disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
1367
+ ** dataset_kwargs : Any
1375
1368
) -> Tuple [Optional [BasePipeline ], RunInfo , RunValue , BaseDataset ]:
1376
1369
"""
1377
1370
Fit a pipeline on the given task for the budget.
@@ -1383,19 +1376,6 @@ def fit_pipeline(
1383
1376
methods.
1384
1377
1385
1378
Args:
1386
- X_train, y_train, X_test, y_test: Union[np.ndarray, List, pd.DataFrame]
1387
- A pair of features (X_train) and targets (y_train) used to fit a
1388
- pipeline. Additionally, a holdout of this pairs (X_test, y_test) can
1389
- be provided to track the generalization performance of each stage.
1390
- dataset_name (Optional[str]):
1391
- Name of the dataset, if None, random value is used.
1392
- resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
1393
- Strategy to split the training data. if None, uses
1394
- HoldoutValTypes.holdout_validation.
1395
- resampling_strategy_args (Optional[Dict[str, Any]]):
1396
- Arguments required for the chosen resampling strategy. If None, uses
1397
- the default values provided in DEFAULT_RESAMPLING_PARAMETERS
1398
- in ```datasets/resampling_strategy.py```.
1399
1379
run_time_limit_secs (int: default=60):
1400
1380
Time limit for a single call to the machine learning model.
1401
1381
Model fitting will be terminated if the machine learning algorithm
@@ -1465,8 +1445,15 @@ def fit_pipeline(
1465
1445
+ `all`:
1466
1446
do not save any of the above.
1467
1447
For more information check `autoPyTorch.evaluation.utils.DisableFileOutputParameters`.
1468
- configuration: (Configuration)
1448
+ configuration (Configuration):
1469
1449
configuration to fit the pipeline with.
1450
+ **dataset_kwargs (Any):
1451
+ Can contain either `dataset (BaseDataset)` object or
1452
+ keyword arguments specifying the dataset like X_train, y_train,
1453
+ X_test, y_test (Optional[Union[List, pd.DataFrame, np.ndarray]] = None)
1454
+ and other parameters like dataset_name (str),
1455
+ resampling_strategy (Union[HoldoutValTypes, CrossValTypes]),
1456
+ resampling_strategy_args (Dict[str, Any]).
1470
1457
1471
1458
Returns:
1472
1459
(BasePipeline):
@@ -1477,19 +1464,18 @@ def fit_pipeline(
1477
1464
Result of fitting the pipeline
1478
1465
(BaseDataset):
1479
1466
Dataset created from the given tensors
1480
- """
1467
+ """
1468
+
1469
+ if 'dataset' not in dataset_kwargs :
1470
+ if (
1471
+ dataset_kwargs .get ('X_train' , None ) is not None
1472
+ and dataset_kwargs .get ('y_train' , None ) is not None
1473
+ ):
1474
+ raise ValueError ("No dataset provided, must provide X_train, y_train tensors" )
1481
1475
1482
- if dataset is None :
1483
- assert X_train is not None and \
1484
- y_train is not None , "No dataset provided, must provide X_train, y_train tensors"
1485
- dataset = self .get_dataset (X_train = X_train ,
1486
- y_train = y_train ,
1487
- X_test = X_test ,
1488
- y_test = y_test ,
1489
- resampling_strategy = resampling_strategy ,
1490
- resampling_strategy_args = resampling_strategy_args ,
1491
- dataset_name = dataset_name
1492
- )
1476
+ dataset = self .get_dataset (** dataset_kwargs )
1477
+ else :
1478
+ dataset = dataset_kwargs ['dataset' ]
1493
1479
1494
1480
# dataset_name is created inside the constructor of BaseDataset
1495
1481
# we expect it to be not None. This is for mypy
0 commit comments