Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support udf like with 3 arguments #212

Merged
merged 3 commits into from
Sep 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbms/src/Flash/Coprocessor/DAGUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ std::unordered_map<tipb::ScalarFuncSig, String> scalar_func_map({
//{tipb::ScalarFuncSig::IsIPv6, "cast"},
//{tipb::ScalarFuncSig::UUID, "cast"},

//{tipb::ScalarFuncSig::LikeSig, "cast"},
{tipb::ScalarFuncSig::LikeSig, "like3Args"},
//{tipb::ScalarFuncSig::RegexpBinarySig, "cast"},
//{tipb::ScalarFuncSig::RegexpSig, "cast"},

Expand Down
6 changes: 6 additions & 0 deletions dbms/src/Functions/FunctionsStringSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,10 @@ struct NameLike
{
static constexpr auto name = "like";
};
struct NameLike3Args
{
static constexpr auto name = "like3Args";
};
struct NameNotLike
{
static constexpr auto name = "notLike";
Expand Down Expand Up @@ -1058,6 +1062,7 @@ using FunctionPositionCaseInsensitiveUTF8

using FunctionMatch = FunctionsStringSearch<MatchImpl<false>, NameMatch>;
using FunctionLike = FunctionsStringSearch<MatchImpl<true>, NameLike>;
using FunctionLike3Args = FunctionsStringSearch<MatchImpl<true>, NameLike3Args, 3>;
using FunctionNotLike = FunctionsStringSearch<MatchImpl<true, true>, NameNotLike>;
using FunctionExtract = FunctionsStringSearchToString<ExtractImpl, NameExtract>;
using FunctionReplaceOne = FunctionStringReplace<ReplaceStringImpl<true>, NameReplaceOne>;
Expand All @@ -1078,6 +1083,7 @@ void registerFunctionsStringSearch(FunctionFactory & factory)
factory.registerFunction<FunctionPositionCaseInsensitiveUTF8>();
factory.registerFunction<FunctionMatch>();
factory.registerFunction<FunctionLike>();
factory.registerFunction<FunctionLike3Args>();
factory.registerFunction<FunctionNotLike>();
factory.registerFunction<FunctionExtract>();
}
Expand Down
102 changes: 98 additions & 4 deletions dbms/src/Functions/FunctionsStringSearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ namespace DB
* Warning! At this point, the arguments needle, pattern, n, replacement must be constants.
*/

static const UInt8 CH_ESCAPE_CHAR = '\\';

template <typename Impl, typename Name>
template <typename Impl, typename Name, size_t num_args = 2>
class FunctionsStringSearch : public IFunction
{
public:
static constexpr auto name = Name::name;
static constexpr auto has_3_args = (num_args == 3);
static FunctionPtr create(const Context &)
{
return std::make_shared<FunctionsStringSearch>();
Expand All @@ -56,7 +58,7 @@ class FunctionsStringSearch : public IFunction

size_t getNumberOfArguments() const override
{
return 2;
return num_args;
}

DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
Expand All @@ -68,10 +70,60 @@ class FunctionsStringSearch : public IFunction
if (!arguments[1]->isString())
throw Exception(
"Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (has_3_args && !arguments[2]->isInteger())
throw Exception(
"Illegal type " + arguments[2]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

return std::make_shared<DataTypeNumber<typename Impl::ResultType>>();
}

// replace the escape_char in orig_string with '\\'
// this function does not check the validation of the orig_string
// for example, for string "abcd" and escape char 'd', it will
// return "abc\\"
String replaceEscapeChar(String & orig_string, UInt8 escape_char)
{
std::stringstream ss;
for (size_t i = 0; i < orig_string.size(); i++)
{
auto c = orig_string[i];
if (c == escape_char)
{
if (i+1 != orig_string.size() && orig_string[i+1] == escape_char)
{
// two successive escape char, which means it is trying to escape itself, just remove one
i++;
ss << escape_char;
}
else
{
// https://github.com/pingcap/tidb/blob/master/util/stringutil/string_util.go#L154
// if any char following escape char that is not [escape_char,'_','%'], it is invalid escape.
// mysql will treat escape character as the origin value even
// the escape sequence is invalid in Go or C.
// e.g., \m is invalid in Go, but in MySQL we will get "m" for select '\m'.
// Following case is correct just for escape \, not for others like +.
// TODO: Add more checks for other escapes.
if (i+1 != orig_string.size() && orig_string[i+1] == CH_ESCAPE_CHAR)
{
continue;
}
ss << CH_ESCAPE_CHAR;
}
}
else if (c == CH_ESCAPE_CHAR)
{
// need to escape this '\\'
ss << CH_ESCAPE_CHAR << CH_ESCAPE_CHAR;
}
else
{
ss << c;
}
}
return ss.str();
}

void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override
{
using ResultType = typename Impl::ResultType;
Expand All @@ -82,10 +134,44 @@ class FunctionsStringSearch : public IFunction
const ColumnConst * col_haystack_const = typeid_cast<const ColumnConst *>(&*column_haystack);
const ColumnConst * col_needle_const = typeid_cast<const ColumnConst *>(&*column_needle);

UInt8 escape_char = CH_ESCAPE_CHAR;
if (has_3_args)
{
auto * col_escape_const = typeid_cast<const ColumnConst *>(&*block.getByPosition(arguments[2]).column);
bool valid_args = true;
if (col_needle_const == nullptr || col_escape_const == nullptr)
{
valid_args = false;
}
else
{
auto c = col_escape_const->getValue<Int32>();
if (c < 0 || c > 255)
{
// todo maybe use more strict constraint
valid_args = false;
}
else
{
escape_char = (UInt8) c;
}
}
if (!valid_args)
{
throw Exception("2nd and 3rd arguments of function " + getName() + " must "
"be constants, and the 3rd argument must between 0 and 255.");
}
}

if (col_haystack_const && col_needle_const)
{
ResultType res{};
Impl::constant_constant(col_haystack_const->getValue<String>(), col_needle_const->getValue<String>(), res);
String needle_string = col_needle_const->getValue<String>();
if (has_3_args && escape_char != CH_ESCAPE_CHAR)
{
needle_string = replaceEscapeChar(needle_string, escape_char);
}
Impl::constant_constant(col_haystack_const->getValue<String>(), needle_string, res);
block.getByPosition(result).column = block.getByPosition(result).type->createColumnConst(col_haystack_const->size(), toField(res));
return;
}
Expand All @@ -105,7 +191,15 @@ class FunctionsStringSearch : public IFunction
col_needle_vector->getOffsets(),
vec_res);
else if (col_haystack_vector && col_needle_const)
Impl::vector_constant(col_haystack_vector->getChars(), col_haystack_vector->getOffsets(), col_needle_const->getValue<String>(), vec_res);
{
String needle_string = col_needle_const->getValue<String>();
if (has_3_args && escape_char != CH_ESCAPE_CHAR)
{
needle_string = replaceEscapeChar(needle_string, escape_char);
}
Impl::vector_constant(col_haystack_vector->getChars(), col_haystack_vector->getOffsets(),
needle_string, vec_res);
}
else if (col_haystack_const && col_needle_vector)
Impl::constant_vector(col_haystack_const->getValue<String>(), col_needle_vector->getChars(), col_needle_vector->getOffsets(), vec_res);
else
Expand Down