Skip to content

Commit 4b1481f

Browse files
committed
Some changes and tests
1 parent d8b066a commit 4b1481f

File tree

2 files changed

+78
-5
lines changed

2 files changed

+78
-5
lines changed

python/pyspark/mllib/clustering.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828
from pyspark import RDD
2929
from pyspark import SparkContext
3030
from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py
31-
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
31+
from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector
3232
from pyspark.mllib.stat.distribution import MultivariateGaussian
3333
from pyspark.mllib.util import Saveable, Loader, inherit_doc
34+
from pyspark.streaming import DStream
3435

3536
__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture']
3637

@@ -269,14 +270,46 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
269270
class StreamingKMeansModel(KMeansModel):
270271
"""
271272
.. note:: Experimental
273+
274+
>>> initCenters, initWeights = [[0.0, 0.0], [1.0, 1.0]], [1.0, 1.0]
275+
>>> stkm = StreamingKMeansModel(initCenters, initWeights)
276+
>>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1],
277+
... [0.9, 0.9], [1.1, 1.1]])
278+
>>> stkm = stkm.update(data, 1.0, "batches")
279+
>>> stkm.centers
280+
array([[ 0., 0.],
281+
[ 1., 1.]])
282+
>>> stkm.predict([-0.1, -0.1]) == stkm.predict([0.1, 0.1]) == 0
283+
True
284+
>>> stkm.predict([0.9, 0.9]) == stkm.predict([1.1, 1.1]) == 1
285+
True
286+
>>> stkm.getClusterWeights
287+
[3.0, 3.0]
288+
>>> decayFactor = 0.0
289+
>>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2, 0.2])])
290+
>>> stkm = stkm.update(data, 0.0, "batches")
291+
>>> stkm.centers
292+
array([[ 0.2, 0.2],
293+
[ 1.5, 1.5]])
294+
>>> stkm.getClusterWeights
295+
[1.0, 1.0]
296+
>>> stkm.predict([0.2, 0.2])
297+
0
298+
>>> stkm.predict([1.5, 1.5])
299+
1
272300
"""
273301
def __init__(self, clusterCenters, clusterWeights):
274302
super(StreamingKMeansModel, self).__init__(centers=clusterCenters)
275303
self._clusterWeights = list(clusterWeights)
276304

305+
@property
306+
def getClusterWeights(self):
307+
return self._clusterWeights
308+
277309
def update(self, data, decayFactor, timeUnit):
278310
if not isinstance(data, RDD):
279311
raise TypeError("data should be of a RDD, got %s." % type(data))
312+
data = data.map(_convert_to_vector)
280313
decayFactor = float(decayFactor)
281314
if timeUnit not in ["batches", "points"]:
282315
raise ValueError(
@@ -306,7 +339,7 @@ def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"):
306339
def _validate(self, dstream):
307340
if self.model is None:
308341
raise ValueError(
309-
"Initial centers should be set either by setInitialCenters ")
342+
"Initial centers should be set either by setInitialCenters "
310343
"or setRandomCenters.")
311344
if not isinstance(dstream, DStream):
312345
raise TypeError(
@@ -342,18 +375,18 @@ def trainOn(self, dstream):
342375

343376
def update(_, rdd):
344377
if rdd:
345-
self.model = self.model.update(rdd)
378+
self.model = self.model.update(rdd, self._decayFactor, self._timeUnit)
346379

347380
dstream.foreachRDD(update)
348381
return self
349382

350383
def predictOn(self, dstream):
351384
self._validate(dstream)
352-
dstream.map(model.predict)
385+
dstream.map(self.model.predict)
353386

354387
def predictOnValues(self, dstream):
355388
self._validate(dstream)
356-
dstream.mapValues(model.predict)
389+
dstream.mapValues(self.model.predict)
357390

358391

359392
def _test():

python/pyspark/mllib/tests.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from pyspark import SparkContext
4040
from pyspark.mllib.common import _to_java_object_rdd
41+
from pyspark.mllib.clustering import StreamingKMeans
4142
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
4243
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
4344
from pyspark.mllib.regression import LabeledPoint
@@ -48,6 +49,7 @@
4849
from pyspark.mllib.feature import StandardScaler
4950
from pyspark.mllib.feature import ElementwiseProduct
5051
from pyspark.serializers import PickleSerializer
52+
from pyspark.streaming import StreamingContext
5153
from pyspark.sql import SQLContext
5254

5355
_have_scipy = False
@@ -863,6 +865,44 @@ def test_model_transform(self):
863865
eprod.transform(sparsevec), SparseVector(3, [0], [3]))
864866

865867

868+
class StreamingKMeansTest(MLlibTestCase):
869+
def test_model_params(self):
870+
stkm = StreamingKMeans()
871+
stkm.setK(5).setDecayFactor(0.0)
872+
self.assertEquals(stkm._k, 5)
873+
self.assertEquals(stkm._decayFactor, 0.0)
874+
875+
# Model not set yet.
876+
self.assertIsNone(stkm.model)
877+
self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0])
878+
879+
stkm.setInitialCenters([[0.0, 0.0], [1.0, 1.0]], [1.0, 1.0])
880+
self.assertEqual(stkm.model.centers, [[0.0, 0.0], [1.0, 1.0]])
881+
self.assertEqual(stkm.model.getClusterWeights, [1.0, 1.0])
882+
883+
def test_model(self):
884+
stkm = StreamingKMeans()
885+
initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]]
886+
weights = [1.0, 1.0, 1.0, 1.0]
887+
stkm.setInitialCenters(initCenters, weights)
888+
889+
offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]]
890+
batches = []
891+
892+
for offset in offsets:
893+
batches.append([[offset[0] + center[0], offset[1] + center[1]]
894+
for center in initCenters])
895+
896+
batches = [self.sc.parallelize(batch, 1) for batch in batches]
897+
ssc = StreamingContext(self.sc, 2.0)
898+
input_stream = ssc.queueStream(batches)
899+
stkm.trainOn(input_stream)
900+
ssc.start()
901+
finalModel = stkm.model
902+
self.assertEqual(finalModel.centers, initCenters)
903+
# self.assertEqual(finalModel.getClusterWeights, [5.0, 5.0, 5.0, 5.0])
904+
905+
866906
if __name__ == "__main__":
867907
if not _have_scipy:
868908
print("NOTE: Skipping SciPy tests as it does not seem to be installed")

0 commit comments

Comments
 (0)