Skip to content

Commit

Permalink
Add Hash operator (only supports string inputs for now)
Browse files Browse the repository at this point in the history
Signed-off-by: Misiu Godfrey <misiu.godfrey@kraken.mapd.com>
  • Loading branch information
tmostak authored and misiugodfrey committed Aug 26, 2024
1 parent c5631a9 commit 47d94e7
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 6 deletions.
5 changes: 5 additions & 0 deletions Analyzer/Analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3816,6 +3816,11 @@ std::shared_ptr<Analyzer::Expr> LevenshteinDistanceStringOper::deep_copy() const
std::dynamic_pointer_cast<Analyzer::StringOper>(StringOper::deep_copy()));
}

std::shared_ptr<Analyzer::Expr> HashStringOper::deep_copy() const {
return makeExpr<Analyzer::HashStringOper>(
std::dynamic_pointer_cast<Analyzer::StringOper>(StringOper::deep_copy()));
}

std::shared_ptr<Analyzer::Expr> FunctionOper::deep_copy() const {
std::vector<std::shared_ptr<Analyzer::Expr>> args_copy;
for (size_t i = 0; i < getArity(); ++i) {
Expand Down
31 changes: 31 additions & 0 deletions Analyzer/Analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -2703,6 +2703,37 @@ class LevenshteinDistanceStringOper : public StringOper {
const std::vector<std::shared_ptr<Analyzer::Expr>>& operands);
};

class HashStringOper : public StringOper {
public:
HashStringOper(const std::shared_ptr<Analyzer::Expr>& operand)
: StringOper(SqlStringOpKind::HASH,
SQLTypeInfo(kBIGINT),
{operand},
getMinArgs(),
getExpectedTypeFamilies(),
getArgNames()) {}

HashStringOper(const std::vector<std::shared_ptr<Analyzer::Expr>>& operands)
: StringOper(SqlStringOpKind::HASH,
SQLTypeInfo(kBIGINT),
operands,
getMinArgs(),
getExpectedTypeFamilies(),
getArgNames()) {}

HashStringOper(const std::shared_ptr<Analyzer::StringOper>& string_oper)
: StringOper(string_oper) {}

std::shared_ptr<Analyzer::Expr> deep_copy() const override;

size_t getMinArgs() const override { return 1UL; }

std::vector<OperandTypeFamily> getExpectedTypeFamilies() const override {
return {OperandTypeFamily::STRING_FAMILY};
}
std::vector<std::string> getArgNames() const override { return {"operand"}; }
};

class FunctionOper : public Expr {
public:
FunctionOper(const SQLTypeInfo& ti,
Expand Down
1 change: 1 addition & 0 deletions Parser/ReservedKeywords.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ static std::set<std::string> reserved_keywords{
"GROUP",
"GROUPING",
"GROUPS",
"HASH",
"HAVING",
"HOLD",
"HOUR",
Expand Down
5 changes: 4 additions & 1 deletion QueryEngine/RelAlgTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,8 @@ std::shared_ptr<Analyzer::Expr> RelAlgTranslator::translateStringOper(
return makeExpr<Analyzer::JarowinklerSimilarityStringOper>(args);
case SqlStringOpKind::LEVENSHTEIN_DISTANCE:
return makeExpr<Analyzer::LevenshteinDistanceStringOper>(args);
case SqlStringOpKind::HASH:
return makeExpr<Analyzer::HashStringOper>(args);
case SqlStringOpKind::URL_ENCODE:
return makeExpr<Analyzer::UrlEncodeStringOper>(args);
case SqlStringOpKind::URL_DECODE:
Expand Down Expand Up @@ -1815,7 +1817,8 @@ std::shared_ptr<Analyzer::Expr> RelAlgTranslator::translateFunction(
"TRY_CAST"sv,
"POSITION"sv,
"JAROWINKLER_SIMILARITY"sv,
"LEVENSHTEIN_DISTANCE"sv)) {
"LEVENSHTEIN_DISTANCE"sv,
"HASH"sv)) {
return translateStringOper(rex_function);
}
if (func_resolve(rex_function->getName(), "CARDINALITY"sv, "ARRAY_LENGTH"sv)) {
Expand Down
7 changes: 7 additions & 0 deletions Shared/sqldefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ enum class SqlStringOpKind {
POSITION, // string-to-numeric
JAROWINKLER_SIMILARITY, // string-to-numeric
LEVENSHTEIN_DISTANCE, // string-to-numeric
HASH, // string-to-numeric
INVALID
};

Expand Down Expand Up @@ -446,6 +447,8 @@ inline std::ostream& operator<<(std::ostream& os, const SqlStringOpKind kind) {
return os << "JAROWINKLER_SIMILARITY";
case SqlStringOpKind::LEVENSHTEIN_DISTANCE:
return os << "LEVENSHTEIN_DISTANCE";
case SqlStringOpKind::HASH:
return os << "HASH";
case SqlStringOpKind::INVALID:
return os << "INVALID";
}
Expand Down Expand Up @@ -539,6 +542,9 @@ inline SqlStringOpKind name_to_string_op_kind(const std::string& func_name) {
if (func_name == "LEVENSHTEIN_DISTANCE") {
return SqlStringOpKind::LEVENSHTEIN_DISTANCE;
}
if (func_name == "HASH") {
return SqlStringOpKind::HASH;
}
LOG(FATAL) << "Invalid string function " << func_name << ".";
return SqlStringOpKind::INVALID;
}
Expand All @@ -550,6 +556,7 @@ inline bool string_op_returns_string(const SqlStringOpKind kind) {
case SqlStringOpKind::JAROWINKLER_SIMILARITY:
case SqlStringOpKind::LEVENSHTEIN_DISTANCE:
case SqlStringOpKind::REGEXP_COUNT:
case SqlStringOpKind::HASH:
return false;
default:
return true;
Expand Down
24 changes: 24 additions & 0 deletions StringOps/StringOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,26 @@ Datum LevenshteinDistance::numericEval(const std::string_view str1,
return return_datum;
}

NullableStrType Hash::operator()(const std::string& str) const {
UNREACHABLE() << "Invalid string output for Hash";
return {};
}

Datum Hash::numericEval(const std::string_view str) const {
if (str.empty()) {
return NullDatum(return_ti_);
} else {
uint64_t str_hash = 1;
// rely on fact that unsigned overflow is defined and wraps
for (size_t i = 0; i < str.size(); ++i) {
str_hash = str_hash * 997u + static_cast<unsigned char>(str[i]);
}
Datum return_datum;
return_datum.bigintval = static_cast<int64_t>(str_hash);
return return_datum;
}
}

NullableStrType Lower::operator()(const std::string& str) const {
std::string output_str(str);
std::transform(
Expand Down Expand Up @@ -1239,6 +1259,10 @@ std::unique_ptr<const StringOp> gen_string_op(const StringOpInfo& string_op_info
return std::make_unique<const LevenshteinDistance>(var_string_optional_literal);
}
}
case SqlStringOpKind::HASH: {
CHECK_EQ(num_non_variable_literals, 0UL);
return std::make_unique<const Hash>(var_string_optional_literal);
}
default:
UNREACHABLE();
return {};
Expand Down
9 changes: 9 additions & 0 deletions StringOps/StringOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,15 @@ struct LevenshteinDistance : public StringOp {
const std::string str_literal_;
};

struct Hash : public StringOp {
public:
Hash(const std::optional<std::string>& var_str_optional_literal)
: StringOp(SqlStringOpKind::HASH, SQLTypeInfo(kBIGINT), var_str_optional_literal) {}

NullableStrType operator()(const std::string& str) const override;
Datum numericEval(const std::string_view str) const override;
};

struct Lower : public StringOp {
Lower(const std::optional<std::string>& var_str_optional_literal)
: StringOp(SqlStringOpKind::LOWER, var_str_optional_literal) {}
Expand Down
97 changes: 92 additions & 5 deletions Tests/StringFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,11 @@ class StringFunctionTest : public TestHelpers::TbbPrivateServerKiller {
insert into string_function_test_people values(3, 'JOHN', 'Wilson', 'John WILSON', 20, 'cA', '555-614-9814', null, 'What is the sound of one hand clapping?', 'JOHN.WILSON@geops.net');
insert into string_function_test_people values(4, 'Sue', 'Smith', 'Sue SMITH', 25, 'CA', '555-614-2282', null, 'Nothing exists entirely alone. Everything is always in relation to everything else.', 'Find me at sue4tw@example.com, or reach me at sue.smith@example.com. I''d love to hear from you!');
drop table if exists string_function_test_countries;
create table string_function_test_countries(id int, code text, arrow_code text, name text, short_name text encoding none, capital text, largest_city text encoding none, lang text encoding none, json_data_none text encoding none);
insert into string_function_test_countries values(1, 'US', '>>US<<', 'United States', null, 'Washington', 'New York City', 'en', '{"capital": "Washington D.C.", "pop": 329500000, "independence_day": "1776-07-04", "has_prime_minister": false, "prime_minister": null, "factoids": {"gdp_per_cap_2015_2020": [56863, 58021, 60110, 63064, 65280, 63544], "Last 3 leaders": ["Barack Obama", "Donald Trump", "Joseph Biden"], "most_valuable_crop": "corn"}}');
insert into string_function_test_countries values(2, 'ca', '>>CA<<', 'Canada', 'Canada', 'Ottawa', 'TORONTO', 'EN', '{"capital": "Toronto", "pop": 38010000, "independence_day": "07/01/1867", "exchange_rate_usd": "0.78125", "has_prime_minister": true, "prime_minister": "Justin Trudeau", "factoids": {"gdp_per_cap_2015_2020": [43596, 42316, 45129, 46454, 46327, 43242], "Last 3 leaders": ["Paul Martin", "Stephen Harper", "Justin Trudeau"], "most valuable crop": "wheat"}}');
insert into string_function_test_countries values(3, 'Gb', '>>GB<<', 'United Kingdom', 'UK', 'London', 'LONDON', 'en', '{"capital": "London", "pop": 67220000, "independence_day": "N/A", "exchange_rate_usd": 1.21875, "prime_minister": "Boris Johnson", "has_prime_minister": true, "factoids": {"gdp_per_cap_2015_2020": [45039, 41048, 40306, 42996, 42354, 40285], "most valuable crop": "wheat"}}');
insert into string_function_test_countries values(4, 'dE', '>>DE<<', 'Germany', 'Germany', 'Berlin', 'Berlin', 'de', '{"capital":"Berlin", "independence_day": "1990-10-03", "exchange_rate_usd": 1.015625, "has_prime_minister": false, "prime_minister": null, "factoids": {"gdp_per_cap_2015_2020": [41103, 42136, 44453, 47811, 46468, 45724], "most valuable crop": "wheat"}}');
create table string_function_test_countries(id int, code text, arrow_code text, name text, short_name text encoding none, capital text, capital_none text encoding none, largest_city text encoding none, lang text encoding none, json_data_none text encoding none);
insert into string_function_test_countries values(1, 'US', '>>US<<', 'United States', null, 'Washington', 'Washington', 'New York City', 'en', '{"capital": "Washington D.C.", "pop": 329500000, "independence_day": "1776-07-04", "has_prime_minister": false, "prime_minister": null, "factoids": {"gdp_per_cap_2015_2020": [56863, 58021, 60110, 63064, 65280, 63544], "Last 3 leaders": ["Barack Obama", "Donald Trump", "Joseph Biden"], "most_valuable_crop": "corn"}}');
insert into string_function_test_countries values(2, 'ca', '>>CA<<', 'Canada', 'Canada', 'Ottawa', 'Ottawa', 'TORONTO', 'EN', '{"capital": "Toronto", "pop": 38010000, "independence_day": "07/01/1867", "exchange_rate_usd": "0.78125", "has_prime_minister": true, "prime_minister": "Justin Trudeau", "factoids": {"gdp_per_cap_2015_2020": [43596, 42316, 45129, 46454, 46327, 43242], "Last 3 leaders": ["Paul Martin", "Stephen Harper", "Justin Trudeau"], "most valuable crop": "wheat"}}');
insert into string_function_test_countries values(3, 'Gb', '>>GB<<', 'United Kingdom', 'UK', 'London', 'London', 'LONDON', 'en', '{"capital": "London", "pop": 67220000, "independence_day": "N/A", "exchange_rate_usd": 1.21875, "prime_minister": "Boris Johnson", "has_prime_minister": true, "factoids": {"gdp_per_cap_2015_2020": [45039, 41048, 40306, 42996, 42354, 40285], "most valuable crop": "wheat"}}');
insert into string_function_test_countries values(4, 'dE', '>>DE<<', 'Germany', 'Germany', 'Berlin', 'Berlin', 'Berlin', 'de', '{"capital":"Berlin", "independence_day": "1990-10-03", "exchange_rate_usd": 1.015625, "has_prime_minister": false, "prime_minister": null, "factoids": {"gdp_per_cap_2015_2020": [41103, 42136, 44453, 47811, 46468, 45724], "most valuable crop": "wheat"}}');
drop table if exists numeric_to_string_test;
create table numeric_to_string_test(b boolean, ti tinyint, si smallint, i int, bi bigint, flt float, dbl double, dec_5_2 decimal(5, 2), dec_18_10 decimal(18, 10), dt date, ts_0 timestamp(0), ts_3 timestamp(3), tm time, b_str text, ti_str text, si_str text, i_str text, bi_str text, flt_str text, dbl_str text, dec_5_2_str text, dec_18_10_str text, dt_str text, ts_0_str text, ts_3_str text, tm_str text) with (fragment_size=2);
insert into numeric_to_string_test values (true, 21, 21, 21, 21, 1.25, 1.25, 1.25, 1.25, '2013-09-10', '2013-09-10 12:43:23', '2013-09-10 12:43:23.123', '12:43:23', 'true', '21', '21', '21', '21', '1.250000', '1.250000', ' 1.25', ' 1.2500000000', '2013-09-10', '2013-09-10 12:43:23', '2013-09-10 12:43:23.123', '12:43:23');
Expand Down Expand Up @@ -2385,6 +2385,93 @@ TEST_F(StringFunctionTest, LevenshteinDistance) {
}
}

TEST_F(StringFunctionTest, Hash) {
for (auto dt : {ExecutorDeviceType::CPU, ExecutorDeviceType::GPU}) {
SKIP_NO_GPU();
{
// Literal hash
auto result_set = sql("select hash('hi');", dt);
std::vector<std::vector<ScalarTargetValue>> expected_result_set{{int64_t(1097802)}};
compare_result_set(expected_result_set, result_set);
}
{
// Literal null
auto result_set = sql("select coalesce(hash(CAST(NULL AS TEXT)), 0);", dt);
std::vector<std::vector<ScalarTargetValue>> expected_result_set{{int64_t(0)}};
compare_result_set(expected_result_set, result_set);
}
{
// Dictionary-encoded text column
auto result_set = sql(
"select hash(capital) from string_function_test_countries order by id;", dt);
std::vector<std::vector<ScalarTargetValue>> expected_result_set{
{int64_t(5703505280371710991)},
{int64_t(1060071279222666409)},
{int64_t(1057111063818803959)},
{int64_t(1047250289947889561)}};
compare_result_set(expected_result_set, result_set);
}
{
// None-encoded text column
auto result_set = sql(
"select hash(capital_none) from string_function_test_countries order by id;",
dt);
std::vector<std::vector<ScalarTargetValue>> expected_result_set{
{int64_t(5703505280371710991)},
{int64_t(1060071279222666409)},
{int64_t(1057111063818803959)},
{int64_t(1047250289947889561)}};
compare_result_set(expected_result_set, result_set);
}
{
// Dictionary-encoded text column with nulls
auto result_set =
sql("select coalesce(hash(zip_plus_4), 0) from string_function_test_people "
"order by id;",
dt);
std::vector<std::vector<ScalarTargetValue>> expected_result_set{
{int64_t(6345224789068548647)},
{int64_t(-3868673234647279706)},
{int64_t(0)},
{int64_t(0)}};
compare_result_set(expected_result_set, result_set);
}
{
// None-encoded text column with nulls
auto result_set =
sql("select coalesce(hash(short_name), 0) from string_function_test_countries "
"order by id;",
dt);
std::vector<std::vector<ScalarTargetValue>> expected_result_set{
{int64_t(0)},
{int64_t(1048231423487679005)},
{int64_t(1078829)},
{int64_t(-2445200816347761128)}};
compare_result_set(expected_result_set, result_set);
}
{
// Hash comparison
auto result_set =
sql("select count(*) from string_function_test_countries where "
"hash(capital) = hash(capital_none);",
dt);
std::vector<std::vector<ScalarTargetValue>> expected_result_set{{int64_t(4)}};
compare_result_set(expected_result_set, result_set);
}
{
auto result_set =
sql("select hash(lower(first_name)), any_value(lower(first_name)), count(*) "
"from string_function_test_people group by hash(lower(first_name)) order "
"by count(*) desc;",
dt);
std::vector<std::vector<ScalarTargetValue>> expected_result_set{
{int64_t(1093213190016), "john", int64_t(3)},
{int64_t(1105454758), "sue", int64_t(1)}};
compare_result_set(expected_result_set, result_set);
}
}
}

TEST_F(StringFunctionTest, NullLiteralTest) {
for (auto dt : {ExecutorDeviceType::CPU, ExecutorDeviceType::GPU}) {
SKIP_NO_GPU();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.mapd.parser.extension.ddl.SqlLeadLag;
import com.mapd.parser.extension.ddl.SqlNthValueInFrame;
import com.mapd.parser.server.ExtensionFunction;
import com.mapd.parser.server.ExtensionFunction.ExtArgumentType;

import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.rel.metadata.RelColumnMapping;
Expand Down Expand Up @@ -238,6 +239,7 @@ public void addUDF(final Map<String, ExtensionFunction> extSigs) {
addOperator(new UrlDecode());
addOperator(new JarowinklerSimilarity());
addOperator(new LevenshteinDistance());
addOperator(new Hash());
addOperator(new Likely());
addOperator(new Unlikely());
addOperator(new Sign());
Expand Down Expand Up @@ -1746,6 +1748,35 @@ public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec)
}
}

public static class Hash extends SqlFunction {
public Hash() {
super("HASH",
SqlKind.OTHER_FUNCTION,
null,
null,
OperandTypes.family(getSignatureFamilies()),
SqlFunctionCategory.SYSTEM);
}

private static java.util.List<SqlTypeFamily> getSignatureFamilies() {
java.util.ArrayList<SqlTypeFamily> families =
new java.util.ArrayList<SqlTypeFamily>();
// Todo(todd): Support any input type for HASH function
// families.add(SqlTypeFamily.ANY);
families.add(SqlTypeFamily.STRING);
return families;
}

@Override
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
assert opBinding.getOperandCount() == 1;
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.BIGINT),
opBinding.getOperandType(0).isNullable());
}
}

public static class Likely extends SqlFunction {
public Likely() {
super("LIKELY",
Expand Down

0 comments on commit 47d94e7

Please sign in to comment.