Skip to content

Commit 1309668

Browse files
Celebiofacebook-github-bot
authored andcommitted
WebAssembly
Summary: This commit introduces WebAssembly module for fastText. Reviewed By: EdouardGrave Differential Revision: D19021740 fbshipit-source-id: e378f0bb70c0e1f4d6382e1e45af03d1e6ddb4f1
1 parent 316b4c9 commit 1309668

16 files changed

+1260
-55
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
*.o
33
*.bin
44
*.vec
5+
*.bc
6+
.DS_Store
57
data
68
fasttext
79
result
810
website/node_modules/
9-
11+
package-lock.json
12+
node_modules/

Makefile

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ coverage: fasttext
2020
debug: CXXFLAGS += -g -O0 -fno-inline
2121
debug: fasttext
2222

23+
wasm: webassembly/fasttext_wasm.js
24+
25+
wasmdebug: export EMCC_DEBUG=1
26+
wasmdebug: webassembly/fasttext_wasm.js
27+
28+
2329
args.o: src/args.cc src/args.h
2430
$(CXX) $(CXXFLAGS) -c src/args.cc
2531

@@ -63,4 +69,57 @@ fasttext: $(OBJS) src/fasttext.cc
6369
$(CXX) $(CXXFLAGS) $(OBJS) src/main.cc -o fasttext
6470

6571
clean:
66-
rm -rf *.o *.gcno *.gcda fasttext
72+
rm -rf *.o *.gcno *.gcda fasttext *.bc webassembly/fasttext_wasm.js webassembly/fasttext_wasm.wasm
73+
74+
75+
EMCXX = em++
76+
EMCXXFLAGS = --bind --std=c++11 -s WASM=1 -s ALLOW_MEMORY_GROWTH=1 -s "EXTRA_EXPORTED_RUNTIME_METHODS=['addOnPostRun', 'FS']" -s "DISABLE_EXCEPTION_CATCHING=0" -s "EXCEPTION_DEBUG=1" -s "FORCE_FILESYSTEM=1" -s "MODULARIZE=1" -s "EXPORT_ES6=1" -s 'EXPORT_NAME="FastTextModule"' -Isrc/
77+
EMOBJS = args.bc autotune.bc matrix.bc dictionary.bc loss.bc productquantizer.bc densematrix.bc quantmatrix.bc vector.bc model.bc utils.bc meter.bc fasttext.bc main.bc
78+
79+
80+
main.bc: webassembly/fasttext_wasm.cc
81+
$(EMCXX) $(EMCXXFLAGS) webassembly/fasttext_wasm.cc -o main.bc
82+
83+
args.bc: src/args.cc src/args.h
84+
$(EMCXX) $(EMCXXFLAGS) src/args.cc -o args.bc
85+
86+
autotune.bc: src/autotune.cc src/autotune.h
87+
$(EMCXX) $(EMCXXFLAGS) src/autotune.cc -o autotune.bc
88+
89+
matrix.bc: src/matrix.cc src/matrix.h
90+
$(EMCXX) $(EMCXXFLAGS) src/matrix.cc -o matrix.bc
91+
92+
dictionary.bc: src/dictionary.cc src/dictionary.h src/args.h
93+
$(EMCXX) $(EMCXXFLAGS) src/dictionary.cc -o dictionary.bc
94+
95+
loss.bc: src/loss.cc src/loss.h src/matrix.h src/real.h
96+
$(EMCXX) $(EMCXXFLAGS) src/loss.cc -o loss.bc
97+
98+
productquantizer.bc: src/productquantizer.cc src/productquantizer.h src/utils.h
99+
$(EMCXX) $(EMCXXFLAGS) src/productquantizer.cc -o productquantizer.bc
100+
101+
densematrix.bc: src/densematrix.cc src/densematrix.h src/utils.h src/matrix.h
102+
$(EMCXX) $(EMCXXFLAGS) src/densematrix.cc -o densematrix.bc
103+
104+
quantmatrix.bc: src/quantmatrix.cc src/quantmatrix.h src/utils.h src/matrix.h
105+
$(EMCXX) $(EMCXXFLAGS) src/quantmatrix.cc -o quantmatrix.bc
106+
107+
vector.bc: src/vector.cc src/vector.h src/utils.h
108+
$(EMCXX) $(EMCXXFLAGS) src/vector.cc -o vector.bc
109+
110+
model.bc: src/model.cc src/model.h src/args.h
111+
$(EMCXX) $(EMCXXFLAGS) src/model.cc -o model.bc
112+
113+
utils.bc: src/utils.cc src/utils.h
114+
$(EMCXX) $(EMCXXFLAGS) src/utils.cc -o utils.bc
115+
116+
meter.bc: src/meter.cc src/meter.h
117+
$(EMCXX) $(EMCXXFLAGS) src/meter.cc -o meter.bc
118+
119+
fasttext.bc: src/fasttext.cc src/*.h
120+
$(EMCXX) $(EMCXXFLAGS) src/fasttext.cc -o fasttext.bc
121+
122+
webassembly/fasttext_wasm.js: $(EMOBJS) webassembly/fasttext_wasm.cc Makefile
123+
$(EMCXX) $(EMCXXFLAGS) $(EMOBJS) -o webassembly/fasttext_wasm.js
124+
125+

src/args.cc

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -262,43 +262,49 @@ void Args::printTrainingHelp() {
262262
std::cerr
263263
<< "\nThe following arguments for training are optional:\n"
264264
<< " -lr learning rate [" << lr << "]\n"
265-
<< " -lrUpdateRate change the rate of updates for the learning rate ["
265+
<< " -lrUpdateRate change the rate of updates for the learning "
266+
"rate ["
266267
<< lrUpdateRate << "]\n"
267268
<< " -dim size of word vectors [" << dim << "]\n"
268269
<< " -ws size of the context window [" << ws << "]\n"
269270
<< " -epoch number of epochs [" << epoch << "]\n"
270271
<< " -neg number of negatives sampled [" << neg << "]\n"
271272
<< " -loss loss function {ns, hs, softmax, one-vs-all} ["
272273
<< lossToString(loss) << "]\n"
273-
<< " -thread number of threads (set to 1 to ensure reproducible results) ["
274+
<< " -thread number of threads (set to 1 to ensure "
275+
"reproducible results) ["
274276
<< thread << "]\n"
275-
<< " -pretrainedVectors pretrained word vectors for supervised learning ["
277+
<< " -pretrainedVectors pretrained word vectors for supervised "
278+
"learning ["
276279
<< pretrainedVectors << "]\n"
277280
<< " -saveOutput whether output params should be saved ["
278281
<< boolToString(saveOutput) << "]\n"
279282
<< " -seed random generator seed [" << seed << "]\n";
280283
}
281284

282285
void Args::printAutotuneHelp() {
283-
std::cerr
284-
<< "\nThe following arguments are for autotune:\n"
285-
<< " -autotune-validation validation file to be used for evaluation\n"
286-
<< " -autotune-metric metric objective {f1, f1:labelname} ["
287-
<< autotuneMetric << "]\n"
288-
<< " -autotune-predictions number of predictions used for evaluation ["
289-
<< autotunePredictions << "]\n"
290-
<< " -autotune-duration maximum duration in seconds ["
291-
<< autotuneDuration << "]\n"
292-
<< " -autotune-modelsize constraint model file size ["
293-
<< autotuneModelSize << "] (empty = do not quantize)\n";
286+
std::cerr << "\nThe following arguments are for autotune:\n"
287+
<< " -autotune-validation validation file to be used "
288+
"for evaluation\n"
289+
<< " -autotune-metric metric objective {f1, "
290+
"f1:labelname} ["
291+
<< autotuneMetric << "]\n"
292+
<< " -autotune-predictions number of predictions used "
293+
"for evaluation ["
294+
<< autotunePredictions << "]\n"
295+
<< " -autotune-duration maximum duration in seconds ["
296+
<< autotuneDuration << "]\n"
297+
<< " -autotune-modelsize constraint model file size ["
298+
<< autotuneModelSize << "] (empty = do not quantize)\n";
294299
}
295300

296301
void Args::printQuantizationHelp() {
297302
std::cerr
298303
<< "\nThe following arguments for quantization are optional:\n"
299304
<< " -cutoff number of words and ngrams to retain ["
300305
<< cutoff << "]\n"
301-
<< " -retrain whether embeddings are finetuned if a cutoff is applied ["
306+
<< " -retrain whether embeddings are finetuned if a cutoff "
307+
"is applied ["
302308
<< boolToString(retrain) << "]\n"
303309
<< " -qnorm whether the norm is quantized separately ["
304310
<< boolToString(qnorm) << "]\n"

src/autotune.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,10 @@ void Autotune::train(const Args& autotuneArgs) {
416416
if (!sizeConstraintWarning && trials_ > 10 &&
417417
sizeConstraintFailed_ > (trials_ / 2)) {
418418
sizeConstraintWarning = true;
419-
std::cerr
420-
<< std::endl
421-
<< "Warning : requested model size is probably too small. You may want to increase `autotune-modelsize`."
422-
<< std::endl;
419+
std::cerr << std::endl
420+
<< "Warning : requested model size is probably too small. "
421+
"You may want to increase `autotune-modelsize`."
422+
<< std::endl;
423423
}
424424
}
425425
} catch (DenseMatrix::EncounteredNaNError&) {
@@ -442,10 +442,12 @@ void Autotune::train(const Args& autotuneArgs) {
442442
std::string errorMessage;
443443
if (sizeConstraintWarning) {
444444
errorMessage =
445-
"Couldn't fulfil model size constraint: please increase `autotune-modelsize`.";
445+
"Couldn't fulfil model size constraint: please increase "
446+
"`autotune-modelsize`.";
446447
} else {
447448
errorMessage =
448-
"Didn't have enough time to train once: please increase `autotune-duration`.";
449+
"Didn't have enough time to train once: please increase "
450+
"`autotune-duration`.";
449451
}
450452
throw std::runtime_error(errorMessage);
451453
} else {

src/densematrix.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,17 @@ void DenseMatrix::uniformThread(real a, int block, int32_t seed) {
4343
}
4444

4545
void DenseMatrix::uniform(real a, unsigned int thread, int32_t seed) {
46-
std::vector<std::thread> threads;
47-
for (int i = 0; i < thread; i++) {
48-
threads.push_back(std::thread([=]() { uniformThread(a, i, seed); }));
49-
}
50-
for (int32_t i = 0; i < threads.size(); i++) {
51-
threads[i].join();
46+
if (thread > 1) {
47+
std::vector<std::thread> threads;
48+
for (int i = 0; i < thread; i++) {
49+
threads.push_back(std::thread([=]() { uniformThread(a, i, seed); }));
50+
}
51+
for (int32_t i = 0; i < threads.size(); i++) {
52+
threads[i].join();
53+
}
54+
} else {
55+
// webassembly can't instantiate `std::thread`
56+
uniformThread(a, 0, seed);
5257
}
5358
}
5459

src/fasttext.cc

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -263,22 +263,30 @@ void FastText::loadModel(std::istream& in) {
263263
buildModel();
264264
}
265265

266-
void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
266+
std::tuple<int64_t, double, double> FastText::progressInfo(real progress) {
267267
double t = utils::getDuration(start_, std::chrono::steady_clock::now());
268268
double lr = args_->lr * (1.0 - progress);
269269
double wst = 0;
270270

271271
int64_t eta = 2592000; // Default to one month in seconds (720 * 3600)
272272

273273
if (progress > 0 && t >= 0) {
274-
progress = progress * 100;
275-
eta = t * (100 - progress) / progress;
274+
eta = t * (1 - progress) / progress;
276275
wst = double(tokenCount_) / t / args_->thread;
277276
}
278277

278+
return std::tuple<double, double, int64_t>(wst, lr, eta);
279+
}
280+
281+
void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
282+
double wst;
283+
double lr;
284+
int64_t eta;
285+
std::tie<double, double, int64_t>(wst, lr, eta) = progressInfo(progress);
286+
279287
log_stream << std::fixed;
280288
log_stream << "Progress: ";
281-
log_stream << std::setprecision(1) << std::setw(5) << progress << "%";
289+
log_stream << std::setprecision(1) << std::setw(5) << (progress * 100) << "%";
282290
log_stream << " words/sec/thread: " << std::setw(7) << int64_t(wst);
283291
log_stream << " lr: " << std::setw(9) << std::setprecision(6) << lr;
284292
log_stream << " avg.loss: " << std::setw(9) << std::setprecision(6) << loss;
@@ -304,7 +312,7 @@ std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
304312
return idx;
305313
}
306314

307-
void FastText::quantize(const Args& qargs) {
315+
void FastText::quantize(const Args& qargs, const TrainCallback& callback) {
308316
if (args_->model != model_name::sup) {
309317
throw std::invalid_argument(
310318
"For now we only support quantization of supervised models");
@@ -336,18 +344,16 @@ void FastText::quantize(const Args& qargs) {
336344
args_->verbose = qargs.verbose;
337345
auto loss = createLoss(output_);
338346
model_ = std::make_shared<Model>(input, output, loss, normalizeGradient);
339-
startThreads();
347+
startThreads(callback);
340348
}
341349
}
342-
343350
input_ = std::make_shared<QuantMatrix>(
344351
std::move(*(input.get())), qargs.dsub, qargs.qnorm);
345352

346353
if (args_->qout) {
347354
output_ = std::make_shared<QuantMatrix>(
348355
std::move(*(output.get())), 2, qargs.qnorm);
349356
}
350-
351357
quant_ = true;
352358
auto loss = createLoss(output_);
353359
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
@@ -615,7 +621,7 @@ bool FastText::keepTraining(const int64_t ntokens) const {
615621
return tokenCount_ < args_->epoch * ntokens && !trainException_;
616622
}
617623

618-
void FastText::trainThread(int32_t threadId) {
624+
void FastText::trainThread(int32_t threadId, const TrainCallback& callback) {
619625
std::ifstream ifs(args_->input);
620626
utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
621627

@@ -624,9 +630,18 @@ void FastText::trainThread(int32_t threadId) {
624630
const int64_t ntokens = dict_->ntokens();
625631
int64_t localTokenCount = 0;
626632
std::vector<int32_t> line, labels;
633+
uint64_t callbackCounter = 0;
627634
try {
628635
while (keepTraining(ntokens)) {
629636
real progress = real(tokenCount_) / (args_->epoch * ntokens);
637+
if (callback && ((callbackCounter++ % 64) == 0)) {
638+
double wst;
639+
double lr;
640+
int64_t eta;
641+
std::tie<double, double, int64_t>(wst, lr, eta) =
642+
progressInfo(progress);
643+
callback(progress, loss_, wst, lr, eta);
644+
}
630645
real lr = args_->lr * (1.0 - progress);
631646
if (args_->model == model_name::sup) {
632647
localTokenCount += dict_->getLine(ifs, line, labels);
@@ -717,7 +732,7 @@ std::shared_ptr<Matrix> FastText::createTrainOutputMatrix() const {
717732
return output;
718733
}
719734

720-
void FastText::train(const Args& args) {
735+
void FastText::train(const Args& args, const TrainCallback& callback) {
721736
args_ = std::make_shared<Args>(args);
722737
dict_ = std::make_shared<Dictionary>(args_);
723738
if (args_->input == "-") {
@@ -742,7 +757,7 @@ void FastText::train(const Args& args) {
742757
auto loss = createLoss(output_);
743758
bool normalizeGradient = (args_->model == model_name::sup);
744759
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
745-
startThreads();
760+
startThreads(callback);
746761
}
747762

748763
void FastText::abort() {
@@ -753,14 +768,19 @@ void FastText::abort() {
753768
}
754769
}
755770

756-
void FastText::startThreads() {
771+
void FastText::startThreads(const TrainCallback& callback) {
757772
start_ = std::chrono::steady_clock::now();
758773
tokenCount_ = 0;
759774
loss_ = -1;
760775
trainException_ = nullptr;
761776
std::vector<std::thread> threads;
762-
for (int32_t i = 0; i < args_->thread; i++) {
763-
threads.push_back(std::thread([=]() { trainThread(i); }));
777+
if (args_->thread > 1) {
778+
for (int32_t i = 0; i < args_->thread; i++) {
779+
threads.push_back(std::thread([=]() { trainThread(i, callback); }));
780+
}
781+
} else {
782+
// webassembly can't instantiate `std::thread`
783+
trainThread(0, callback);
764784
}
765785
const int64_t ntokens = dict_->ntokens();
766786
// Same condition as trainThread
@@ -772,7 +792,7 @@ void FastText::startThreads() {
772792
printInfo(progress, loss_, std::cerr);
773793
}
774794
}
775-
for (int32_t i = 0; i < args_->thread; i++) {
795+
for (int32_t i = 0; i < threads.size(); i++) {
776796
threads[i].join();
777797
}
778798
if (trainException_) {

0 commit comments

Comments
 (0)