Skip to content

Commit c3ba4d4

Browse files
committed
[MLlib] Minor: UDF style update.
Author: Reynold Xin <rxin@databricks.com> Closes apache#4388 from rxin/mllib-style and squashes the following commits: 61d465b [Reynold Xin] oops 3364295 [Reynold Xin] Missed one .. 5e068e3 [Reynold Xin] [MLlib] Minor: UDF style update.
1 parent 7d789e1 commit c3ba4d4

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,14 @@ class LogisticRegressionModel private[ml] (
132132
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
133133
transformSchema(dataset.schema, paramMap, logging = true)
134134
val map = this.paramMap ++ paramMap
135-
val scoreFunction = udf((v: Vector) => {
135+
val scoreFunction = udf { v: Vector =>
136136
val margin = BLAS.dot(v, weights)
137137
1.0 / (1.0 + math.exp(-margin))
138-
} : Double)
138+
}
139139
val t = map(threshold)
140-
val predictFunction = udf((score: Double) => { if (score > t) 1.0 else 0.0 } : Double)
140+
val predictFunction = udf { score: Double =>
141+
if (score > t) 1.0 else 0.0
142+
}
141143
dataset
142144
.select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol)))
143145
.select($"*", predictFunction(col(map(scoreCol))).as(map(predictionCol)))

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,13 @@ class ALSModel private[ml] (
129129

130130
// Register a UDF for DataFrame, and then
131131
// create a new column named map(predictionCol) by running the predict UDF.
132-
val predict = udf((userFeatures: Seq[Float], itemFeatures: Seq[Float]) => {
132+
val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
133133
if (userFeatures != null && itemFeatures != null) {
134134
blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
135135
} else {
136136
Float.NaN
137137
}
138-
} : Float)
138+
}
139139
dataset
140140
.join(users, dataset(map(userCol)) === users("id"), "left")
141141
.join(items, dataset(map(itemCol)) === items("id"), "left")

0 commit comments

Comments
 (0)