Skip to content

Commit a9483ef

Browse files
committed
add PCA transform
1 parent da43731 commit a9483ef

File tree

1 file changed

+15
-2
lines changed
  • mllib/src/main/scala/org/apache/spark/ml/feature

1 file changed

+15
-2
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,26 @@ class PCAModel private[ml] (
150150
OldMatrices.fromML(pc).asInstanceOf[OldDenseMatrix],
151151
OldVectors.fromML(explainedVariance).asInstanceOf[OldDenseVector])
152152

153-
// TODO: Make the transformer natively in ml framework to avoid extra conversion.
154-
val transformer: Vector => Vector = v => pcaModel.transform(OldVectors.fromML(v)).asML
153+
val transformer: Vector => Vector = v => transform(pcaModel.pc.asML, v)
155154

156155
val pcaOp = udf(transformer)
157156
dataset.withColumn($(outputCol), pcaOp(col($(inputCol))))
158157
}
159158

159+
private def transform(pc: DenseMatrix, vector: Vector): Vector = {
160+
vector match {
161+
case dv: DenseVector =>
162+
pc.transpose.multiply(dv)
163+
case SparseVector(size, indices, values) =>
164+
val sm = Matrices.sparse(size, 1, Array(0, indices.length), indices, values).transpose
165+
val projection = sm.multiply(pc)
166+
Vectors.dense(projection.values)
167+
case _ =>
168+
throw new IllegalArgumentException("Unsupported vector format. Expected " +
169+
s"SparseVector or DenseVector. Instead got: ${vector.getClass}")
170+
}
171+
}
172+
160173
@Since("1.5.0")
161174
override def transformSchema(schema: StructType): StructType = {
162175
validateAndTransformSchema(schema)

0 commit comments

Comments
 (0)