|
| 1 | +/* |
| 2 | +* Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | +* contributor license agreements. See the NOTICE file distributed with |
| 4 | +* this work for additional information regarding copyright ownership. |
| 5 | +* The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | +* Add a comment to this line |
| 7 | +* (the "License"); you may not use this file except in compliance with |
| 8 | +* the License. You may obtain a copy of the License at |
| 9 | +* |
| 10 | +* http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | +* |
| 12 | +* Unless required by applicable law or agreed to in writing, software |
| 13 | +* distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | +* See the License for the specific language governing permissions and |
| 16 | +* limitations under the License. |
| 17 | +*/ |
| 18 | + |
| 19 | +package org.apache.spark.mllib.feature |
| 20 | + |
| 21 | +import scala.util._ |
| 22 | +import scala.collection.mutable.ArrayBuffer |
| 23 | +import scala.collection.mutable.HashMap |
| 24 | +import scala.collection.mutable |
| 25 | + |
| 26 | +import com.github.fommil.netlib.BLAS.{getInstance => blas} |
| 27 | + |
| 28 | +import org.apache.spark._ |
| 29 | +import org.apache.spark.rdd._ |
| 30 | +import org.apache.spark.SparkContext._ |
| 31 | +import org.apache.spark.mllib.linalg.Vector |
| 32 | +import org.apache.spark.HashPartitioner |
| 33 | + |
| 34 | +private case class VocabWord( |
| 35 | + var word: String, |
| 36 | + var cn: Int, |
| 37 | + var point: Array[Int], |
| 38 | + var code: Array[Int], |
| 39 | + var codeLen:Int |
| 40 | +) |
| 41 | + |
| 42 | +class Word2Vec( |
| 43 | + val size: Int, |
| 44 | + val startingAlpha: Double, |
| 45 | + val window: Int, |
| 46 | + val minCount: Int) |
| 47 | + extends Serializable with Logging { |
| 48 | + |
| 49 | + private val EXP_TABLE_SIZE = 1000 |
| 50 | + private val MAX_EXP = 6 |
| 51 | + private val MAX_CODE_LENGTH = 40 |
| 52 | + private val MAX_SENTENCE_LENGTH = 1000 |
| 53 | + private val layer1Size = size |
| 54 | + |
| 55 | + private var trainWordsCount = 0 |
| 56 | + private var vocabSize = 0 |
| 57 | + private var vocab: Array[VocabWord] = null |
| 58 | + private var vocabHash = mutable.HashMap.empty[String, Int] |
| 59 | + private var alpha = startingAlpha |
| 60 | + |
| 61 | + private def learnVocab(dataset: RDD[String]) { |
| 62 | + vocab = dataset.flatMap(line => line.split(" ")) |
| 63 | + .map(w => (w, 1)) |
| 64 | + .reduceByKey(_ + _) |
| 65 | + .map(x => VocabWord(x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) |
| 66 | + .filter(_.cn >= minCount) |
| 67 | + .collect() |
| 68 | + .sortWith((a, b)=> a.cn > b.cn) |
| 69 | + |
| 70 | + vocabSize = vocab.length |
| 71 | + var a = 0 |
| 72 | + while (a < vocabSize) { |
| 73 | + vocabHash += vocab(a).word -> a |
| 74 | + trainWordsCount += vocab(a).cn |
| 75 | + a += 1 |
| 76 | + } |
| 77 | + logInfo("trainWordsCount = " + trainWordsCount) |
| 78 | + } |
| 79 | + |
| 80 | + private def createExpTable(): Array[Double] = { |
| 81 | + val expTable = new Array[Double](EXP_TABLE_SIZE) |
| 82 | + var i = 0 |
| 83 | + while (i < EXP_TABLE_SIZE) { |
| 84 | + val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP) |
| 85 | + expTable(i) = tmp / (tmp + 1) |
| 86 | + i += 1 |
| 87 | + } |
| 88 | + expTable |
| 89 | + } |
| 90 | + |
| 91 | + private def createBinaryTree() { |
| 92 | + val count = new Array[Long](vocabSize * 2 + 1) |
| 93 | + val binary = new Array[Int](vocabSize * 2 + 1) |
| 94 | + val parentNode = new Array[Int](vocabSize * 2 + 1) |
| 95 | + val code = new Array[Int](MAX_CODE_LENGTH) |
| 96 | + val point = new Array[Int](MAX_CODE_LENGTH) |
| 97 | + var a = 0 |
| 98 | + while (a < vocabSize) { |
| 99 | + count(a) = vocab(a).cn |
| 100 | + a += 1 |
| 101 | + } |
| 102 | + while (a < 2 * vocabSize) { |
| 103 | + count(a) = 1e9.toInt |
| 104 | + a += 1 |
| 105 | + } |
| 106 | + var pos1 = vocabSize - 1 |
| 107 | + var pos2 = vocabSize |
| 108 | + |
| 109 | + var min1i = 0 |
| 110 | + var min2i = 0 |
| 111 | + |
| 112 | + a = 0 |
| 113 | + while (a < vocabSize - 1) { |
| 114 | + if (pos1 >= 0) { |
| 115 | + if (count(pos1) < count(pos2)) { |
| 116 | + min1i = pos1 |
| 117 | + pos1 -= 1 |
| 118 | + } else { |
| 119 | + min1i = pos2 |
| 120 | + pos2 += 1 |
| 121 | + } |
| 122 | + } else { |
| 123 | + min1i = pos2 |
| 124 | + pos2 += 1 |
| 125 | + } |
| 126 | + if (pos1 >= 0) { |
| 127 | + if (count(pos1) < count(pos2)) { |
| 128 | + min2i = pos1 |
| 129 | + pos1 -= 1 |
| 130 | + } else { |
| 131 | + min2i = pos2 |
| 132 | + pos2 += 1 |
| 133 | + } |
| 134 | + } else { |
| 135 | + min2i = pos2 |
| 136 | + pos2 += 1 |
| 137 | + } |
| 138 | + count(vocabSize + a) = count(min1i) + count(min2i) |
| 139 | + parentNode(min1i) = vocabSize + a |
| 140 | + parentNode(min2i) = vocabSize + a |
| 141 | + binary(min2i) = 1 |
| 142 | + a += 1 |
| 143 | + } |
| 144 | + // Now assign binary code to each vocabulary word |
| 145 | + var i = 0 |
| 146 | + a = 0 |
| 147 | + while (a < vocabSize) { |
| 148 | + var b = a |
| 149 | + i = 0 |
| 150 | + while (b != vocabSize * 2 - 2) { |
| 151 | + code(i) = binary(b) |
| 152 | + point(i) = b |
| 153 | + i += 1 |
| 154 | + b = parentNode(b) |
| 155 | + } |
| 156 | + vocab(a).codeLen = i |
| 157 | + vocab(a).point(0) = vocabSize - 2 |
| 158 | + b = 0 |
| 159 | + while (b < i) { |
| 160 | + vocab(a).code(i - b - 1) = code(b) |
| 161 | + vocab(a).point(i - b) = point(b) - vocabSize |
| 162 | + b += 1 |
| 163 | + } |
| 164 | + a += 1 |
| 165 | + } |
| 166 | + } |
| 167 | + |
| 168 | + /** |
| 169 | + * Computes the vector representation of each word in |
| 170 | + * vocabulary |
| 171 | + * @param dataset an RDD of strings |
| 172 | + */ |
| 173 | + |
| 174 | + def fit(dataset:RDD[String]): Word2VecModel = { |
| 175 | + |
| 176 | + learnVocab(dataset) |
| 177 | + |
| 178 | + createBinaryTree() |
| 179 | + |
| 180 | + val sc = dataset.context |
| 181 | + |
| 182 | + val expTable = sc.broadcast(createExpTable()) |
| 183 | + val V = sc.broadcast(vocab) |
| 184 | + val VHash = sc.broadcast(vocabHash) |
| 185 | + |
| 186 | + val sentences = dataset.flatMap(line => line.split(" ")).mapPartitions { |
| 187 | + iter => { new Iterator[Array[Int]] { |
| 188 | + def hasNext = iter.hasNext |
| 189 | + def next = { |
| 190 | + var sentence = new ArrayBuffer[Int] |
| 191 | + var sentenceLength = 0 |
| 192 | + while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { |
| 193 | + val word = VHash.value.get(iter.next) |
| 194 | + word match { |
| 195 | + case Some(w) => { |
| 196 | + sentence += w |
| 197 | + sentenceLength += 1 |
| 198 | + } |
| 199 | + case None => |
| 200 | + } |
| 201 | + } |
| 202 | + sentence.toArray |
| 203 | + } |
| 204 | + } |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + val newSentences = sentences.repartition(1).cache() |
| 209 | + val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) |
| 210 | + val (aggSyn0, _, _, _) = |
| 211 | + // TODO: broadcast temp instead of serializing it directly or initialize the model in each executor |
| 212 | + newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))( |
| 213 | + seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) => |
| 214 | + var lwc = lastWordCount |
| 215 | + var wc = wordCount |
| 216 | + if (wordCount - lastWordCount > 10000) { |
| 217 | + lwc = wordCount |
| 218 | + alpha = startingAlpha * (1 - wordCount.toDouble / (trainWordsCount + 1)) |
| 219 | + if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 |
| 220 | + logInfo("wordCount = " + wordCount + ", alpha = " + alpha) |
| 221 | + } |
| 222 | + wc += sentence.size |
| 223 | + var pos = 0 |
| 224 | + while (pos < sentence.size) { |
| 225 | + val word = sentence(pos) |
| 226 | + // TODO: fix random seed |
| 227 | + val b = Random.nextInt(window) |
| 228 | + // Train Skip-gram |
| 229 | + var a = b |
| 230 | + while (a < window * 2 + 1 - b) { |
| 231 | + if (a != window) { |
| 232 | + val c = pos - window + a |
| 233 | + if (c >= 0 && c < sentence.size) { |
| 234 | + val lastWord = sentence(c) |
| 235 | + val l1 = lastWord * layer1Size |
| 236 | + val neu1e = new Array[Double](layer1Size) |
| 237 | + //HS |
| 238 | + var d = 0 |
| 239 | + while (d < vocab(word).codeLen) { |
| 240 | + val l2 = vocab(word).point(d) * layer1Size |
| 241 | + // Propagate hidden -> output |
| 242 | + var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) |
| 243 | + if (f > -MAX_EXP && f < MAX_EXP) { |
| 244 | + val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt |
| 245 | + f = expTable.value(ind) |
| 246 | + val g = (1 - vocab(word).code(d) - f) * alpha |
| 247 | + blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) |
| 248 | + blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) |
| 249 | + } |
| 250 | + d += 1 |
| 251 | + } |
| 252 | + blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1) |
| 253 | + } |
| 254 | + } |
| 255 | + a += 1 |
| 256 | + } |
| 257 | + pos += 1 |
| 258 | + } |
| 259 | + (syn0, syn1, lwc, wc) |
| 260 | + }, |
| 261 | + combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => |
| 262 | + val n = syn0_1.length |
| 263 | + blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) |
| 264 | + blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) |
| 265 | + (syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2) |
| 266 | + }) |
| 267 | + |
| 268 | + val wordMap = new Array[(String, Array[Double])](vocabSize) |
| 269 | + var i = 0 |
| 270 | + while (i < vocabSize) { |
| 271 | + val word = vocab(i).word |
| 272 | + val vector = new Array[Double](layer1Size) |
| 273 | + Array.copy(aggSyn0, i * layer1Size, vector, 0, layer1Size) |
| 274 | + wordMap(i) = (word, vector) |
| 275 | + i += 1 |
| 276 | + } |
| 277 | + val modelRDD = sc.parallelize(wordMap,100).partitionBy(new HashPartitioner(100)) |
| 278 | + new Word2VecModel(modelRDD) |
| 279 | + } |
| 280 | +} |
| 281 | + |
| 282 | +class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable { |
| 283 | + |
| 284 | + val model = _model |
| 285 | + |
| 286 | + private def distance(v1: Array[Double], v2: Array[Double]): Double = { |
| 287 | + require(v1.length == v2.length, "Vectors should have the same length") |
| 288 | + val n = v1.length |
| 289 | + val norm1 = blas.dnrm2(n, v1, 1) |
| 290 | + val norm2 = blas.dnrm2(n, v2, 1) |
| 291 | + if (norm1 == 0 || norm2 == 0) return 0.0 |
| 292 | + blas.ddot(n, v1, 1, v2,1) / norm1 / norm2 |
| 293 | + } |
| 294 | + |
| 295 | + def transform(word: String): Array[Double] = { |
| 296 | + val result = model.lookup(word) |
| 297 | + if (result.isEmpty) Array[Double]() |
| 298 | + else result(0) |
| 299 | + } |
| 300 | + |
| 301 | + def transform(dataset: RDD[String]): RDD[Array[Double]] = { |
| 302 | + dataset.map(word => transform(word)) |
| 303 | + } |
| 304 | + |
| 305 | + def findSynonyms(word: String, num: Int): Array[(String, Double)] = { |
| 306 | + val vector = transform(word) |
| 307 | + if (vector.isEmpty) Array[(String, Double)]() |
| 308 | + else findSynonyms(vector,num) |
| 309 | + } |
| 310 | + |
| 311 | + def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = { |
| 312 | + require(num > 0, "Number of similar words should > 0") |
| 313 | + val topK = model.map( |
| 314 | + {case(w, vec) => (distance(vector, vec), w)}) |
| 315 | + .sortByKey(ascending = false) |
| 316 | + .take(num + 1) |
| 317 | + .map({case (dist, w) => (w, dist)}).drop(1) |
| 318 | + |
| 319 | + topK |
| 320 | + } |
| 321 | +} |
| 322 | + |
| 323 | +object Word2Vec extends Serializable with Logging { |
| 324 | + def train( |
| 325 | + input: RDD[String], |
| 326 | + size: Int, |
| 327 | + startingAlpha: Double, |
| 328 | + window: Int, |
| 329 | + minCount: Int): Word2VecModel = { |
| 330 | + new Word2Vec(size,startingAlpha, window, minCount).fit(input) |
| 331 | + } |
| 332 | + |
| 333 | + def main(args: Array[String]) { |
| 334 | + if (args.length < 6) { |
| 335 | + println("Usage: word2vec input size startingAlpha window minCount num") |
| 336 | + sys.exit(1) |
| 337 | + } |
| 338 | + val conf = new SparkConf() |
| 339 | + .setAppName("word2vec") |
| 340 | + |
| 341 | + val sc = new SparkContext(conf) |
| 342 | + val input = sc.textFile(args(0)) |
| 343 | + val size = args(1).toInt |
| 344 | + val startingAlpha = args(2).toDouble |
| 345 | + val window = args(3).toInt |
| 346 | + val minCount = args(4).toInt |
| 347 | + val num = args(5).toInt |
| 348 | + val model = train(input, size, startingAlpha, window, minCount) |
| 349 | + val vec = model.findSynonyms("china", num) |
| 350 | + for((w, dist) <- vec) logInfo(w.toString + " " + dist.toString) |
| 351 | + sc.stop() |
| 352 | + } |
| 353 | +} |
0 commit comments