Skip to content

Commit ec9df6a

Browse files
rnowlingmengxr
authored andcommitted
[SPARK-3614][MLLIB] Add minimumOccurence filtering to IDF
This PR for [SPARK-3614](https://issues.apache.org/jira/browse/SPARK-3614) adds functionality for filtering out terms which do not appear in at least a minimum number of documents. This is implemented using a minimumOccurence parameter (default 0). When terms' document frequencies are less than minimumOccurence, their IDFs are set to 0, just like when the DF is 0. As a result, the TF-IDFs for the terms are found to be 0, as if the terms were not present in the documents. This PR makes the following changes: * Add a minimumOccurence parameter to the IDF and DocumentFrequencyAggregator classes. * Create a parameter-less constructor for IDF with a default minimumOccurence value of 0 to remain backwards-compatibility with the original IDF API. * Sets the IDFs to 0 for terms which DFs are less than minimumOccurence * Add tests to the Spark IDFSuite and Java JavaTfIdfSuite test suites * Updated the MLLib Feature Extraction programming guide to describe the new feature Author: RJ Nowling <rnowling@gmail.com> Closes #2494 from rnowling/spark-3614-idf-filter and squashes the following commits: 0aa3c63 [RJ Nowling] Fix identation e6523a8 [RJ Nowling] Remove unnecessary toDouble's from IDFSuite bfa82ec [RJ Nowling] Add space after if 30d20b3 [RJ Nowling] Add spaces around equals signs 9013447 [RJ Nowling] Add space before division operator 79978fc [RJ Nowling] Remove unnecessary semi-colon 40fd70c [RJ Nowling] Change minimumOccurence to minDocFreq in code and docs 47850ab [RJ Nowling] Changed minimumOccurence to Int from Long 9fb4093 [RJ Nowling] Remove unnecessary lines from IDF class docs 1fc09d8 [RJ Nowling] Add backwards-compatible constructor to DocumentFrequencyAggregator 1801fd2 [RJ Nowling] Fix style errors in IDF.scala 6897252 [RJ Nowling] Preface minimumOccurence members with val to make them final and immutable a200bab [RJ Nowling] Remove unnecessary else statement 4b974f5 [RJ Nowling] Remove accidentally-added import from testing c0cc643 [RJ Nowling] Add minimumOccurence filtering to IDF
1 parent d16e161 commit ec9df6a

File tree

4 files changed

+103
-5
lines changed

4 files changed

+103
-5
lines changed

docs/mllib-feature-extraction.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,21 @@ tf.cache()
8282
val idf = new IDF().fit(tf)
8383
val tfidf: RDD[Vector] = idf.transform(tf)
8484
{% endhighlight %}
85+
86+
MLLib's IDF implementation provides an option for ignoring terms which occur in less than a
87+
minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature
88+
can be used by passing the `minDocFreq` value to the IDF constructor.
89+
90+
{% highlight scala %}
91+
import org.apache.spark.mllib.feature.IDF
92+
93+
// ... continue from the previous example
94+
tf.cache()
95+
val idf = new IDF(minDocFreq = 2).fit(tf)
96+
val tfidf: RDD[Vector] = idf.transform(tf)
97+
{% endhighlight %}
98+
99+
85100
</div>
86101
</div>
87102

mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,18 @@ import org.apache.spark.rdd.RDD
3030
* Inverse document frequency (IDF).
3131
* The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total
3232
* number of documents and `d(t)` is the number of documents that contain term `t`.
33+
*
34+
* This implementation supports filtering out terms which do not appear in a minimum number
35+
* of documents (controlled by the variable `minDocFreq`). For terms that are not in
36+
* at least `minDocFreq` documents, the IDF is found as 0, resulting in TF-IDFs of 0.
37+
*
38+
* @param minDocFreq minimum of documents in which a term
39+
* should appear for filtering
3340
*/
3441
@Experimental
35-
class IDF {
42+
class IDF(val minDocFreq: Int) {
43+
44+
def this() = this(0)
3645

3746
// TODO: Allow different IDF formulations.
3847

@@ -41,7 +50,8 @@ class IDF {
4150
* @param dataset an RDD of term frequency vectors
4251
*/
4352
def fit(dataset: RDD[Vector]): IDFModel = {
44-
val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)(
53+
val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator(
54+
minDocFreq = minDocFreq))(
4555
seqOp = (df, v) => df.add(v),
4656
combOp = (df1, df2) => df1.merge(df2)
4757
).idf()
@@ -60,13 +70,16 @@ class IDF {
6070
private object IDF {
6171

6272
/** Document frequency aggregator. */
63-
class DocumentFrequencyAggregator extends Serializable {
73+
class DocumentFrequencyAggregator(val minDocFreq: Int) extends Serializable {
6474

6575
/** number of documents */
6676
private var m = 0L
6777
/** document frequency vector */
6878
private var df: BDV[Long] = _
6979

80+
81+
def this() = this(0)
82+
7083
/** Adds a new document. */
7184
def add(doc: Vector): this.type = {
7285
if (isEmpty) {
@@ -123,7 +136,18 @@ private object IDF {
123136
val inv = new Array[Double](n)
124137
var j = 0
125138
while (j < n) {
126-
inv(j) = math.log((m + 1.0)/ (df(j) + 1.0))
139+
/*
140+
* If the term is not present in the minimum
141+
* number of documents, set IDF to 0. This
142+
* will cause multiplication in IDFModel to
143+
* set TF-IDF to 0.
144+
*
145+
* Since arrays are initialized to 0 by default,
146+
* we just omit changing those entries.
147+
*/
148+
if(df(j) >= minDocFreq) {
149+
inv(j) = math.log((m + 1.0) / (df(j) + 1.0))
150+
}
127151
j += 1
128152
}
129153
Vectors.dense(inv)
@@ -140,6 +164,11 @@ class IDFModel private[mllib] (val idf: Vector) extends Serializable {
140164

141165
/**
142166
* Transforms term frequency (TF) vectors to TF-IDF vectors.
167+
*
168+
* If `minDocFreq` was set for the IDF calculation,
169+
* the terms which occur in fewer than `minDocFreq`
170+
* documents will have an entry of 0.
171+
*
143172
* @param dataset an RDD of term frequency vectors
144173
* @return an RDD of TF-IDF vectors
145174
*/

mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,24 @@ public void tfIdf() {
6363
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
6464
}
6565
}
66+
67+
@Test
68+
public void tfIdfMinimumDocumentFrequency() {
69+
// The tests are to check Java compatibility.
70+
HashingTF tf = new HashingTF();
71+
JavaRDD<ArrayList<String>> documents = sc.parallelize(Lists.newArrayList(
72+
Lists.newArrayList("this is a sentence".split(" ")),
73+
Lists.newArrayList("this is another sentence".split(" ")),
74+
Lists.newArrayList("this is still a sentence".split(" "))), 2);
75+
JavaRDD<Vector> termFreqs = tf.transform(documents);
76+
termFreqs.collect();
77+
IDF idf = new IDF(2);
78+
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
79+
List<Vector> localTfIdfs = tfIdfs.collect();
80+
int indexOfThis = tf.indexOf("this");
81+
for (Vector v: localTfIdfs) {
82+
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
83+
}
84+
}
85+
6686
}

mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class IDFSuite extends FunSuite with LocalSparkContext {
3838
val idf = new IDF
3939
val model = idf.fit(termFrequencies)
4040
val expected = Vectors.dense(Array(0, 3, 1, 2).map { x =>
41-
math.log((m.toDouble + 1.0) / (x + 1.0))
41+
math.log((m + 1.0) / (x + 1.0))
4242
})
4343
assert(model.idf ~== expected absTol 1e-12)
4444
val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap()
@@ -54,4 +54,38 @@ class IDFSuite extends FunSuite with LocalSparkContext {
5454
assert(tfidf2.indices === Array(1))
5555
assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12)
5656
}
57+
58+
test("idf minimum document frequency filtering") {
59+
val n = 4
60+
val localTermFrequencies = Seq(
61+
Vectors.sparse(n, Array(1, 3), Array(1.0, 2.0)),
62+
Vectors.dense(0.0, 1.0, 2.0, 3.0),
63+
Vectors.sparse(n, Array(1), Array(1.0))
64+
)
65+
val m = localTermFrequencies.size
66+
val termFrequencies = sc.parallelize(localTermFrequencies, 2)
67+
val idf = new IDF(minDocFreq = 1)
68+
val model = idf.fit(termFrequencies)
69+
val expected = Vectors.dense(Array(0, 3, 1, 2).map { x =>
70+
if (x > 0) {
71+
math.log((m + 1.0) / (x + 1.0))
72+
} else {
73+
0
74+
}
75+
})
76+
assert(model.idf ~== expected absTol 1e-12)
77+
val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap()
78+
assert(tfidf.size === 3)
79+
val tfidf0 = tfidf(0L).asInstanceOf[SparseVector]
80+
assert(tfidf0.indices === Array(1, 3))
81+
assert(Vectors.dense(tfidf0.values) ~==
82+
Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12)
83+
val tfidf1 = tfidf(1L).asInstanceOf[DenseVector]
84+
assert(Vectors.dense(tfidf1.values) ~==
85+
Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12)
86+
val tfidf2 = tfidf(2L).asInstanceOf[SparseVector]
87+
assert(tfidf2.indices === Array(1))
88+
assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12)
89+
}
90+
5791
}

0 commit comments

Comments
 (0)