Skip to content

Commit

Permalink
Rewrite TensorFlow EditDistance implementation to be based on true GT…
Browse files Browse the repository at this point in the history
…L; fixes a bug in the calculation.

This fixes the EditDistanceOp.
Change: 136788609
  • Loading branch information
ebrevdo authored and tensorflower-gardener committed Oct 21, 2016
1 parent 9c145a8 commit 7b4af07
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 27 deletions.
63 changes: 36 additions & 27 deletions tensorflow/core/lib/gtl/edit_distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,39 +49,48 @@ inline int64 LevenshteinDistance(const gtl::ArraySlice<T>& s,
const int64 s_size = s.size();
const int64 t_size = t.size();

if (s_size == 0) return t_size;
if (t_size == 0) return s_size;
if (s == t) return 0;
if (t_size > s_size) return LevenshteinDistance(t, s, cmp);

// Create work vectors
gtl::InlinedVector<int64, 32> scratch0(t_size + 1);
gtl::InlinedVector<int64, 32> scratch1(t_size + 1);

int64* previous = scratch0.data();
int64* current = scratch1.data();

// Initialize previous row of distances
std::iota(scratch0.begin(), scratch0.end(), 0);
const T* s_data = s.data();
const T* t_data = t.data();

for (int64 i = 0; i < s_size; ++i) {
// Swap current and previous rows for next iteration
std::swap(previous, current);

// Calculate current row distances from previous row
current[0] = i + 1;
if (t_size == 0) return s_size;
if (s == t) return 0;

// Fill in the rest of the row
for (int64 j = 0; j < t_size; ++j) {
const int64 cost = cmp(s[i], t[j]) ? 0 : 1;
current[j + 1] =
std::min(current[j] + 1, // deletion cost
std::min(previous[j + 1] + 1, // insertion cost
previous[j] + cost)); // substitution cost
// Create work vector
gtl::InlinedVector<int64, 32> scratch_holder(t_size);

int64* scratch = scratch_holder.data();

// Special case for i = 0: Distance between empty string and string
// of length j is just j.
for (size_t j = 1; j < t_size; ++j) scratch[j - 1] = j;

for (size_t i = 1; i <= s_size; ++i) {
// Invariant: scratch[j - 1] equals cost(i - 1, j).
int substitution_base_cost = i - 1;
int insertion_cost = i + 1;
for (size_t j = 1; j <= t_size; ++j) {
// Invariants:
// scratch[k - 1] = cost(i, k) for 0 < k < j.
// scratch[k - 1] = cost(i - 1, k) for j <= k <= t_size.
// substitution_base_cost = cost(i - 1, j - 1)
// insertion_cost = cost(i, j - 1)
const int replacement_cost = cmp(s_data[i - 1], t_data[j - 1]) ? 0 : 1;
const int substitution_cost = substitution_base_cost + replacement_cost;
const int deletion_cost = scratch[j - 1] + 1;

// Select the cheapest edit.
const int cheapest = // = cost(i, j)
std::min(deletion_cost, std::min(insertion_cost, substitution_cost));

// Restore invariant for the next iteration of the loop.
substitution_base_cost = scratch[j - 1]; // = cost(i - 1, j)
scratch[j - 1] = cheapest; // = cost(i, j)
insertion_cost = cheapest + 1; // = cost(i, j) + 1
}
}

return current[t_size];
return scratch[t_size - 1];
}

template <typename Container1, typename Container2, typename Cmp>
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/core/lib/gtl/edit_distance_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class LevenshteinDistanceTest : public ::testing::Test {
std::string grandmother_;
std::string lower_;
std::string upper_;
std::vector<char> ebab_;
std::vector<char> abcd_;

void SetUp() override {
s1_ = "1";
Expand All @@ -48,13 +50,20 @@ class LevenshteinDistanceTest : public ::testing::Test {
grandmother_ = "grandmother";
lower_ = "lower case";
upper_ = "UPPER case";
ebab_ = {'e', 'b', 'a', 'b'};
abcd_ = {'a', 'b', 'c', 'd'};
}
};

TEST_F(LevenshteinDistanceTest, BothEmpty) {
ASSERT_EQ(LevenshteinDistance(empty_, empty_, std::equal_to<char>()), 0);
}

TEST_F(LevenshteinDistanceTest, Symmetry) {
ASSERT_EQ(LevenshteinDistance(ebab_, abcd_, std::equal_to<char>()), 3);
ASSERT_EQ(LevenshteinDistance(abcd_, ebab_, std::equal_to<char>()), 3);
}

TEST_F(LevenshteinDistanceTest, OneEmpty) {
ASSERT_EQ(LevenshteinDistance(s1234_, empty_, std::equal_to<char>()), 4);
ASSERT_EQ(LevenshteinDistance(empty_, s567_, std::equal_to<char>()), 3);
Expand Down

0 comments on commit 7b4af07

Please sign in to comment.