Skip to content

Commit

Permalink
[SPARK-19291][SPARKR][ML] spark.gaussianMixture supports output log-l…
Browse files Browse the repository at this point in the history
…ikelihood.

## What changes were proposed in this pull request?
```spark.gaussianMixture``` supports output total log-likelihood for the model like R ```mvnormalmixEM```.

## How was this patch tested?
R unit test.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes apache#16646 from yanboliang/spark-19291.
  • Loading branch information
yanboliang committed Jan 22, 2017
1 parent 3dcad9f commit 0c589e3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
5 changes: 3 additions & 2 deletions R/pkg/R/mllib_clustering.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula =
#' @param object a fitted gaussian mixture model.
#' @return \code{summary} returns summary of the fitted model, which is a list.
#' The list includes the model's \code{lambda} (lambda), \code{mu} (mu),
#' \code{sigma} (sigma), and \code{posterior} (posterior).
#' \code{sigma} (sigma), \code{loglik} (loglik), and \code{posterior} (posterior).
#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method
#' @rdname spark.gaussianMixture
#' @export
Expand All @@ -112,6 +112,7 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
sigmaList <- callJMethod(jobj, "sigma")
k <- callJMethod(jobj, "k")
dim <- callJMethod(jobj, "dim")
loglik <- callJMethod(jobj, "logLikelihood")
mu <- c()
for (i in 1 : k) {
start <- (i - 1) * dim + 1
Expand All @@ -129,7 +130,7 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
} else {
dataFrame(callJMethod(jobj, "posterior"))
}
list(lambda = lambda, mu = mu, sigma = sigma,
list(lambda = lambda, mu = mu, sigma = sigma, loglik = loglik,
posterior = posterior, is.loaded = is.loaded)
})

Expand Down
7 changes: 7 additions & 0 deletions R/pkg/inst/tests/testthat/test_mllib_clustering.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ test_that("spark.gaussianMixture", {
# [,1] [,2]
# [1,] 0.2961543 0.160783
# [2,] 0.1607830 1.008878
#
#' model$loglik
#
# [1] -46.89499
# nolint end
data <- list(list(-0.6264538, 0.1836433), list(-0.8356286, 1.5952808),
list(0.3295078, -0.8204684), list(0.4874291, 0.7383247),
Expand All @@ -72,9 +76,11 @@ test_that("spark.gaussianMixture", {
rMu <- c(0.11731091, -0.06192351, 10.363673, 9.897081)
rSigma <- c(0.62049934, 0.06880802, 0.06880802, 1.27431874,
0.2961543, 0.160783, 0.1607830, 1.008878)
rLoglik <- -46.89499
expect_equal(stats$lambda, rLambda, tolerance = 1e-3)
expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3)
expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3)
expect_equal(unlist(stats$loglik), rLoglik, tolerance = 1e-3)
p <- collect(select(predict(model, df), "prediction"))
expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1))

Expand All @@ -88,6 +94,7 @@ test_that("spark.gaussianMixture", {
expect_equal(stats$lambda, stats2$lambda)
expect_equal(unlist(stats$mu), unlist(stats2$mu))
expect_equal(unlist(stats$sigma), unlist(stats2$sigma))
expect_equal(unlist(stats$loglik), unlist(stats2$loglik))

unlink(modelPath)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.functions._
private[r] class GaussianMixtureWrapper private (
val pipeline: PipelineModel,
val dim: Int,
val logLikelihood: Double,
val isLoaded: Boolean = false) extends MLWritable {

private val gmm: GaussianMixtureModel = pipeline.stages(1).asInstanceOf[GaussianMixtureModel]
Expand Down Expand Up @@ -91,7 +92,10 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
.setStages(Array(rFormulaModel, gm))
.fit(data)

new GaussianMixtureWrapper(pipeline, dim)
val gmm: GaussianMixtureModel = pipeline.stages(1).asInstanceOf[GaussianMixtureModel]
val logLikelihood: Double = gmm.summary.logLikelihood

new GaussianMixtureWrapper(pipeline, dim, logLikelihood)
}

override def read: MLReader[GaussianMixtureWrapper] = new GaussianMixtureWrapperReader
Expand All @@ -105,7 +109,8 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
val pipelinePath = new Path(path, "pipeline").toString

val rMetadata = ("class" -> instance.getClass.getName) ~
("dim" -> instance.dim)
("dim" -> instance.dim) ~
("logLikelihood" -> instance.logLikelihood)
val rMetadataJson: String = compact(render(rMetadata))

sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
Expand All @@ -124,7 +129,8 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
val rMetadata = parse(rMetadataStr)
val dim = (rMetadata \ "dim").extract[Int]
new GaussianMixtureWrapper(pipeline, dim, isLoaded = true)
val logLikelihood = (rMetadata \ "logLikelihood").extract[Double]
new GaussianMixtureWrapper(pipeline, dim, logLikelihood, isLoaded = true)
}
}
}

0 comments on commit 0c589e3

Please sign in to comment.