Skip to content

Commit 1dc522c

Browse files
committed
[SPARK-6598] Python API for IDFModel
1 parent 4bdfb7b commit 1dc522c

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

python/pyspark/mllib/feature.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,12 @@ def transform(self, x):
244244
x = _convert_to_vector(x)
245245
return JavaVectorTransformer.transform(self, x)
246246

247+
def idf(self):
248+
"""
249+
Returns the current IDF vector.
250+
"""
251+
return self.call('idf')
252+
247253

248254
class IDF(object):
249255
"""

python/pyspark/mllib/tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from pyspark.mllib.regression import LabeledPoint
4242
from pyspark.mllib.random import RandomRDDs
4343
from pyspark.mllib.stat import Statistics
44+
from pyspark.mllib.feature import IDF
4445
from pyspark.serializers import PickleSerializer
4546
from pyspark.sql import SQLContext
4647
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -620,6 +621,19 @@ def test_right_number_of_results(self):
620621
self.assertEqual(len(chi), num_cols)
621622
self.assertIsNotNone(chi[1000])
622623

624+
625+
class FeatureTest(PySparkTestCase):
626+
def test_idf_model(self):
627+
data = [
628+
Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]),
629+
Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]),
630+
Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]),
631+
Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9])
632+
]
633+
model = IDF().fit(self.sc.parallelize(data, 2))
634+
idf = model.idf()
635+
self.assertEqual(len(idf), 11)
636+
623637
if __name__ == "__main__":
624638
if not _have_scipy:
625639
print "NOTE: Skipping SciPy tests as it does not seem to be installed"

0 commit comments

Comments
 (0)