|
21 | 21 | if sys.version > '3':
|
22 | 22 | xrange = range
|
23 | 23 |
|
24 |
| -from numpy import array |
| 24 | +from math import exp, log |
| 25 | + |
| 26 | +from numpy import array, random, tile |
25 | 27 |
|
26 | 28 | from pyspark import RDD
|
27 | 29 | from pyspark import SparkContext
|
@@ -264,6 +266,96 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
|
264 | 266 | return GaussianMixtureModel(weight, mvg_obj)
|
265 | 267 |
|
266 | 268 |
|
| 269 | +class StreamingKMeansModel(KMeansModel): |
| 270 | + """ |
| 271 | + .. note:: Experimental |
| 272 | + """ |
| 273 | + def __init__(self, clusterCenters, clusterWeights): |
| 274 | + super(StreamingKMeansModel, self).__init__(centers=clusterCenters) |
| 275 | + self._clusterWeights = list(clusterWeights) |
| 276 | + |
| 277 | + def update(self, data, decayFactor, timeUnit): |
| 278 | + if not isinstance(data, RDD): |
| 279 | + raise TypeError("data should be of a RDD, got %s." % type(data)) |
| 280 | + decayFactor = float(decayFactor) |
| 281 | + if timeUnit not in ["batches", "points"]: |
| 282 | + raise ValueError( |
| 283 | + "timeUnit should be 'batches' or 'points', got %s." % timeUnit) |
| 284 | + vectorCenters = [_convert_to_vector(center) for center in self.centers] |
| 285 | + updatedModel = callMLlibFunc( |
| 286 | + "updateStreamingKMeansModel", vectorCenters, self._clusterWeights, |
| 287 | + data, decayFactor, timeUnit) |
| 288 | + self.centers = array(updatedModel[0]) |
| 289 | + self._clusterWeights = list(updatedModel[1]) |
| 290 | + return self |
| 291 | + |
| 292 | + |
| 293 | +class StreamingKMeans(object): |
| 294 | + """ |
| 295 | + .. note:: Experimental |
| 296 | + """ |
| 297 | + def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"): |
| 298 | + self._k = k |
| 299 | + self._decayFactor = decayFactor |
| 300 | + if timeUnit not in ["batches", "points"]: |
| 301 | + raise ValueError( |
| 302 | + "timeUnit should be 'batches' or 'points', got %s." % timeUnit) |
| 303 | + self._timeUnit = timeUnit |
| 304 | + self.model = None |
| 305 | + |
| 306 | + def _validate(self, dstream): |
| 307 | + if self.model is None: |
| 308 | + raise ValueError( |
| 309 | + "Initial centers should be set either by setInitialCenters ") |
| 310 | + "or setRandomCenters.") |
| 311 | + if not isinstance(dstream, DStream): |
| 312 | + raise TypeError( |
| 313 | + "Expected dstream to be of type DStream, " |
| 314 | + "got type %d" % type(dstream)) |
| 315 | + |
| 316 | + def setK(self, k): |
| 317 | + self._k = k |
| 318 | + return self |
| 319 | + |
| 320 | + def setDecayFactor(self, decayFactor): |
| 321 | + self._decayFactor = decayFactor |
| 322 | + return self |
| 323 | + |
| 324 | + def setHalfLife(self, halfLife, timeUnit): |
| 325 | + self._timeUnit = timeUnit |
| 326 | + self._decayFactor = exp(log(0.5) / halfLife) |
| 327 | + return self |
| 328 | + |
| 329 | + def setInitialCenters(self, centers, weights): |
| 330 | + self.model = StreamingKMeansModel(centers, weights) |
| 331 | + return self |
| 332 | + |
| 333 | + def setRandomCenters(self, dim, weight, seed): |
| 334 | + rng = random.RandomState(seed) |
| 335 | + clusterCenters = rng.randn(self._k, dim) |
| 336 | + clusterWeights = tile(weight, self._k) |
| 337 | + self.model = StreamingKMeansModel(clusterCenters, clusterWeights) |
| 338 | + return self |
| 339 | + |
| 340 | + def trainOn(self, dstream): |
| 341 | + self._validate(dstream) |
| 342 | + |
| 343 | + def update(_, rdd): |
| 344 | + if rdd: |
| 345 | + self.model = self.model.update(rdd) |
| 346 | + |
| 347 | + dstream.foreachRDD(update) |
| 348 | + return self |
| 349 | + |
| 350 | + def predictOn(self, dstream): |
| 351 | + self._validate(dstream) |
| 352 | + dstream.map(model.predict) |
| 353 | + |
| 354 | + def predictOnValues(self, dstream): |
| 355 | + self._validate(dstream) |
| 356 | + dstream.mapValues(model.predict) |
| 357 | + |
| 358 | + |
267 | 359 | def _test():
|
268 | 360 | import doctest
|
269 | 361 | globs = globals().copy()
|
|
0 commit comments