Skip to content

Commit 94e0250

Browse files
committed
add 1 un
1 parent c8228fb commit 94e0250

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ import org.apache.spark.ml.feature.Instance
2222
import org.apache.spark.ml.linalg._
2323

2424
/**
25-
* LinearSVCAggregator computes the gradient and loss for loss function ("hinge" or
25+
* HingeAggregator computes the gradient and loss for loss function ("hinge" or
2626
* "squared_hinge", as used in binary classification for instances in sparse or dense
2727
* vector in an online fashion.
2828
*
29-
* Two LinearSVCAggregator can be merged together to have a summary of loss and gradient of
29+
* Two HingeAggregators can be merged together to have a summary of loss and gradient of
3030
* the corresponding joint dataset.
3131
*
3232
* This class standardizes feature values during computation using bcFeaturesStd.
@@ -50,11 +50,11 @@ private[ml] class HingeAggregator(
5050
protected override val dim: Int = numFeaturesPlusIntercept
5151

5252
/**
53-
* Add a new training instance to this LinearSVCAggregator, and update the loss and gradient
53+
* Add a new training instance to this HingeAggregator, and update the loss and gradient
5454
* of the objective function.
5555
*
5656
* @param instance The instance of data point to be added.
57-
* @return This LinearSVCAggregator object.
57+
* @return This HingeAggregator object.
5858
*/
5959
def add(instance: Instance): this.type = {
6060
instance match { case Instance(label, weight, features) =>

mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeAggregatorSuite.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class HingeAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
6767
val interceptArray = Array(2.0)
6868
val agg = getNewAggregator(instances, Vectors.dense(coefArray ++ interceptArray),
6969
fitIntercept = true)
70-
withClue("LogisticAggregator does not support negative instance weights") {
70+
withClue("HingeAggregator does not support negative instance weights") {
7171
intercept[IllegalArgumentException] {
7272
agg.add(Instance(1.0, -1.0, Vectors.dense(2.0, 1.0)))
7373
}
@@ -133,4 +133,18 @@ class HingeAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
133133
assert(gradient ~== agg.gradient relTol 0.01)
134134
}
135135

136+
test("check with zero standard deviation") {
137+
val instancesConstantFeature = Array(
138+
Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)),
139+
Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)),
140+
Instance(1.0, 0.3, Vectors.dense(1.0, 0.5)))
141+
val binaryCoefArray = Array(1.0, 2.0)
142+
val intercept = 1.0
143+
val aggConstantFeatureBinary = getNewAggregator(instancesConstantFeature,
144+
Vectors.dense(binaryCoefArray ++ Array(intercept)), fitIntercept = true)
145+
instances.foreach(aggConstantFeatureBinary.add)
146+
// constant features should not affect gradient
147+
assert(aggConstantFeatureBinary.gradient(0) === 0.0)
148+
}
149+
136150
}

0 commit comments

Comments
 (0)