Skip to content

Commit 57bbe5c

Browse files
committed
Merge pull request #15 from dirkgr/LoadFast
Load binary models faster
2 parents d6e34aa + 17fd83d commit 57bbe5c

File tree

1 file changed

+140
-79
lines changed

1 file changed

+140
-79
lines changed

src/main/java/com/medallia/word2vec/Word2VecModel.java

Lines changed: 140 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,49 @@
11
package com.medallia.word2vec;
22

3-
import java.io.DataInput;
4-
import java.io.DataInputStream;
53
import java.io.File;
64
import java.io.FileInputStream;
75
import java.io.IOException;
86
import java.nio.ByteOrder;
7+
import java.nio.FloatBuffer;
8+
import java.nio.MappedByteBuffer;
9+
import java.nio.channels.FileChannel;
10+
import java.util.ArrayList;
911
import java.util.List;
1012

11-
import org.apache.commons.io.input.SwappedDataInputStream;
12-
1313
import com.google.common.annotations.VisibleForTesting;
1414
import com.google.common.base.Preconditions;
1515
import com.google.common.collect.ImmutableList;
1616
import com.google.common.collect.Lists;
1717
import com.google.common.primitives.Doubles;
1818
import com.medallia.word2vec.thrift.Word2VecModelThrift;
1919
import com.medallia.word2vec.util.Common;
20+
import com.medallia.word2vec.util.ProfilingTimer;
21+
import com.medallia.word2vec.util.AC;
22+
2023

2124
/**
2225
* Represents the Word2Vec model, containing vectors for each word
23-
* <p>
26+
* <p/>
2427
* Instances of this class are obtained via:
2528
* <ul>
2629
* <li> {@link #trainer()}
2730
* <li> {@link #fromThrift(Word2VecModelThrift)}
2831
* </ul>
29-
*
32+
*
3033
* @see {@link #forSearch()}
3134
*/
3235
public class Word2VecModel {
3336
final List<String> vocab;
3437
final int layerSize;
3538
final double[] vectors;
36-
39+
private final static long ONE_GB = 1024 * 1024 * 1024;
40+
3741
Word2VecModel(Iterable<String> vocab, int layerSize, double[] vectors) {
3842
this.vocab = ImmutableList.copyOf(vocab);
3943
this.layerSize = layerSize;
4044
this.vectors = vectors;
4145
}
42-
46+
4347
/** @return Vocabulary */
4448
public Iterable<String> getVocab() {
4549
return vocab;
@@ -49,15 +53,15 @@ public Iterable<String> getVocab() {
4953
public Searcher forSearch() {
5054
return new SearcherImpl(this);
5155
}
52-
56+
5357
/** @return Serializable thrift representation */
5458
public Word2VecModelThrift toThrift() {
5559
return new Word2VecModelThrift()
56-
.setVocab(vocab)
57-
.setLayerSize(layerSize)
58-
.setVectors(Doubles.asList(vectors));
60+
.setVocab(vocab)
61+
.setLayerSize(layerSize)
62+
.setVectors(Doubles.asList(vectors));
5963
}
60-
64+
6165
/** @return {@link Word2VecModel} created from a thrift representation */
6266
public static Word2VecModel fromThrift(Word2VecModelThrift thrift) {
6367
return new Word2VecModel(
@@ -76,69 +80,126 @@ public static Word2VecModel fromTextFile(File file) throws IOException {
7680
}
7781

7882
/**
79-
* Forwards to {@link #fromBinFile(File, ByteOrder)} with the default
80-
* ByteOrder.LITTLE_ENDIAN
81-
*/
82-
public static Word2VecModel fromBinFile(File file)
83-
throws IOException {
84-
return fromBinFile(file, ByteOrder.LITTLE_ENDIAN);
85-
}
86-
87-
/**
88-
* @return {@link Word2VecModel} created from the binary representation output
89-
* by the open source C version of word2vec using the given byte order.
90-
*/
91-
public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder)
92-
throws IOException {
93-
94-
try (FileInputStream fis = new FileInputStream(file);) {
95-
DataInput in = (byteOrder == ByteOrder.BIG_ENDIAN) ?
96-
new DataInputStream(fis) : new SwappedDataInputStream(fis);
97-
98-
StringBuilder sb = new StringBuilder();
99-
char c = (char) in.readByte();
100-
while (c != '\n') {
101-
sb.append(c);
102-
c = (char) in.readByte();
103-
}
104-
String firstLine = sb.toString();
105-
int index = firstLine.indexOf(' ');
106-
Preconditions.checkState(index != -1,
107-
"Expected a space in the first line of file '%s': '%s'",
108-
file.getAbsolutePath(), firstLine);
109-
110-
int vocabSize = Integer.parseInt(firstLine.substring(0, index));
111-
int layerSize = Integer.parseInt(firstLine.substring(index + 1));
112-
113-
List<String> vocabs = Lists.newArrayList();
114-
List<Double> vectors = Lists.newArrayList();
115-
116-
for (int lineno = 0; lineno < vocabSize; lineno++) {
117-
sb = new StringBuilder();
118-
c = (char) in.readByte();
119-
while (c != ' ') {
120-
// ignore newlines in front of words (some binary files have newline,
121-
// some don't)
122-
if (c != '\n') {
123-
sb.append(c);
124-
}
125-
c = (char) in.readByte();
126-
}
127-
vocabs.add(sb.toString());
128-
129-
for (int i = 0; i < layerSize; i++) {
130-
vectors.add((double) in.readFloat());
131-
}
132-
}
133-
134-
return fromThrift(new Word2VecModelThrift()
135-
.setLayerSize(layerSize)
136-
.setVocab(vocabs)
137-
.setVectors(vectors));
138-
}
139-
}
140-
141-
/**
83+
* Forwards to {@link #fromBinFile(File, ByteOrder, ProfilingTimer)} with the default
84+
* ByteOrder.LITTLE_ENDIAN and no ProfilingTimer
85+
*/
86+
public static Word2VecModel fromBinFile(File file)
87+
throws IOException {
88+
return fromBinFile(file, ByteOrder.LITTLE_ENDIAN, ProfilingTimer.NONE);
89+
}
90+
91+
/**
92+
* Forwards to {@link #fromBinFile(File, ByteOrder, ProfilingTimer)} with no ProfilingTimer
93+
*/
94+
public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder)
95+
throws IOException {
96+
return fromBinFile(file, byteOrder, ProfilingTimer.NONE);
97+
}
98+
99+
/**
100+
* @return {@link Word2VecModel} created from the binary representation output
101+
* by the open source C version of word2vec using the given byte order.
102+
*/
103+
public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder, ProfilingTimer timer)
104+
throws IOException {
105+
106+
try (
107+
final FileInputStream fis = new FileInputStream(file);
108+
final AC ac = timer.start("Loading vectors from bin file")
109+
) {
110+
final FileChannel channel = fis.getChannel();
111+
timer.start("Reading gigabyte #1");
112+
MappedByteBuffer buffer =
113+
channel.map(
114+
FileChannel.MapMode.READ_ONLY,
115+
0,
116+
Math.min(channel.size(), Integer.MAX_VALUE));
117+
buffer.order(byteOrder);
118+
int bufferCount = 1;
119+
// Java's NIO only allows memory-mapping up to 2GB. To work around this problem, we re-map
120+
// every gigabyte. To calculate offsets correctly, we have to keep track how many gigabytes
121+
// we've already skipped. That's what this is for.
122+
123+
StringBuilder sb = new StringBuilder();
124+
char c = (char) buffer.get();
125+
while (c != '\n') {
126+
sb.append(c);
127+
c = (char) buffer.get();
128+
}
129+
String firstLine = sb.toString();
130+
int index = firstLine.indexOf(' ');
131+
Preconditions.checkState(index != -1,
132+
"Expected a space in the first line of file '%s': '%s'",
133+
file.getAbsolutePath(), firstLine);
134+
135+
final int vocabSize = Integer.parseInt(firstLine.substring(0, index));
136+
final int layerSize = Integer.parseInt(firstLine.substring(index + 1));
137+
timer.appendToLog(String.format(
138+
"Loading %d vectors with dimensionality %d",
139+
vocabSize,
140+
layerSize));
141+
142+
List<String> vocabs = new ArrayList<String>(vocabSize);
143+
double vectors[] = new double[vocabSize * layerSize];
144+
145+
long lastLogMessage = System.currentTimeMillis();
146+
final float[] floats = new float[layerSize];
147+
for (int lineno = 0; lineno < vocabSize; lineno++) {
148+
// read vocab
149+
sb.setLength(0);
150+
c = (char) buffer.get();
151+
while (c != ' ') {
152+
// ignore newlines in front of words (some binary files have newline,
153+
// some don't)
154+
if (c != '\n') {
155+
sb.append(c);
156+
}
157+
c = (char) buffer.get();
158+
}
159+
vocabs.add(sb.toString());
160+
161+
// read vector
162+
final FloatBuffer floatBuffer = buffer.asFloatBuffer();
163+
floatBuffer.get(floats);
164+
for (int i = 0; i < floats.length; ++i) {
165+
vectors[lineno * layerSize + i] = floats[i];
166+
}
167+
buffer.position(buffer.position() + 4 * layerSize);
168+
169+
// print log
170+
final long now = System.currentTimeMillis();
171+
if (now - lastLogMessage > 1000) {
172+
final double percentage = ((double) (lineno + 1) / (double) vocabSize) * 100.0;
173+
timer.appendToLog(
174+
String.format("Loaded %d/%d vectors (%f%%)", lineno + 1, vocabSize, percentage));
175+
lastLogMessage = now;
176+
}
177+
178+
// remap file
179+
if (buffer.position() > ONE_GB) {
180+
final int newPosition = (int) (buffer.position() - ONE_GB);
181+
final long size = Math.min(channel.size() - ONE_GB * bufferCount, Integer.MAX_VALUE);
182+
timer.endAndStart(
183+
"Reading gigabyte #%d. Start: %d, size: %d",
184+
bufferCount,
185+
ONE_GB * bufferCount,
186+
size);
187+
buffer = channel.map(
188+
FileChannel.MapMode.READ_ONLY,
189+
ONE_GB * bufferCount,
190+
size);
191+
buffer.order(byteOrder);
192+
buffer.position(newPosition);
193+
bufferCount += 1;
194+
}
195+
}
196+
timer.end();
197+
198+
return new Word2VecModel(vocabs, layerSize, vectors);
199+
}
200+
}
201+
202+
/**
142203
* @return {@link Word2VecModel} from the lines of the file in the text output format of the
143204
* Word2Vec C open source project.
144205
*/
@@ -155,7 +216,7 @@ static Word2VecModel fromTextFile(String filename, List<String> lines) throws IO
155216
filename,
156217
vocabSize,
157218
lines.size() - 1
158-
);
219+
);
159220

160221
for (int n = 1; n < lines.size(); n++) {
161222
String[] values = lines.get(n).split(" ");
@@ -169,7 +230,7 @@ static Word2VecModel fromTextFile(String filename, List<String> lines) throws IO
169230
n,
170231
layerSize,
171232
values.length - 1
172-
);
233+
);
173234

174235
for (int d = 1; d < values.length; d++) {
175236
vectors.add(Double.parseDouble(values[d]));
@@ -182,7 +243,7 @@ static Word2VecModel fromTextFile(String filename, List<String> lines) throws IO
182243
.setVectors(vectors);
183244
return fromThrift(thrift);
184245
}
185-
246+
186247
/** @return {@link Word2VecTrainerBuilder} for training a model */
187248
public static Word2VecTrainerBuilder trainer() {
188249
return new Word2VecTrainerBuilder();

0 commit comments

Comments
 (0)