Skip to content

Commit

Permalink
Support REGEXP_COUNT string function and fix string function null lit…
Browse files Browse the repository at this point in the history
…eral issues

Signed-off-by: Misiu Godfrey <misiu.godfrey@kraken.mapd.com>
  • Loading branch information
tmostak authored and misiugodfrey committed Aug 26, 2024
1 parent 06647d7 commit 7057d1f
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 18 deletions.
5 changes: 5 additions & 0 deletions Analyzer/Analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3766,6 +3766,11 @@ std::shared_ptr<Analyzer::Expr> RegexpSubstrStringOper::deep_copy() const {
std::dynamic_pointer_cast<Analyzer::StringOper>(StringOper::deep_copy()));
}

std::shared_ptr<Analyzer::Expr> RegexpCountStringOper::deep_copy() const {
return makeExpr<Analyzer::RegexpCountStringOper>(
std::dynamic_pointer_cast<Analyzer::StringOper>(StringOper::deep_copy()));
}

std::shared_ptr<Analyzer::Expr> JsonValueStringOper::deep_copy() const {
return makeExpr<Analyzer::JsonValueStringOper>(
std::dynamic_pointer_cast<Analyzer::StringOper>(StringOper::deep_copy()));
Expand Down
40 changes: 40 additions & 0 deletions Analyzer/Analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -2359,6 +2359,46 @@ class RegexpSubstrStringOper : public StringOper {
"sub-match group index"};
}
};

class RegexpCountStringOper : public StringOper {
public:
RegexpCountStringOper(const std::shared_ptr<Analyzer::Expr>& operand,
const std::shared_ptr<Analyzer::Expr>& regex_pattern,
const std::shared_ptr<Analyzer::Expr>& start_pos,
const std::shared_ptr<Analyzer::Expr>& regex_params)
: StringOper(SqlStringOpKind::REGEXP_COUNT,
SQLTypeInfo(kBIGINT),
{operand, regex_pattern, start_pos, regex_params},
getMinArgs(),
getExpectedTypeFamilies(),
getArgNames()) {}

RegexpCountStringOper(const std::vector<std::shared_ptr<Analyzer::Expr>>& operands)
: StringOper(SqlStringOpKind::REGEXP_COUNT,
SQLTypeInfo(kBIGINT),
operands,
getMinArgs(),
getExpectedTypeFamilies(),
getArgNames()) {}

RegexpCountStringOper(const std::shared_ptr<Analyzer::StringOper>& string_oper)
: StringOper(string_oper) {}

std::shared_ptr<Analyzer::Expr> deep_copy() const override;

size_t getMinArgs() const override { return 4UL; }

std::vector<OperandTypeFamily> getExpectedTypeFamilies() const override {
return {OperandTypeFamily::STRING_FAMILY,
OperandTypeFamily::STRING_FAMILY,
OperandTypeFamily::INT_FAMILY,
OperandTypeFamily::STRING_FAMILY};
}
std::vector<std::string> getArgNames() const override {
return {"operand", "regex pattern", "start position", "regex parameters"};
}
};

class JsonValueStringOper : public StringOper {
public:
JsonValueStringOper(const std::shared_ptr<Analyzer::Expr>& operand,
Expand Down
1 change: 1 addition & 0 deletions Parser/ReservedKeywords.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ static std::set<std::string> reserved_keywords{
"REF",
"REFERENCES",
"REFERENCING",
"REGEXP_COUNT",
"REGEXP_REPLACE",
"REGEXP_SUBSTR",
"REGEXP_MATCH"
Expand Down
3 changes: 3 additions & 0 deletions QueryEngine/RelAlgTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,8 @@ std::shared_ptr<Analyzer::Expr> RelAlgTranslator::translateStringOper(
return makeExpr<Analyzer::RegexpReplaceStringOper>(args);
case SqlStringOpKind::REGEXP_SUBSTR:
return makeExpr<Analyzer::RegexpSubstrStringOper>(args);
case SqlStringOpKind::REGEXP_COUNT:
return makeExpr<Analyzer::RegexpCountStringOper>(args);
case SqlStringOpKind::JSON_VALUE:
return makeExpr<Analyzer::JsonValueStringOper>(args);
case SqlStringOpKind::BASE64_ENCODE:
Expand Down Expand Up @@ -1804,6 +1806,7 @@ std::shared_ptr<Analyzer::Expr> RelAlgTranslator::translateFunction(
"REGEXP_REPLACE"sv,
"REGEXP_SUBSTR"sv,
"REGEXP_MATCH"sv,
"REGEXP_COUNT"sv,
"JSON_VALUE"sv,
"BASE64_ENCODE"sv,
"BASE64_DECODE"sv,
Expand Down
9 changes: 9 additions & 0 deletions Shared/sqldefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ enum class SqlStringOpKind {
/* 6 args */
REGEXP_REPLACE,
REGEXP_SUBSTR,
REGEXP_COUNT,
JSON_VALUE,
BASE64_ENCODE,
BASE64_DECODE,
Expand Down Expand Up @@ -425,6 +426,8 @@ inline std::ostream& operator<<(std::ostream& os, const SqlStringOpKind kind) {
return os << "REGEXP_REPLACE";
case SqlStringOpKind::REGEXP_SUBSTR:
return os << "REGEXP_SUBSTR";
case SqlStringOpKind::REGEXP_COUNT:
return os << "REGEXP_COUNT";
case SqlStringOpKind::JSON_VALUE:
return os << "JSON_VALUE";
case SqlStringOpKind::BASE64_ENCODE:
Expand Down Expand Up @@ -506,6 +509,9 @@ inline SqlStringOpKind name_to_string_op_kind(const std::string& func_name) {
if (func_name == "REGEXP_MATCH") {
return SqlStringOpKind::REGEXP_SUBSTR;
}
if (func_name == "REGEXP_COUNT") {
return SqlStringOpKind::REGEXP_COUNT;
}
if (func_name == "JSON_VALUE") {
return SqlStringOpKind::JSON_VALUE;
}
Expand Down Expand Up @@ -541,6 +547,9 @@ inline bool string_op_returns_string(const SqlStringOpKind kind) {
switch (kind) {
case SqlStringOpKind::TRY_STRING_CAST:
case SqlStringOpKind::POSITION:
case SqlStringOpKind::JAROWINKLER_SIMILARITY:
case SqlStringOpKind::LEVENSHTEIN_DISTANCE:
case SqlStringOpKind::REGEXP_COUNT:
return false;
default:
return true;
Expand Down
46 changes: 40 additions & 6 deletions StringOps/StringOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,31 @@ std::pair<bool, int64_t> RegexpSubstr::set_sub_match_info(
true, sub_match_group_idx > 0L ? sub_match_group_idx - 1 : sub_match_group_idx);
}

NullableStrType RegexpCount::operator()(const std::string& str) const {
UNREACHABLE() << "Invalid string output for RegexpCount";
return {};
}

Datum RegexpCount::numericEval(const std::string_view str_view) const {
if (str_view.empty()) {
return NullDatum(return_ti_);
}

Datum return_datum;
const int64_t str_len = str_view.size();
const int64_t pos = start_pos_ < 0 ? str_len + start_pos_ : start_pos_;
const size_t wrapped_start = std::clamp(pos, int64_t(0), str_len);
auto search_start = str_view.data() + wrapped_start;
auto search_end = str_view.data() + str_len;
boost::cregex_iterator iter(search_start, search_end, regex_pattern_);
boost::cregex_iterator end;

int64_t num_matches = std::distance(iter, end);
return_datum.bigintval = num_matches;

return return_datum;
}

// json_path must start with "lax $", "strict $" or "$" (case-insensitive).
JsonValue::JsonParseMode JsonValue::parse_json_parse_mode(std::string_view json_path) {
size_t const string_pos = json_path.find('$');
Expand Down Expand Up @@ -997,7 +1022,8 @@ std::unique_ptr<const StringOp> gen_string_op(const StringOpInfo& string_op_info
const auto& return_ti = string_op_info.getReturnType();

if (string_op_info.hasNullLiteralArg()) {
return std::make_unique<const NullOp>(var_string_optional_literal, op_kind);
return std::make_unique<const NullOp>(
return_ti, var_string_optional_literal, op_kind);
}

const auto num_non_variable_literals = string_op_info.numNonVariableLiterals();
Expand Down Expand Up @@ -1139,6 +1165,17 @@ std::unique_ptr<const StringOp> gen_string_op(const StringOpInfo& string_op_info
regex_params_literal,
sub_match_idx_literal);
}
case SqlStringOpKind::REGEXP_COUNT: {
CHECK_GE(num_non_variable_literals, 3UL);
CHECK_LE(num_non_variable_literals, 3UL);
const auto pattern_literal = string_op_info.getStringLiteral(1);
const auto start_pos_literal = string_op_info.getIntLiteral(2);
const auto regex_params_literal = string_op_info.getStringLiteral(3);
return std::make_unique<const RegexpCount>(var_string_optional_literal,
pattern_literal,
start_pos_literal,
regex_params_literal);
}
case SqlStringOpKind::JSON_VALUE: {
CHECK_EQ(num_non_variable_literals, 1UL);
const auto json_path_literal = string_op_info.getStringLiteral(1);
Expand Down Expand Up @@ -1202,13 +1239,10 @@ std::unique_ptr<const StringOp> gen_string_op(const StringOpInfo& string_op_info
return std::make_unique<const LevenshteinDistance>(var_string_optional_literal);
}
}
default: {
default:
UNREACHABLE();
return std::make_unique<NullOp>(var_string_optional_literal, op_kind);
}
return {};
}
// Make compiler happy
return std::make_unique<NullOp>(var_string_optional_literal, op_kind);
}

std::pair<std::string, bool /* is null */> apply_string_op_to_literals(
Expand Down
41 changes: 32 additions & 9 deletions StringOps/StringOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct StringOp {
StringOp(const SqlStringOpKind op_kind,
const std::optional<std::string>& var_str_optional_literal)
: op_kind_(op_kind)
, return_ti_(SQLTypeInfo(kTEXT))
, return_ti_(SQLTypeInfo(kTEXT, false, kENCODING_DICT))
, has_var_str_literal_(var_str_optional_literal.has_value())
, var_str_literal_(!var_str_optional_literal.has_value()
? NullableStrType()
Expand Down Expand Up @@ -80,10 +80,10 @@ struct StringOp {
}

virtual NullableStrType operator()() const {
CHECK(hasVarStringLiteral());
if (var_str_literal_.is_null) {
return var_str_literal_;
}
CHECK(hasVarStringLiteral());
return operator()(var_str_literal_.str);
}

Expand All @@ -101,10 +101,10 @@ struct StringOp {
}

virtual Datum numericEval() const {
CHECK(hasVarStringLiteral());
if (var_str_literal_.is_null) {
return NullDatum(return_ti_);
}
CHECK(hasVarStringLiteral());
return numericEval(var_str_literal_.str);
}

Expand Down Expand Up @@ -439,8 +439,7 @@ struct RegexpSubstr : public StringOp {
const std::string& regex_params,
const int64_t sub_match_group_idx)
: StringOp(SqlStringOpKind::REGEXP_SUBSTR, var_str_optional_literal)
, regex_pattern_str_(
regex_pattern) // for toString() as std::regex does not have str() method
, regex_pattern_str_(regex_pattern)
, regex_pattern_(
StringOp::generateRegex("REGEXP_SUBSTR", regex_pattern, regex_params, true))
, start_pos_(start_pos > 0 ? start_pos - 1 : start_pos)
Expand Down Expand Up @@ -472,8 +471,7 @@ struct RegexpReplace : public StringOp {
const int64_t occurrence,
const std::string& regex_params)
: StringOp(SqlStringOpKind::REGEXP_REPLACE, var_str_optional_literal)
, regex_pattern_str_(
regex_pattern) // for toString() as std::regex does not have str() method
, regex_pattern_str_(regex_pattern)
, regex_pattern_(
StringOp::generateRegex("REGEXP_REPLACE", regex_pattern, regex_params, false))
, replacement_(replacement)
Expand All @@ -495,6 +493,29 @@ struct RegexpReplace : public StringOp {
const int64_t occurrence_;
};

struct RegexpCount : public StringOp {
public:
RegexpCount(const std::optional<std::string>& var_str_optional_literal,
const std::string& regex_pattern,
const int64_t start_pos,
const std::string& regex_params)
: StringOp(SqlStringOpKind::REGEXP_COUNT,
SQLTypeInfo(kBIGINT),
var_str_optional_literal)
, regex_pattern_str_(regex_pattern)
, regex_pattern_(
StringOp::generateRegex("REGEXP_COUNT", regex_pattern, regex_params, true))
, start_pos_(start_pos > 0 ? start_pos - 1 : start_pos) {}

NullableStrType operator()(const std::string& str) const override;
Datum numericEval(const std::string_view str) const override;

private:
const std::string regex_pattern_str_;
const boost::regex regex_pattern_;
const int64_t start_pos_;
};

// We currently do not allow strict mode JSON parsing per the SQL standard, as
// 1) We can't throw run-time errors in the case that the string operator
// is evaluated in an actual kernel, which is the case for none-encoded text
Expand Down Expand Up @@ -592,9 +613,11 @@ struct UrlDecode : public StringOp {
};

struct NullOp : public StringOp {
NullOp(const std::optional<std::string>& var_str_optional_literal,
NullOp(const SQLTypeInfo& return_ti,
const std::optional<std::string>& var_str_optional_literal,
const SqlStringOpKind op_kind)
: StringOp(SqlStringOpKind::INVALID, var_str_optional_literal), op_kind_(op_kind) {}
: StringOp(SqlStringOpKind::INVALID, return_ti, var_str_optional_literal)
, op_kind_(op_kind) {}

NullableStrType operator()(const std::string& str) const override {
return NullableStrType(); // null string
Expand Down
Loading

0 comments on commit 7057d1f

Please sign in to comment.