@@ -738,7 +738,7 @@ def __init__(self, model, dependent_variable,
738738 etl = None ):
739739 self .model = model
740740 self ._input_model = model # In case we need to modify the input
741- if isinstance (dependent_variable , str ):
741+ if isinstance (dependent_variable , six . string_types ):
742742 # Standardize the dependent variable as a list.
743743 dependent_variable = [dependent_variable ]
744744 self .dependent_variable = dependent_variable
@@ -1133,7 +1133,8 @@ def predict(self, df=None, csv_path=None,
11331133 manifest = None , file_id = None , sql_where = None , sql_limit = None ,
11341134 primary_key = SENTINEL , output_table = None , output_db = None ,
11351135 if_exists = 'fail' , n_jobs = None , polling_interval = None ,
1136- cpu = None , memory = None , disk_space = None ):
1136+ cpu = None , memory = None , disk_space = None ,
1137+ dvs_to_predict = None ):
11371138 """Make predictions on a trained model
11381139
11391140 Provide input through one of
@@ -1219,6 +1220,15 @@ def predict(self, df=None, csv_path=None,
12191220 RAM requested by the user for a single job.
12201221 disk_space : float, optional
12211222 disk space requested by the user for a single job.
1223+ dvs_to_predict : list of str, optional
1224+ If this is a multi-output model, you may list a subset of
1225+ dependent variables for which you wish to generate predictions.
1226+ This list must be a subset of the original `dependent_variable`
1227+ input. The scores for the returned subset will be identical to
1228+ the scores which those outputs would have had if all outputs
1229+ were written, but ignoring some of the model's outputs will
1230+ let predictions complete faster and use less disk space.
1231+ The default is to produce scores for all DVs.
12221232
12231233 Returns
12241234 -------
@@ -1265,6 +1275,12 @@ def predict(self, df=None, csv_path=None,
12651275 predict_args ['LIMITSQL' ] = sql_limit
12661276 if n_jobs :
12671277 predict_args ['N_JOBS' ] = n_jobs
1278+ if dvs_to_predict :
1279+ if isinstance (dvs_to_predict , six .string_types ):
1280+ dvs_to_predict = [dvs_to_predict ]
1281+ if self .predict_template_id > 10583 :
1282+ # This feature was added in v2.2; 10583 is the v2.1 template
1283+ predict_args ['TARGET_COLUMN' ] = ' ' .join (dvs_to_predict )
12681284 if self .predict_template_id >= 9969 :
12691285 if cpu :
12701286 predict_args ['CPU' ] = cpu
0 commit comments