@@ -23,7 +23,7 @@ import org.json4s.JsonDSL._
23
23
24
24
import org .apache .spark .annotation .Since
25
25
import org .apache .spark .internal .Logging
26
- import org .apache .spark .ml .feature .LabeledPoint
26
+ import org .apache .spark .ml .feature .Instance
27
27
import org .apache .spark .ml .linalg .{DenseVector , SparseVector , Vector , Vectors }
28
28
import org .apache .spark .ml .param .ParamMap
29
29
import org .apache .spark .ml .regression .DecisionTreeRegressionModel
@@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
34
34
import org .apache .spark .ml .util .Instrumentation .instrumented
35
35
import org .apache .spark .mllib .tree .configuration .{Algo => OldAlgo }
36
36
import org .apache .spark .mllib .tree .model .{GradientBoostedTreesModel => OldGBTModel }
37
- import org .apache .spark .sql .{DataFrame , Dataset , Row }
37
+ import org .apache .spark .sql .{DataFrame , Dataset }
38
38
import org .apache .spark .sql .functions ._
39
39
40
40
/**
@@ -79,6 +79,10 @@ class GBTClassifier @Since("1.4.0") (
79
79
@ Since (" 1.4.0" )
80
80
def setMinInstancesPerNode (value : Int ): this .type = set(minInstancesPerNode, value)
81
81
82
+ /** @group setParam */
83
+ @ Since (" 3.0.0" )
84
+ def setMinWeightFractionPerNode (value : Double ): this .type = set(minWeightFractionPerNode, value)
85
+
82
86
/** @group setParam */
83
87
@ Since (" 1.4.0" )
84
88
def setMinInfoGain (value : Double ): this .type = set(minInfoGain, value)
@@ -152,36 +156,34 @@ class GBTClassifier @Since("1.4.0") (
152
156
set(validationIndicatorCol, value)
153
157
}
154
158
159
+ /**
160
+ * Sets the value of param [[weightCol ]].
161
+ * If this is not set or empty, we treat all instance weights as 1.0.
162
+ * By default the weightCol is not set, so all instances have weight 1.0.
163
+ *
164
+ * @group setParam
165
+ */
166
+ @ Since (" 3.0.0" )
167
+ def setWeightCol (value : String ): this .type = set(weightCol, value)
168
+
155
169
override protected def train (
156
170
dataset : Dataset [_]): GBTClassificationModel = instrumented { instr =>
157
- val categoricalFeatures : Map [Int , Int ] =
158
- MetadataUtils .getCategoricalFeatures(dataset.schema($(featuresCol)))
159
-
160
171
val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
161
172
162
- // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
163
- // 2 classes now. This lets us provide a more precise error message.
164
- val convert2LabeledPoint = (dataset : Dataset [_]) => {
165
- dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
166
- case Row (label : Double , features : Vector ) =>
167
- require(label == 0 || label == 1 , s " GBTClassifier was given " +
168
- s " dataset with invalid label $label. Labels must be in {0,1}; note that " +
169
- s " GBTClassifier currently only supports binary classification. " )
170
- LabeledPoint (label, features)
171
- }
173
+ val validateInstance = (instance : Instance ) => {
174
+ val label = instance.label
175
+ require(label == 0 || label == 1 , s " GBTClassifier was given " +
176
+ s " dataset with invalid label $label. Labels must be in {0,1}; note that " +
177
+ s " GBTClassifier currently only supports binary classification. " )
172
178
}
173
179
174
180
val (trainDataset, validationDataset) = if (withValidation) {
175
- (
176
- convert2LabeledPoint(dataset.filter(not(col($(validationIndicatorCol))))),
177
- convert2LabeledPoint(dataset.filter(col($(validationIndicatorCol))))
178
- )
181
+ (extractInstances(dataset.filter(not(col($(validationIndicatorCol)))), validateInstance),
182
+ extractInstances(dataset.filter(col($(validationIndicatorCol))), validateInstance))
179
183
} else {
180
- (convert2LabeledPoint (dataset), null )
184
+ (extractInstances (dataset, validateInstance ), null )
181
185
}
182
186
183
- val boostingStrategy = super .getOldBoostingStrategy(categoricalFeatures, OldAlgo .Classification )
184
-
185
187
val numClasses = 2
186
188
if (isDefined(thresholds)) {
187
189
require($(thresholds).length == numClasses, this .getClass.getSimpleName +
@@ -191,12 +193,14 @@ class GBTClassifier @Since("1.4.0") (
191
193
192
194
instr.logPipelineStage(this )
193
195
instr.logDataset(dataset)
194
- instr.logParams(this , labelCol, featuresCol, predictionCol, leafCol, impurity ,
195
- lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode ,
196
- seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy ,
197
- validationIndicatorCol, validationTol)
196
+ instr.logParams(this , labelCol, weightCol, featuresCol, predictionCol, leafCol,
197
+ impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain,
198
+ minInstancesPerNode, minWeightFractionPerNode, seed, stepSize, subsamplingRate, cacheNodeIds ,
199
+ checkpointInterval, featureSubsetStrategy, validationIndicatorCol, validationTol)
198
200
instr.logNumClasses(numClasses)
199
201
202
+ val categoricalFeatures = MetadataUtils .getCategoricalFeatures(dataset.schema($(featuresCol)))
203
+ val boostingStrategy = super .getOldBoostingStrategy(categoricalFeatures, OldAlgo .Classification )
200
204
val (baseLearners, learnerWeights) = if (withValidation) {
201
205
GradientBoostedTrees .runWithValidation(trainDataset, validationDataset, boostingStrategy,
202
206
$(seed), $(featureSubsetStrategy))
@@ -374,12 +378,9 @@ class GBTClassificationModel private[ml](
374
378
*/
375
379
@ Since (" 2.4.0" )
376
380
def evaluateEachIteration (dataset : Dataset [_]): Array [Double ] = {
377
- val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
378
- case Row (label : Double , features : Vector ) => LabeledPoint (label, features)
379
- }
381
+ val data = extractInstances(dataset)
380
382
GradientBoostedTrees .evaluateEachIteration(data, trees, treeWeights, loss,
381
- OldAlgo .Classification
382
- )
383
+ OldAlgo .Classification )
383
384
}
384
385
385
386
@ Since (" 2.0.0" )
@@ -423,10 +424,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
423
424
val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int ]
424
425
val numTrees = (metadata.metadata \ numTreesKey).extract[Int ]
425
426
426
- val trees : Array [ DecisionTreeRegressionModel ] = treesData.map {
427
+ val trees = treesData.map {
427
428
case (treeMetadata, root) =>
428
- val tree =
429
- new DecisionTreeRegressionModel (treeMetadata.uid, root, numFeatures)
429
+ val tree = new DecisionTreeRegressionModel (treeMetadata.uid, root, numFeatures)
430
430
treeMetadata.getAndSetParams(tree)
431
431
tree
432
432
}
0 commit comments