|
41 | 41 | from pyspark.mllib.regression import LabeledPoint
|
42 | 42 | from pyspark.mllib.random import RandomRDDs
|
43 | 43 | from pyspark.mllib.stat import Statistics
|
| 44 | +from pyspark.mllib.feature import IDF |
44 | 45 | from pyspark.serializers import PickleSerializer
|
45 | 46 | from pyspark.sql import SQLContext
|
46 | 47 | from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
|
@@ -620,6 +621,19 @@ def test_right_number_of_results(self):
|
620 | 621 | self.assertEqual(len(chi), num_cols)
|
621 | 622 | self.assertIsNotNone(chi[1000])
|
622 | 623 |
|
| 624 | + |
| 625 | +class FeatureTest(PySparkTestCase): |
| 626 | + def test_idf_model(self): |
| 627 | + data = [ |
| 628 | + Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), |
| 629 | + Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), |
| 630 | + Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), |
| 631 | + Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) |
| 632 | + ] |
| 633 | + model = IDF().fit(self.sc.parallelize(data, 2)) |
| 634 | + idf = model.idf() |
| 635 | + self.assertEqual(len(idf), 11) |
| 636 | + |
623 | 637 | if __name__ == "__main__":
|
624 | 638 | if not _have_scipy:
|
625 | 639 | print "NOTE: Skipping SciPy tests as it does not seem to be installed"
|
|
0 commit comments