11package com .medallia .word2vec ;
22
3- import java .io .DataInput ;
4- import java .io .DataInputStream ;
53import java .io .File ;
64import java .io .FileInputStream ;
75import java .io .IOException ;
86import java .nio .ByteOrder ;
7+ import java .nio .FloatBuffer ;
8+ import java .nio .MappedByteBuffer ;
9+ import java .nio .channels .FileChannel ;
10+ import java .util .ArrayList ;
911import java .util .List ;
1012
11- import org .apache .commons .io .input .SwappedDataInputStream ;
12-
1313import com .google .common .annotations .VisibleForTesting ;
1414import com .google .common .base .Preconditions ;
1515import com .google .common .collect .ImmutableList ;
1616import com .google .common .collect .Lists ;
1717import com .google .common .primitives .Doubles ;
1818import com .medallia .word2vec .thrift .Word2VecModelThrift ;
1919import com .medallia .word2vec .util .Common ;
20+ import com .medallia .word2vec .util .ProfilingTimer ;
21+ import com .medallia .word2vec .util .AC ;
22+
2023
2124/**
2225 * Represents the Word2Vec model, containing vectors for each word
23- * <p>
26+ * <p/ >
2427 * Instances of this class are obtained via:
2528 * <ul>
2629 * <li> {@link #trainer()}
2730 * <li> {@link #fromThrift(Word2VecModelThrift)}
2831 * </ul>
29- *
32+ *
3033 * @see {@link #forSearch()}
3134 */
3235public class Word2VecModel {
3336 final List <String > vocab ;
3437 final int layerSize ;
3538 final double [] vectors ;
36-
39+ private final static long ONE_GB = 1024 * 1024 * 1024 ;
40+
3741 Word2VecModel (Iterable <String > vocab , int layerSize , double [] vectors ) {
3842 this .vocab = ImmutableList .copyOf (vocab );
3943 this .layerSize = layerSize ;
4044 this .vectors = vectors ;
4145 }
42-
46+
4347 /** @return Vocabulary */
4448 public Iterable <String > getVocab () {
4549 return vocab ;
@@ -49,15 +53,15 @@ public Iterable<String> getVocab() {
4953 public Searcher forSearch () {
5054 return new SearcherImpl (this );
5155 }
52-
56+
5357 /** @return Serializable thrift representation */
5458 public Word2VecModelThrift toThrift () {
5559 return new Word2VecModelThrift ()
56- .setVocab (vocab )
57- .setLayerSize (layerSize )
58- .setVectors (Doubles .asList (vectors ));
60+ .setVocab (vocab )
61+ .setLayerSize (layerSize )
62+ .setVectors (Doubles .asList (vectors ));
5963 }
60-
64+
6165 /** @return {@link Word2VecModel} created from a thrift representation */
6266 public static Word2VecModel fromThrift (Word2VecModelThrift thrift ) {
6367 return new Word2VecModel (
@@ -76,69 +80,126 @@ public static Word2VecModel fromTextFile(File file) throws IOException {
7680 }
7781
7882 /**
79- * Forwards to {@link #fromBinFile(File, ByteOrder)} with the default
80- * ByteOrder.LITTLE_ENDIAN
81- */
82- public static Word2VecModel fromBinFile (File file )
83- throws IOException {
84- return fromBinFile (file , ByteOrder .LITTLE_ENDIAN );
85- }
86-
87- /**
88- * @return {@link Word2VecModel} created from the binary representation output
89- * by the open source C version of word2vec using the given byte order.
90- */
91- public static Word2VecModel fromBinFile (File file , ByteOrder byteOrder )
92- throws IOException {
93-
94- try (FileInputStream fis = new FileInputStream (file );) {
95- DataInput in = (byteOrder == ByteOrder .BIG_ENDIAN ) ?
96- new DataInputStream (fis ) : new SwappedDataInputStream (fis );
97-
98- StringBuilder sb = new StringBuilder ();
99- char c = (char ) in .readByte ();
100- while (c != '\n' ) {
101- sb .append (c );
102- c = (char ) in .readByte ();
103- }
104- String firstLine = sb .toString ();
105- int index = firstLine .indexOf (' ' );
106- Preconditions .checkState (index != -1 ,
107- "Expected a space in the first line of file '%s': '%s'" ,
108- file .getAbsolutePath (), firstLine );
109-
110- int vocabSize = Integer .parseInt (firstLine .substring (0 , index ));
111- int layerSize = Integer .parseInt (firstLine .substring (index + 1 ));
112-
113- List <String > vocabs = Lists .newArrayList ();
114- List <Double > vectors = Lists .newArrayList ();
115-
116- for (int lineno = 0 ; lineno < vocabSize ; lineno ++) {
117- sb = new StringBuilder ();
118- c = (char ) in .readByte ();
119- while (c != ' ' ) {
120- // ignore newlines in front of words (some binary files have newline,
121- // some don't)
122- if (c != '\n' ) {
123- sb .append (c );
124- }
125- c = (char ) in .readByte ();
126- }
127- vocabs .add (sb .toString ());
128-
129- for (int i = 0 ; i < layerSize ; i ++) {
130- vectors .add ((double ) in .readFloat ());
131- }
132- }
133-
134- return fromThrift (new Word2VecModelThrift ()
135- .setLayerSize (layerSize )
136- .setVocab (vocabs )
137- .setVectors (vectors ));
138- }
139- }
140-
141- /**
83+ * Forwards to {@link #fromBinFile(File, ByteOrder, ProfilingTimer)} with the default
84+ * ByteOrder.LITTLE_ENDIAN and no ProfilingTimer
85+ */
86+ public static Word2VecModel fromBinFile (File file )
87+ throws IOException {
88+ return fromBinFile (file , ByteOrder .LITTLE_ENDIAN , ProfilingTimer .NONE );
89+ }
90+
91+ /**
92+ * Forwards to {@link #fromBinFile(File, ByteOrder, ProfilingTimer)} with no ProfilingTimer
93+ */
94+ public static Word2VecModel fromBinFile (File file , ByteOrder byteOrder )
95+ throws IOException {
96+ return fromBinFile (file , byteOrder , ProfilingTimer .NONE );
97+ }
98+
99+ /**
100+ * @return {@link Word2VecModel} created from the binary representation output
101+ * by the open source C version of word2vec using the given byte order.
102+ */
103+ public static Word2VecModel fromBinFile (File file , ByteOrder byteOrder , ProfilingTimer timer )
104+ throws IOException {
105+
106+ try (
107+ final FileInputStream fis = new FileInputStream (file );
108+ final AC ac = timer .start ("Loading vectors from bin file" )
109+ ) {
110+ final FileChannel channel = fis .getChannel ();
111+ timer .start ("Reading gigabyte #1" );
112+ MappedByteBuffer buffer =
113+ channel .map (
114+ FileChannel .MapMode .READ_ONLY ,
115+ 0 ,
116+ Math .min (channel .size (), Integer .MAX_VALUE ));
117+ buffer .order (byteOrder );
118+ int bufferCount = 1 ;
119+ // Java's NIO only allows memory-mapping up to 2GB. To work around this problem, we re-map
120+ // every gigabyte. To calculate offsets correctly, we have to keep track how many gigabytes
121+ // we've already skipped. That's what this is for.
122+
123+ StringBuilder sb = new StringBuilder ();
124+ char c = (char ) buffer .get ();
125+ while (c != '\n' ) {
126+ sb .append (c );
127+ c = (char ) buffer .get ();
128+ }
129+ String firstLine = sb .toString ();
130+ int index = firstLine .indexOf (' ' );
131+ Preconditions .checkState (index != -1 ,
132+ "Expected a space in the first line of file '%s': '%s'" ,
133+ file .getAbsolutePath (), firstLine );
134+
135+ final int vocabSize = Integer .parseInt (firstLine .substring (0 , index ));
136+ final int layerSize = Integer .parseInt (firstLine .substring (index + 1 ));
137+ timer .appendToLog (String .format (
138+ "Loading %d vectors with dimensionality %d" ,
139+ vocabSize ,
140+ layerSize ));
141+
142+ List <String > vocabs = new ArrayList <String >(vocabSize );
143+ double vectors [] = new double [vocabSize * layerSize ];
144+
145+ long lastLogMessage = System .currentTimeMillis ();
146+ final float [] floats = new float [layerSize ];
147+ for (int lineno = 0 ; lineno < vocabSize ; lineno ++) {
148+ // read vocab
149+ sb .setLength (0 );
150+ c = (char ) buffer .get ();
151+ while (c != ' ' ) {
152+ // ignore newlines in front of words (some binary files have newline,
153+ // some don't)
154+ if (c != '\n' ) {
155+ sb .append (c );
156+ }
157+ c = (char ) buffer .get ();
158+ }
159+ vocabs .add (sb .toString ());
160+
161+ // read vector
162+ final FloatBuffer floatBuffer = buffer .asFloatBuffer ();
163+ floatBuffer .get (floats );
164+ for (int i = 0 ; i < floats .length ; ++i ) {
165+ vectors [lineno * layerSize + i ] = floats [i ];
166+ }
167+ buffer .position (buffer .position () + 4 * layerSize );
168+
169+ // print log
170+ final long now = System .currentTimeMillis ();
171+ if (now - lastLogMessage > 1000 ) {
172+ final double percentage = ((double ) (lineno + 1 ) / (double ) vocabSize ) * 100.0 ;
173+ timer .appendToLog (
174+ String .format ("Loaded %d/%d vectors (%f%%)" , lineno + 1 , vocabSize , percentage ));
175+ lastLogMessage = now ;
176+ }
177+
178+ // remap file
179+ if (buffer .position () > ONE_GB ) {
180+ final int newPosition = (int ) (buffer .position () - ONE_GB );
181+ final long size = Math .min (channel .size () - ONE_GB * bufferCount , Integer .MAX_VALUE );
182+ timer .endAndStart (
183+ "Reading gigabyte #%d. Start: %d, size: %d" ,
184+ bufferCount ,
185+ ONE_GB * bufferCount ,
186+ size );
187+ buffer = channel .map (
188+ FileChannel .MapMode .READ_ONLY ,
189+ ONE_GB * bufferCount ,
190+ size );
191+ buffer .order (byteOrder );
192+ buffer .position (newPosition );
193+ bufferCount += 1 ;
194+ }
195+ }
196+ timer .end ();
197+
198+ return new Word2VecModel (vocabs , layerSize , vectors );
199+ }
200+ }
201+
202+ /**
142203 * @return {@link Word2VecModel} from the lines of the file in the text output format of the
143204 * Word2Vec C open source project.
144205 */
@@ -155,7 +216,7 @@ static Word2VecModel fromTextFile(String filename, List<String> lines) throws IO
155216 filename ,
156217 vocabSize ,
157218 lines .size () - 1
158- );
219+ );
159220
160221 for (int n = 1 ; n < lines .size (); n ++) {
161222 String [] values = lines .get (n ).split (" " );
@@ -169,7 +230,7 @@ static Word2VecModel fromTextFile(String filename, List<String> lines) throws IO
169230 n ,
170231 layerSize ,
171232 values .length - 1
172- );
233+ );
173234
174235 for (int d = 1 ; d < values .length ; d ++) {
175236 vectors .add (Double .parseDouble (values [d ]));
@@ -182,7 +243,7 @@ static Word2VecModel fromTextFile(String filename, List<String> lines) throws IO
182243 .setVectors (vectors );
183244 return fromThrift (thrift );
184245 }
185-
246+
186247 /** @return {@link Word2VecTrainerBuilder} for training a model */
187248 public static Word2VecTrainerBuilder trainer () {
188249 return new Word2VecTrainerBuilder ();
0 commit comments