@@ -55,9 +55,26 @@ class GaussianMixtureModelEM private (
55
55
// number of samples per cluster to use when initializing Gaussians
56
56
private val nSamples = 5
57
57
58
+ // an initializing GMM can be provided rather than using the
59
+ // default random starting point
60
+ private var initialGmm : Option [GaussianMixtureModel ] = None
61
+
58
62
/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
59
63
def this () = this (2 , 0.01 , 100 )
60
64
65
+ /** Set the initial GMM starting point, bypassing the random initialization */
66
+ def setInitialGmm (gmm : GaussianMixtureModel ): this .type = {
67
+ if (gmm.k == k) {
68
+ initialGmm = Some (gmm)
69
+ } else {
70
+ throw new IllegalArgumentException (" initialing GMM has mismatched cluster count (gmm.k != k)" )
71
+ }
72
+ this
73
+ }
74
+
75
+ /** Return the user supplied initial GMM, if supplied */
76
+ def getInitialiGmm : Option [GaussianMixtureModel ] = initialGmm
77
+
61
78
/** Set the number of Gaussians in the mixture model. Default: 2 */
62
79
def setK (k : Int ): this .type = {
63
80
this .k = k
@@ -103,20 +120,35 @@ class GaussianMixtureModelEM private (
103
120
// Get length of the input vectors
104
121
val d = breezeData.first.length
105
122
106
- // For each Gaussian, we will initialize the mean as the average
107
- // of some random samples from the data
108
- val samples = breezeData.takeSample(true , k * nSamples, scala.util.Random .nextInt)
109
-
110
- // gaussians will be array of (weight, mean, covariance) tuples
123
+ // gaussians will be array of (weight, mean, covariance) tuples.
124
+ // If the user supplied an initial GMM, we use those values, otherwise
111
125
// we start with uniform weights, a random mean from the data, and
112
126
// diagonal covariance matrices using component variances
113
127
// derived from the samples
114
- var gaussians = (0 until k).map{ i =>
128
+ var gaussians = initialGmm match {
129
+ case Some (gmm) => (0 until k).map{ i =>
130
+ (gmm.weight(i), gmm.mu(i).toBreeze.toDenseVector, gmm.sigma(i).toBreeze.toDenseMatrix)
131
+ }.toArray
132
+
133
+ case None => {
134
+ // For each Gaussian, we will initialize the mean as the average
135
+ // of some random samples from the data
136
+ val samples = breezeData.takeSample(true , k * nSamples, scala.util.Random .nextInt)
137
+
138
+ (0 until k).map{ i =>
139
+ (1.0 / k,
140
+ vectorMean(samples.slice(i * nSamples, (i + 1 ) * nSamples)),
141
+ initCovariance(samples.slice(i * nSamples, (i + 1 ) * nSamples)))
142
+ }.toArray
143
+ }
144
+ }
145
+
146
+ /* var gaussians = (0 until k).map{ i =>
115
147
(1.0 / k,
116
148
vectorMean(samples.slice(i * nSamples, (i + 1) * nSamples)),
117
149
initCovariance(samples.slice(i * nSamples, (i + 1) * nSamples)))
118
150
}.toArray
119
-
151
+ */
120
152
val accW = new Array [Accumulator [Double ]](k)
121
153
val accMu = new Array [Accumulator [DenseDoubleVector ]](k)
122
154
val accSigma = new Array [Accumulator [DenseDoubleMatrix ]](k)
0 commit comments