@@ -738,7 +738,7 @@ def __init__(self, model, dependent_variable,
738738 etl = None ):
739739 self .model = model
740740 self ._input_model = model # In case we need to modify the input
741- if isinstance (dependent_variable , str ):
741+ if isinstance (dependent_variable , six . string_types ):
742742 # Standardize the dependent variable as a list.
743743 dependent_variable = [dependent_variable ]
744744 self .dependent_variable = dependent_variable
@@ -1133,7 +1133,8 @@ def predict(self, df=None, csv_path=None,
11331133 manifest = None , file_id = None , sql_where = None , sql_limit = None ,
11341134 primary_key = SENTINEL , output_table = None , output_db = None ,
11351135 if_exists = 'fail' , n_jobs = None , polling_interval = None ,
1136- cpu = None , memory = None , disk_space = None ):
1136+ cpu = None , memory = None , disk_space = None ,
1137+ targets_to_predict = None ):
11371138 """Make predictions on a trained model
11381139
11391140 Provide input through one of
@@ -1219,6 +1220,13 @@ def predict(self, df=None, csv_path=None,
12191220 RAM requested by the user for a single job.
12201221 disk_space : float, optional
12211222 disk space requested by the user for a single job.
1223+ targets_to_predict : list of str, optional
1224+ If this is a multi-output model, you may list a subset of
1225+ targets for which you wish to generate predictions.
1226+ This list must be a subset of the original `dependent_variable`
1227+ input. Ignoring some of the model's outputs will let predictions
1228+ complete faster and use less disk space. The default is to
1229+ produce scores for all targets.
12221230
12231231 Returns
12241232 -------
@@ -1265,6 +1273,12 @@ def predict(self, df=None, csv_path=None,
12651273 predict_args ['LIMITSQL' ] = sql_limit
12661274 if n_jobs :
12671275 predict_args ['N_JOBS' ] = n_jobs
1276+ if targets_to_predict :
1277+ if isinstance (targets_to_predict , six .string_types ):
1278+ targets_to_predict = [targets_to_predict ]
1279+ if self .predict_template_id > 10600 :
1280+ # This feature was added in v2.2.
1281+ predict_args ['TARGET_COLUMN' ] = ' ' .join (targets_to_predict )
12681282 if self .predict_template_id >= 9969 :
12691283 if cpu :
12701284 predict_args ['CPU' ] = cpu
0 commit comments