Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
getNN and getAnalogies functions handle onUnicodeError argument
Browse files Browse the repository at this point in the history
Summary: Fixing the previous pull-requests issues + refactoring

Reviewed By: EdouardGrave

Differential Revision: D20478559

fbshipit-source-id: bc92b40257a74ee548b087740bd81af3886ab1d6
  • Loading branch information
Celebio authored and facebook-github-bot committed Mar 25, 2020
1 parent 3b25b87 commit 5a5b1e6
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 31 deletions.
9 changes: 5 additions & 4 deletions python/fasttext_module/fasttext/FastText.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ def get_sentence_vector(self, text):
self.f.getSentenceVector(b, text)
return np.array(b)

def get_nearest_neighbors(self, word, k=10):
return self.f.getNN(word, k)
def get_nearest_neighbors(self, word, k=10, on_unicode_error='strict'):
return self.f.getNN(word, k, on_unicode_error)

def get_analogies(self, wordA, wordB, wordC, k=10):
return self.f.getAnalogies(wordA, wordB, wordC, k)
def get_analogies(self, wordA, wordB, wordC, k=10,
on_unicode_error='strict'):
return self.f.getAnalogies(wordA, wordB, wordC, k, on_unicode_error)

def get_word_id(self, word):
"""
Expand Down
48 changes: 24 additions & 24 deletions python/fasttext_module/fasttext/pybind/fasttext_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ py::str castToPythonString(const std::string& s, const char* onUnicodeError) {
return handle_str;
}

std::vector<std::pair<fasttext::real, py::str>> castToPythonString(
const std::vector<std::pair<fasttext::real, std::string>>& predictions,
const char* onUnicodeError) {
std::vector<std::pair<fasttext::real, py::str>> transformedPredictions;

for (const auto& prediction : predictions) {
transformedPredictions.emplace_back(
prediction.first,
castToPythonString(prediction.second, onUnicodeError));
}

return transformedPredictions;
}

std::pair<std::vector<py::str>, std::vector<py::str>> getLineText(
fasttext::FastText& m,
const std::string text,
Expand Down Expand Up @@ -339,16 +353,7 @@ PYBIND11_MODULE(fasttext_pybind, m) {
std::vector<std::pair<fasttext::real, std::string>> predictions;
m.predictLine(ioss, predictions, k, threshold);

std::vector<std::pair<fasttext::real, py::str>>
transformedPredictions;

for (const auto& prediction : predictions) {
transformedPredictions.push_back(std::make_pair(
prediction.first,
castToPythonString(prediction.second, onUnicodeError)));
}

return transformedPredictions;
return castToPythonString(predictions, onUnicodeError);
})
.def(
"multilinePredict",
Expand Down Expand Up @@ -427,28 +432,23 @@ PYBIND11_MODULE(fasttext_pybind, m) {
const std::string word) { m.getWordVector(vec, word); })
.def(
"getNN",
[](fasttext::FastText& m, const std::string& word, int32_t k,
[](fasttext::FastText& m,
const std::string& word,
int32_t k,
const char* onUnicodeError) {
std::vector<std::pair<float, std::string>> score_words = m.getNN(
word, k);
std::vector<std::pair<float, py::str>> output_list;
for (uint32_t i = 0; i < score_words.size(); i++) {
float score = score_words[i].first;
py::str word = castToPythonString(
score_words[i].second, onUnicodeError);
std::pair<float, py::str> sw_pair = std::make_pair(score, word);
output_list.push_back(sw_pair);
}

return output_list;
return castToPythonString(m.getNN(word, k), onUnicodeError);
})
.def(
"getAnalogies",
[](fasttext::FastText& m,
const std::string& wordA,
const std::string& wordB,
const std::string& wordC,
int32_t k) { return m.getAnalogies(k, wordA, wordB, wordC); })
int32_t k,
const char* onUnicodeError) {
return castToPythonString(
m.getAnalogies(k, wordA, wordB, wordC), onUnicodeError);
})
.def(
"getSubwords",
[](fasttext::FastText& m,
Expand Down
3 changes: 1 addition & 2 deletions src/autotune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,7 @@ void Autotune::train(const Args& autotuneArgs) {
autotuneArgs.getAutotuneMetric(),
autotuneArgs.getAutotuneMetricLabel());

if (bestScore_ == kUnknownBestScore ||
(currentScore > bestScore_)) {
if (bestScore_ == kUnknownBestScore || (currentScore > bestScore_)) {
bestTrainArgs = trainArgs;
bestScore_ = currentScore;
strategy_->updateBest(bestTrainArgs);
Expand Down
1 change: 0 additions & 1 deletion src/autotune.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ class Autotune {
TimeoutError() : std::runtime_error("Autotune timed out.") {}
};


public:
Autotune() = delete;
explicit Autotune(const std::shared_ptr<FastText>& fastText);
Expand Down

0 comments on commit 5a5b1e6

Please sign in to comment.