Skip to content

Commit d8b066a

Browse files
committed
[SPARK-4118] [MLlib] [PySpark] Python bindings for StreamingKMeans
1 parent 3b61077 commit d8b066a

File tree

2 files changed

+108
-1
lines changed

2 files changed

+108
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,21 @@ private[python] class PythonMLLibAPI extends Serializable {
964964
points.asScala.toArray)
965965
}
966966

967+
/**
968+
* Java stub for the update method of StreamingKMeansModel.
969+
*/
970+
def updateStreamingKMeansModel(
971+
clusterCenters: java.util.ArrayList[Vector],
972+
clusterWeights: java.util.ArrayList[Double],
973+
data: JavaRDD[Vector], decayFactor: Double,
974+
timeUnit: String) : JList[Object] = {
975+
val model = new StreamingKMeansModel(
976+
clusterCenters.asScala.toArray, clusterWeights.asScala.toArray)
977+
.update(data, decayFactor, timeUnit)
978+
List(model.clusterCenters, model.clusterWeights).
979+
map(_.asInstanceOf[Object]).asJava
980+
}
981+
967982
}
968983

969984
/**

python/pyspark/mllib/clustering.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
if sys.version > '3':
2222
xrange = range
2323

24-
from numpy import array
24+
from math import exp, log
25+
26+
from numpy import array, random, tile
2527

2628
from pyspark import RDD
2729
from pyspark import SparkContext
@@ -264,6 +266,96 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
264266
return GaussianMixtureModel(weight, mvg_obj)
265267

266268

269+
class StreamingKMeansModel(KMeansModel):
270+
"""
271+
.. note:: Experimental
272+
"""
273+
def __init__(self, clusterCenters, clusterWeights):
274+
super(StreamingKMeansModel, self).__init__(centers=clusterCenters)
275+
self._clusterWeights = list(clusterWeights)
276+
277+
def update(self, data, decayFactor, timeUnit):
278+
if not isinstance(data, RDD):
279+
raise TypeError("data should be of a RDD, got %s." % type(data))
280+
decayFactor = float(decayFactor)
281+
if timeUnit not in ["batches", "points"]:
282+
raise ValueError(
283+
"timeUnit should be 'batches' or 'points', got %s." % timeUnit)
284+
vectorCenters = [_convert_to_vector(center) for center in self.centers]
285+
updatedModel = callMLlibFunc(
286+
"updateStreamingKMeansModel", vectorCenters, self._clusterWeights,
287+
data, decayFactor, timeUnit)
288+
self.centers = array(updatedModel[0])
289+
self._clusterWeights = list(updatedModel[1])
290+
return self
291+
292+
293+
class StreamingKMeans(object):
294+
"""
295+
.. note:: Experimental
296+
"""
297+
def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"):
298+
self._k = k
299+
self._decayFactor = decayFactor
300+
if timeUnit not in ["batches", "points"]:
301+
raise ValueError(
302+
"timeUnit should be 'batches' or 'points', got %s." % timeUnit)
303+
self._timeUnit = timeUnit
304+
self.model = None
305+
306+
def _validate(self, dstream):
307+
if self.model is None:
308+
raise ValueError(
309+
"Initial centers should be set either by setInitialCenters ")
310+
"or setRandomCenters.")
311+
if not isinstance(dstream, DStream):
312+
raise TypeError(
313+
"Expected dstream to be of type DStream, "
314+
"got type %d" % type(dstream))
315+
316+
def setK(self, k):
317+
self._k = k
318+
return self
319+
320+
def setDecayFactor(self, decayFactor):
321+
self._decayFactor = decayFactor
322+
return self
323+
324+
def setHalfLife(self, halfLife, timeUnit):
325+
self._timeUnit = timeUnit
326+
self._decayFactor = exp(log(0.5) / halfLife)
327+
return self
328+
329+
def setInitialCenters(self, centers, weights):
330+
self.model = StreamingKMeansModel(centers, weights)
331+
return self
332+
333+
def setRandomCenters(self, dim, weight, seed):
334+
rng = random.RandomState(seed)
335+
clusterCenters = rng.randn(self._k, dim)
336+
clusterWeights = tile(weight, self._k)
337+
self.model = StreamingKMeansModel(clusterCenters, clusterWeights)
338+
return self
339+
340+
def trainOn(self, dstream):
341+
self._validate(dstream)
342+
343+
def update(_, rdd):
344+
if rdd:
345+
self.model = self.model.update(rdd)
346+
347+
dstream.foreachRDD(update)
348+
return self
349+
350+
def predictOn(self, dstream):
351+
self._validate(dstream)
352+
dstream.map(model.predict)
353+
354+
def predictOnValues(self, dstream):
355+
self._validate(dstream)
356+
dstream.mapValues(model.predict)
357+
358+
267359
def _test():
268360
import doctest
269361
globs = globals().copy()

0 commit comments

Comments
 (0)