Skip to content

Commit 8d6befe

Browse files
author
Liquan Pei
committed
initial commit
1 parent c475540 commit 8d6befe

File tree

2 files changed

+393
-0
lines changed

2 files changed

+393
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* Add a comment to this line
7+
* (the "License"); you may not use this file except in compliance with
8+
* the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.spark.mllib.feature
20+
21+
import scala.util._
22+
import scala.collection.mutable.ArrayBuffer
23+
import scala.collection.mutable.HashMap
24+
import scala.collection.mutable
25+
26+
import com.github.fommil.netlib.BLAS.{getInstance => blas}
27+
28+
import org.apache.spark._
29+
import org.apache.spark.rdd._
30+
import org.apache.spark.SparkContext._
31+
import org.apache.spark.mllib.linalg.Vector
32+
import org.apache.spark.HashPartitioner
33+
34+
private case class VocabWord(
35+
var word: String,
36+
var cn: Int,
37+
var point: Array[Int],
38+
var code: Array[Int],
39+
var codeLen:Int
40+
)
41+
42+
class Word2Vec(
43+
val size: Int,
44+
val startingAlpha: Double,
45+
val window: Int,
46+
val minCount: Int)
47+
extends Serializable with Logging {
48+
49+
private val EXP_TABLE_SIZE = 1000
50+
private val MAX_EXP = 6
51+
private val MAX_CODE_LENGTH = 40
52+
private val MAX_SENTENCE_LENGTH = 1000
53+
private val layer1Size = size
54+
55+
private var trainWordsCount = 0
56+
private var vocabSize = 0
57+
private var vocab: Array[VocabWord] = null
58+
private var vocabHash = mutable.HashMap.empty[String, Int]
59+
private var alpha = startingAlpha
60+
61+
private def learnVocab(dataset: RDD[String]) {
62+
vocab = dataset.flatMap(line => line.split(" "))
63+
.map(w => (w, 1))
64+
.reduceByKey(_ + _)
65+
.map(x => VocabWord(x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0))
66+
.filter(_.cn >= minCount)
67+
.collect()
68+
.sortWith((a, b)=> a.cn > b.cn)
69+
70+
vocabSize = vocab.length
71+
var a = 0
72+
while (a < vocabSize) {
73+
vocabHash += vocab(a).word -> a
74+
trainWordsCount += vocab(a).cn
75+
a += 1
76+
}
77+
logInfo("trainWordsCount = " + trainWordsCount)
78+
}
79+
80+
private def createExpTable(): Array[Double] = {
81+
val expTable = new Array[Double](EXP_TABLE_SIZE)
82+
var i = 0
83+
while (i < EXP_TABLE_SIZE) {
84+
val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
85+
expTable(i) = tmp / (tmp + 1)
86+
i += 1
87+
}
88+
expTable
89+
}
90+
91+
private def createBinaryTree() {
92+
val count = new Array[Long](vocabSize * 2 + 1)
93+
val binary = new Array[Int](vocabSize * 2 + 1)
94+
val parentNode = new Array[Int](vocabSize * 2 + 1)
95+
val code = new Array[Int](MAX_CODE_LENGTH)
96+
val point = new Array[Int](MAX_CODE_LENGTH)
97+
var a = 0
98+
while (a < vocabSize) {
99+
count(a) = vocab(a).cn
100+
a += 1
101+
}
102+
while (a < 2 * vocabSize) {
103+
count(a) = 1e9.toInt
104+
a += 1
105+
}
106+
var pos1 = vocabSize - 1
107+
var pos2 = vocabSize
108+
109+
var min1i = 0
110+
var min2i = 0
111+
112+
a = 0
113+
while (a < vocabSize - 1) {
114+
if (pos1 >= 0) {
115+
if (count(pos1) < count(pos2)) {
116+
min1i = pos1
117+
pos1 -= 1
118+
} else {
119+
min1i = pos2
120+
pos2 += 1
121+
}
122+
} else {
123+
min1i = pos2
124+
pos2 += 1
125+
}
126+
if (pos1 >= 0) {
127+
if (count(pos1) < count(pos2)) {
128+
min2i = pos1
129+
pos1 -= 1
130+
} else {
131+
min2i = pos2
132+
pos2 += 1
133+
}
134+
} else {
135+
min2i = pos2
136+
pos2 += 1
137+
}
138+
count(vocabSize + a) = count(min1i) + count(min2i)
139+
parentNode(min1i) = vocabSize + a
140+
parentNode(min2i) = vocabSize + a
141+
binary(min2i) = 1
142+
a += 1
143+
}
144+
// Now assign binary code to each vocabulary word
145+
var i = 0
146+
a = 0
147+
while (a < vocabSize) {
148+
var b = a
149+
i = 0
150+
while (b != vocabSize * 2 - 2) {
151+
code(i) = binary(b)
152+
point(i) = b
153+
i += 1
154+
b = parentNode(b)
155+
}
156+
vocab(a).codeLen = i
157+
vocab(a).point(0) = vocabSize - 2
158+
b = 0
159+
while (b < i) {
160+
vocab(a).code(i - b - 1) = code(b)
161+
vocab(a).point(i - b) = point(b) - vocabSize
162+
b += 1
163+
}
164+
a += 1
165+
}
166+
}
167+
168+
/**
169+
* Computes the vector representation of each word in
170+
* vocabulary
171+
* @param dataset an RDD of strings
172+
*/
173+
174+
def fit(dataset:RDD[String]): Word2VecModel = {
175+
176+
learnVocab(dataset)
177+
178+
createBinaryTree()
179+
180+
val sc = dataset.context
181+
182+
val expTable = sc.broadcast(createExpTable())
183+
val V = sc.broadcast(vocab)
184+
val VHash = sc.broadcast(vocabHash)
185+
186+
val sentences = dataset.flatMap(line => line.split(" ")).mapPartitions {
187+
iter => { new Iterator[Array[Int]] {
188+
def hasNext = iter.hasNext
189+
def next = {
190+
var sentence = new ArrayBuffer[Int]
191+
var sentenceLength = 0
192+
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
193+
val word = VHash.value.get(iter.next)
194+
word match {
195+
case Some(w) => {
196+
sentence += w
197+
sentenceLength += 1
198+
}
199+
case None =>
200+
}
201+
}
202+
sentence.toArray
203+
}
204+
}
205+
}
206+
}
207+
208+
val newSentences = sentences.repartition(1).cache()
209+
val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
210+
val (aggSyn0, _, _, _) =
211+
// TODO: broadcast temp instead of serializing it directly or initialize the model in each executor
212+
newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))(
213+
seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
214+
var lwc = lastWordCount
215+
var wc = wordCount
216+
if (wordCount - lastWordCount > 10000) {
217+
lwc = wordCount
218+
alpha = startingAlpha * (1 - wordCount.toDouble / (trainWordsCount + 1))
219+
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
220+
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
221+
}
222+
wc += sentence.size
223+
var pos = 0
224+
while (pos < sentence.size) {
225+
val word = sentence(pos)
226+
// TODO: fix random seed
227+
val b = Random.nextInt(window)
228+
// Train Skip-gram
229+
var a = b
230+
while (a < window * 2 + 1 - b) {
231+
if (a != window) {
232+
val c = pos - window + a
233+
if (c >= 0 && c < sentence.size) {
234+
val lastWord = sentence(c)
235+
val l1 = lastWord * layer1Size
236+
val neu1e = new Array[Double](layer1Size)
237+
//HS
238+
var d = 0
239+
while (d < vocab(word).codeLen) {
240+
val l2 = vocab(word).point(d) * layer1Size
241+
// Propagate hidden -> output
242+
var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1)
243+
if (f > -MAX_EXP && f < MAX_EXP) {
244+
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
245+
f = expTable.value(ind)
246+
val g = (1 - vocab(word).code(d) - f) * alpha
247+
blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
248+
blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
249+
}
250+
d += 1
251+
}
252+
blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1)
253+
}
254+
}
255+
a += 1
256+
}
257+
pos += 1
258+
}
259+
(syn0, syn1, lwc, wc)
260+
},
261+
combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
262+
val n = syn0_1.length
263+
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
264+
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
265+
(syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
266+
})
267+
268+
val wordMap = new Array[(String, Array[Double])](vocabSize)
269+
var i = 0
270+
while (i < vocabSize) {
271+
val word = vocab(i).word
272+
val vector = new Array[Double](layer1Size)
273+
Array.copy(aggSyn0, i * layer1Size, vector, 0, layer1Size)
274+
wordMap(i) = (word, vector)
275+
i += 1
276+
}
277+
val modelRDD = sc.parallelize(wordMap,100).partitionBy(new HashPartitioner(100))
278+
new Word2VecModel(modelRDD)
279+
}
280+
}
281+
282+
class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable {
283+
284+
val model = _model
285+
286+
private def distance(v1: Array[Double], v2: Array[Double]): Double = {
287+
require(v1.length == v2.length, "Vectors should have the same length")
288+
val n = v1.length
289+
val norm1 = blas.dnrm2(n, v1, 1)
290+
val norm2 = blas.dnrm2(n, v2, 1)
291+
if (norm1 == 0 || norm2 == 0) return 0.0
292+
blas.ddot(n, v1, 1, v2,1) / norm1 / norm2
293+
}
294+
295+
def transform(word: String): Array[Double] = {
296+
val result = model.lookup(word)
297+
if (result.isEmpty) Array[Double]()
298+
else result(0)
299+
}
300+
301+
def transform(dataset: RDD[String]): RDD[Array[Double]] = {
302+
dataset.map(word => transform(word))
303+
}
304+
305+
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
306+
val vector = transform(word)
307+
if (vector.isEmpty) Array[(String, Double)]()
308+
else findSynonyms(vector,num)
309+
}
310+
311+
def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = {
312+
require(num > 0, "Number of similar words should > 0")
313+
val topK = model.map(
314+
{case(w, vec) => (distance(vector, vec), w)})
315+
.sortByKey(ascending = false)
316+
.take(num + 1)
317+
.map({case (dist, w) => (w, dist)}).drop(1)
318+
319+
topK
320+
}
321+
}
322+
323+
object Word2Vec extends Serializable with Logging {
324+
def train(
325+
input: RDD[String],
326+
size: Int,
327+
startingAlpha: Double,
328+
window: Int,
329+
minCount: Int): Word2VecModel = {
330+
new Word2Vec(size,startingAlpha, window, minCount).fit(input)
331+
}
332+
333+
def main(args: Array[String]) {
334+
if (args.length < 6) {
335+
println("Usage: word2vec input size startingAlpha window minCount num")
336+
sys.exit(1)
337+
}
338+
val conf = new SparkConf()
339+
.setAppName("word2vec")
340+
341+
val sc = new SparkContext(conf)
342+
val input = sc.textFile(args(0))
343+
val size = args(1).toInt
344+
val startingAlpha = args(2).toDouble
345+
val window = args(3).toInt
346+
val minCount = args(4).toInt
347+
val num = args(5).toInt
348+
val model = train(input, size, startingAlpha, window, minCount)
349+
val vec = model.findSynonyms("china", num)
350+
for((w, dist) <- vec) logInfo(w.toString + " " + dist.toString)
351+
sc.stop()
352+
}
353+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* Add a comment to this line
7+
* (the "License"); you may not use this file except in compliance with
8+
* the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.spark.mllib.feature
20+
21+
import org.scalatest.FunSuite
22+
import org.apache.spark.SparkContext._
23+
import org.apache.spark.mllib.util.LocalSparkContext
24+
25+
class Word2VecSuite extends FunSuite with LocalSparkContext {
26+
test("word2vec") {
27+
val num = 2
28+
val localModel = Seq(
29+
("china" , Array(0.50, 0.50, 0.50, 0.50)),
30+
("japan" , Array(0.40, 0.50, 0.50, 0.50)),
31+
("taiwan", Array(0.60, 0.50, 0.50, 0.50)),
32+
("korea" , Array(0.45, 0.60, 0.60, 0.60))
33+
)
34+
val model = new Word2VecModel(sc.parallelize(localModel, 2))
35+
val synons = model.findSynonyms("china", num)
36+
assert(synons.length == num)
37+
assert(synons(0)._1 == "taiwan")
38+
assert(synons(1)._1 == "japan")
39+
}
40+
}

0 commit comments

Comments
 (0)