|
28 | 28 | from pyspark import RDD
|
29 | 29 | from pyspark import SparkContext
|
30 | 30 | from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py
|
31 |
| -from pyspark.mllib.linalg import SparseVector, _convert_to_vector |
| 31 | +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector |
32 | 32 | from pyspark.mllib.stat.distribution import MultivariateGaussian
|
33 | 33 | from pyspark.mllib.util import Saveable, Loader, inherit_doc
|
| 34 | +from pyspark.streaming import DStream |
34 | 35 |
|
35 | 36 | __all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture']
|
36 | 37 |
|
@@ -269,14 +270,46 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
|
269 | 270 | class StreamingKMeansModel(KMeansModel):
|
270 | 271 | """
|
271 | 272 | .. note:: Experimental
|
| 273 | +
|
| 274 | + >>> initCenters, initWeights = [[0.0, 0.0], [1.0, 1.0]], [1.0, 1.0] |
| 275 | + >>> stkm = StreamingKMeansModel(initCenters, initWeights) |
| 276 | + >>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1], |
| 277 | + ... [0.9, 0.9], [1.1, 1.1]]) |
| 278 | + >>> stkm = stkm.update(data, 1.0, "batches") |
| 279 | + >>> stkm.centers |
| 280 | + array([[ 0., 0.], |
| 281 | + [ 1., 1.]]) |
| 282 | + >>> stkm.predict([-0.1, -0.1]) == stkm.predict([0.1, 0.1]) == 0 |
| 283 | + True |
| 284 | + >>> stkm.predict([0.9, 0.9]) == stkm.predict([1.1, 1.1]) == 1 |
| 285 | + True |
| 286 | + >>> stkm.getClusterWeights |
| 287 | + [3.0, 3.0] |
| 288 | + >>> decayFactor = 0.0 |
| 289 | + >>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2, 0.2])]) |
| 290 | + >>> stkm = stkm.update(data, 0.0, "batches") |
| 291 | + >>> stkm.centers |
| 292 | + array([[ 0.2, 0.2], |
| 293 | + [ 1.5, 1.5]]) |
| 294 | + >>> stkm.getClusterWeights |
| 295 | + [1.0, 1.0] |
| 296 | + >>> stkm.predict([0.2, 0.2]) |
| 297 | + 0 |
| 298 | + >>> stkm.predict([1.5, 1.5]) |
| 299 | + 1 |
272 | 300 | """
|
273 | 301 | def __init__(self, clusterCenters, clusterWeights):
|
274 | 302 | super(StreamingKMeansModel, self).__init__(centers=clusterCenters)
|
275 | 303 | self._clusterWeights = list(clusterWeights)
|
276 | 304 |
|
| 305 | + @property |
| 306 | + def getClusterWeights(self): |
| 307 | + return self._clusterWeights |
| 308 | + |
277 | 309 | def update(self, data, decayFactor, timeUnit):
|
278 | 310 | if not isinstance(data, RDD):
|
279 | 311 | raise TypeError("data should be of a RDD, got %s." % type(data))
|
| 312 | + data = data.map(_convert_to_vector) |
280 | 313 | decayFactor = float(decayFactor)
|
281 | 314 | if timeUnit not in ["batches", "points"]:
|
282 | 315 | raise ValueError(
|
@@ -306,7 +339,7 @@ def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"):
|
306 | 339 | def _validate(self, dstream):
|
307 | 340 | if self.model is None:
|
308 | 341 | raise ValueError(
|
309 |
| - "Initial centers should be set either by setInitialCenters ") |
| 342 | + "Initial centers should be set either by setInitialCenters " |
310 | 343 | "or setRandomCenters.")
|
311 | 344 | if not isinstance(dstream, DStream):
|
312 | 345 | raise TypeError(
|
@@ -342,18 +375,18 @@ def trainOn(self, dstream):
|
342 | 375 |
|
343 | 376 | def update(_, rdd):
|
344 | 377 | if rdd:
|
345 |
| - self.model = self.model.update(rdd) |
| 378 | + self.model = self.model.update(rdd, self._decayFactor, self._timeUnit) |
346 | 379 |
|
347 | 380 | dstream.foreachRDD(update)
|
348 | 381 | return self
|
349 | 382 |
|
350 | 383 | def predictOn(self, dstream):
|
351 | 384 | self._validate(dstream)
|
352 |
| - dstream.map(model.predict) |
| 385 | + dstream.map(self.model.predict) |
353 | 386 |
|
354 | 387 | def predictOnValues(self, dstream):
|
355 | 388 | self._validate(dstream)
|
356 |
| - dstream.mapValues(model.predict) |
| 389 | + dstream.mapValues(self.model.predict) |
357 | 390 |
|
358 | 391 |
|
359 | 392 | def _test():
|
|
0 commit comments