Skip to content

Commit

Permalink
Expose threshold argument of Jaro-Winkler similarity (duckdb#12079)
Browse files Browse the repository at this point in the history
Following up on duckdb#10345, but starting with Jaro-Winkler similarity. This
PR adds an optional third argument to the Jaro and Jaro-Winkler
functions that acts as a "threshold" -- similarities below the threshold
are reported as zero. This was already implemented in the vendored
implementation of Jaro-Winkler, just not exposed to the DuckDB user.

If this is received positively, I'd like to update the vendored
RapidFuzz and use it for all string comparisons, which would allow
exposing this argument for those as well.

**NOTE: I am not great at C++. I expect this will need a lot of
cleanup.**
  • Loading branch information
Mytherin authored Oct 28, 2024
2 parents 895a496 + 5f929c2 commit baf4304
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 32 deletions.
4 changes: 2 additions & 2 deletions extension/core_functions/function_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ static const StaticFunctionDefinition core_functions[] = {
DUCKDB_SCALAR_FUNCTION_SET(ISODayOfWeekFun),
DUCKDB_SCALAR_FUNCTION_SET(ISOYearFun),
DUCKDB_SCALAR_FUNCTION(JaccardFun),
DUCKDB_SCALAR_FUNCTION(JaroSimilarityFun),
DUCKDB_SCALAR_FUNCTION(JaroWinklerSimilarityFun),
DUCKDB_SCALAR_FUNCTION_SET(JaroSimilarityFun),
DUCKDB_SCALAR_FUNCTION_SET(JaroWinklerSimilarityFun),
DUCKDB_SCALAR_FUNCTION_SET(JulianDayFun),
DUCKDB_AGGREGATE_FUNCTION(KahanSumFun),
DUCKDB_AGGREGATE_FUNCTION(KurtosisFun),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,20 +176,20 @@ struct JaccardFun {

struct JaroSimilarityFun {
static constexpr const char *Name = "jaro_similarity";
static constexpr const char *Parameters = "str1,str2";
static constexpr const char *Parameters = "str1,str2,score_cutoff";
static constexpr const char *Description = "The Jaro similarity between two strings. Different case is considered different. Returns a number between 0 and 1";
static constexpr const char *Example = "jaro_similarity('duck','duckdb')";
static constexpr const char *Example = "jaro_similarity('duck', 'duckdb', 0.5)";

static ScalarFunction GetFunction();
static ScalarFunctionSet GetFunctions();
};

struct JaroWinklerSimilarityFun {
static constexpr const char *Name = "jaro_winkler_similarity";
static constexpr const char *Parameters = "str1,str2";
static constexpr const char *Parameters = "str1,str2,score_cutoff";
static constexpr const char *Description = "The Jaro-Winkler similarity between two strings. Different case is considered different. Returns a number between 0 and 1";
static constexpr const char *Example = "jaro_winkler_similarity('duck','duckdb')";
static constexpr const char *Example = "jaro_winkler_similarity('duck', 'duckdb', 0.5)";

static ScalarFunction GetFunction();
static ScalarFunctionSet GetFunctions();
};

struct LeftFun {
Expand Down
12 changes: 6 additions & 6 deletions extension/core_functions/scalar/string/functions.json
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,17 @@
},
{
"name": "jaro_similarity",
"parameters": "str1,str2",
"parameters": "str1,str2,score_cutoff",
"description": "The Jaro similarity between two strings. Different case is considered different. Returns a number between 0 and 1",
"example": "jaro_similarity('duck','duckdb')",
"type": "scalar_function"
"example": "jaro_similarity('duck', 'duckdb', 0.5)",
"type": "scalar_function_set"
},
{
"name": "jaro_winkler_similarity",
"parameters": "str1,str2",
"parameters": "str1,str2,score_cutoff",
"description": "The Jaro-Winkler similarity between two strings. Different case is considered different. Returns a number between 0 and 1",
"example": "jaro_winkler_similarity('duck','duckdb')",
"type": "scalar_function"
"example": "jaro_winkler_similarity('duck', 'duckdb', 0.5)",
"type": "scalar_function_set"
},
{
"name": "left",
Expand Down
77 changes: 59 additions & 18 deletions extension/core_functions/scalar/string/jaro_winkler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@

namespace duckdb {

static inline double JaroScalarFunction(const string_t &s1, const string_t &s2) {
static inline double JaroScalarFunction(const string_t &s1, const string_t &s2, const double_t &score_cutoff = 0.0) {
auto s1_begin = s1.GetData();
auto s2_begin = s2.GetData();
return duckdb_jaro_winkler::jaro_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, s2_begin + s2.GetSize());
return duckdb_jaro_winkler::jaro_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, s2_begin + s2.GetSize(),
score_cutoff);
}

static inline double JaroWinklerScalarFunction(const string_t &s1, const string_t &s2) {
static inline double JaroWinklerScalarFunction(const string_t &s1, const string_t &s2,
const double_t &score_cutoff = 0.0) {
auto s1_begin = s1.GetData();
auto s2_begin = s2.GetData();
return duckdb_jaro_winkler::jaro_winkler_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin,
s2_begin + s2.GetSize());
s2_begin + s2.GetSize(), 0.1, score_cutoff);
}

template <class CACHED_SIMILARITY>
static void CachedFunction(Vector &constant, Vector &other, Vector &result, idx_t count) {
static void CachedFunction(Vector &constant, Vector &other, Vector &result, DataChunk &args) {
auto val = constant.GetValue(0);
idx_t count = args.size();
if (val.IsNull()) {
auto &result_validity = FlatVector::Validity(result);
result_validity.SetAllInvalid(count);
Expand All @@ -28,26 +31,46 @@ static void CachedFunction(Vector &constant, Vector &other, Vector &result, idx_

auto str_val = StringValue::Get(val);
auto cached = CACHED_SIMILARITY(str_val);
UnaryExecutor::Execute<string_t, double>(other, result, count, [&](const string_t &other_str) {
auto other_str_begin = other_str.GetData();
return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize());
});

D_ASSERT(args.ColumnCount() == 2 || args.ColumnCount() == 3);
if (args.ColumnCount() == 2) {
UnaryExecutor::Execute<string_t, double>(other, result, count, [&](const string_t &other_str) {
auto other_str_begin = other_str.GetData();
return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize());
});
} else {
auto score_cutoff = args.data[2];
BinaryExecutor::Execute<string_t, double_t, double>(
other, score_cutoff, result, count, [&](const string_t &other_str, const double_t score_cutoff) {
auto other_str_begin = other_str.GetData();
return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize(), score_cutoff);
});
}
}

template <class CACHED_SIMILARITY, class SIMILARITY_FUNCTION = std::function<double(string_t, string_t)>>
template <class CACHED_SIMILARITY, class SIMILARITY_FUNCTION>
static void TemplatedJaroWinklerFunction(DataChunk &args, Vector &result, SIMILARITY_FUNCTION fun) {
bool arg0_constant = args.data[0].GetVectorType() == VectorType::CONSTANT_VECTOR;
bool arg1_constant = args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR;
if (!(arg0_constant ^ arg1_constant)) {
// We can't optimize by caching one of the two strings
BinaryExecutor::Execute<string_t, string_t, double>(args.data[0], args.data[1], result, args.size(), fun);
return;
D_ASSERT(args.ColumnCount() == 2 || args.ColumnCount() == 3);
if (args.ColumnCount() == 2) {
BinaryExecutor::Execute<string_t, string_t, double>(
args.data[0], args.data[1], result, args.size(),
[&](const string_t &s1, const string_t &s2) { return fun(s1, s2, 0.0); });
return;
} else {
TernaryExecutor::Execute<string_t, string_t, double_t, double>(args.data[0], args.data[1], args.data[2],
result, args.size(), fun);
return;
}
}

if (arg0_constant) {
CachedFunction<CACHED_SIMILARITY>(args.data[0], args.data[1], result, args.size());
CachedFunction<CACHED_SIMILARITY>(args.data[0], args.data[1], result, args);
} else {
CachedFunction<CACHED_SIMILARITY>(args.data[1], args.data[0], result, args.size());
CachedFunction<CACHED_SIMILARITY>(args.data[1], args.data[0], result, args);
}
}

Expand All @@ -60,12 +83,30 @@ static void JaroWinklerFunction(DataChunk &args, ExpressionState &state, Vector
JaroWinklerScalarFunction);
}

ScalarFunction JaroSimilarityFun::GetFunction() {
return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroFunction);
ScalarFunctionSet JaroSimilarityFun::GetFunctions() {
ScalarFunctionSet jaro;

const auto list_type = LogicalType::LIST(LogicalType::VARCHAR);
auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroFunction);
jaro.AddFunction(fun);

fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::DOUBLE}, LogicalType::DOUBLE,
JaroFunction);
jaro.AddFunction(fun);
return jaro;
}

ScalarFunction JaroWinklerSimilarityFun::GetFunction() {
return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroWinklerFunction);
ScalarFunctionSet JaroWinklerSimilarityFun::GetFunctions() {
ScalarFunctionSet jaroWinkler;

const auto list_type = LogicalType::LIST(LogicalType::VARCHAR);
auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroWinklerFunction);
jaroWinkler.AddFunction(fun);

fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::DOUBLE}, LogicalType::DOUBLE,
JaroWinklerFunction);
jaroWinkler.AddFunction(fun);
return jaroWinkler;
}

} // namespace duckdb
21 changes: 21 additions & 0 deletions test/sql/function/string/test_jaro_winkler.test
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,27 @@ select jaro_winkler_similarity('PENNSYLVANIA', 'PENNCISYLVNIA')
----
0.8980186480186481

# test score cutoff
query T
select jaro_winkler_similarity('CRATE', 'TRACE', 0.7)
----
0.733333

query T
select jaro_winkler_similarity('CRATE', 'TRACE', 0.75)
----
0.0

query T
select jaro_winkler_similarity('000000000000000000000000000000000000000000000000000000000000000', '00000000000000000000000000000000000000000000000000000000000000000', 0.9)
----
0.9938

query T
select jaro_winkler_similarity('000000000000000000000000000000000000000000000000000000000000000', '00000000000000000000000000000000000000000000000000000000000000000', 0.995)
----
0.0

# test with table just in case
statement ok
create table test as select '0000' || range::varchar s from range(10000);
Expand Down

0 comments on commit baf4304

Please sign in to comment.