Skip to content

Commit

Permalink
Support Json unquote function (#8407)
Browse files Browse the repository at this point in the history
close #8334
  • Loading branch information
yibin87 authored Nov 28, 2023
1 parent 1c6ea49 commit 4479df8
Show file tree
Hide file tree
Showing 17 changed files with 1,295 additions and 71 deletions.
43 changes: 30 additions & 13 deletions dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,27 +281,30 @@ String DAGExpressionAnalyzerHelper::buildCastFunction(
return buildCastFunctionInternal(analyzer, {name, type_expr_name}, false, expr.field_type(), actions);
}

String DAGExpressionAnalyzerHelper::buildCastAsJsonWithInputTiDBField(
String DAGExpressionAnalyzerHelper::buildSingleParamJsonRelatedFunctions(
DAGExpressionAnalyzer * analyzer,
const tipb::Expr & expr,
const ExpressionActionsPtr & actions)
{
auto func_name = getFunctionName(expr);
if unlikely (expr.children_size() != 1)
throw TiFlashException("Cast function only support one argument", Errors::Coprocessor::BadRequest);
throw TiFlashException(
fmt::format("{} function only support one argument", func_name),
Errors::Coprocessor::BadRequest);
if unlikely (!exprHasValidFieldType(expr))
throw TiFlashException("CAST function without valid field type", Errors::Coprocessor::BadRequest);
throw TiFlashException(
fmt::format("{} function without valid field type", func_name),
Errors::Coprocessor::BadRequest);

const auto & input_expr = expr.children(0);
auto func_name = getFunctionName(expr);

String arg = analyzer->getActions(input_expr, actions);
const auto & collator = getCollatorFromExpr(expr);
String result_name = genFuncString(func_name, {arg}, {collator});
if (actions->getSampleBlock().has(result_name))
return result_name;

const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(func_name, analyzer->getContext());
auto * function_build_ptr = function_builder.get();
const FunctionBuilderPtr & ifunction_builder = FunctionFactory::instance().get(func_name, analyzer->getContext());
auto * function_build_ptr = ifunction_builder.get();
if (auto * function_builder = dynamic_cast<DefaultFunctionBuilder *>(function_build_ptr); function_builder)
{
auto * function_impl = function_builder->getFunctionImpl().get();
Expand All @@ -321,17 +324,29 @@ String DAGExpressionAnalyzerHelper::buildCastAsJsonWithInputTiDBField(
{
function_cast_time_as_json->setInputTiDBFieldType(input_expr.field_type());
}
else if (auto * function_json_unquote = dynamic_cast<FunctionJsonUnquote *>(function_impl);
function_json_unquote)
{
bool valid_check
= !(isScalarFunctionExpr(input_expr) && input_expr.sig() == tipb::ScalarFuncSig::CastJsonAsString);
function_json_unquote->setNeedValidCheck(valid_check);
}
else if (auto * function_cast_json_as_string = dynamic_cast<FunctionCastJsonAsString *>(function_impl);
function_cast_json_as_string)
{
function_cast_json_as_string->setOutputTiDBFieldType(expr.field_type());
}
else
{
throw Exception(fmt::format("Unexpected func {} in buildCastAsJsonWithInputTiDBField", func_name));
throw Exception(fmt::format("Unexpected func {} in buildSingleParamJsonRelatedFunctions", func_name));
}
}
else
{
throw Exception(fmt::format("Unexpected func {} in buildCastAsJsonWithInputTiDBField", func_name));
throw Exception(fmt::format("Unexpected func {} in buildSingleParamJsonRelatedFunctions", func_name));
}

const ExpressionAction & action = ExpressionAction::applyFunction(function_builder, {arg}, result_name, collator);
const ExpressionAction & action = ExpressionAction::applyFunction(ifunction_builder, {arg}, result_name, collator);
actions->add(action);
return result_name;
}
Expand Down Expand Up @@ -534,9 +549,11 @@ DAGExpressionAnalyzerHelper::FunctionBuilderMap DAGExpressionAnalyzerHelper::fun
{"ifNull", DAGExpressionAnalyzerHelper::buildIfNullFunction},
{"multiIf", DAGExpressionAnalyzerHelper::buildMultiIfFunction},
{"tidb_cast", DAGExpressionAnalyzerHelper::buildCastFunction},
{"cast_int_as_json", DAGExpressionAnalyzerHelper::buildCastAsJsonWithInputTiDBField},
{"cast_string_as_json", DAGExpressionAnalyzerHelper::buildCastAsJsonWithInputTiDBField},
{"cast_time_as_json", DAGExpressionAnalyzerHelper::buildCastAsJsonWithInputTiDBField},
{"cast_int_as_json", DAGExpressionAnalyzerHelper::buildSingleParamJsonRelatedFunctions},
{"cast_string_as_json", DAGExpressionAnalyzerHelper::buildSingleParamJsonRelatedFunctions},
{"cast_time_as_json", DAGExpressionAnalyzerHelper::buildSingleParamJsonRelatedFunctions},
{"cast_json_as_string", DAGExpressionAnalyzerHelper::buildSingleParamJsonRelatedFunctions},
{"json_unquote", DAGExpressionAnalyzerHelper::buildSingleParamJsonRelatedFunctions},
{"and", DAGExpressionAnalyzerHelper::buildLogicalFunction},
{"or", DAGExpressionAnalyzerHelper::buildLogicalFunction},
{"xor", DAGExpressionAnalyzerHelper::buildLogicalFunction},
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class DAGExpressionAnalyzerHelper
const tipb::Expr & expr,
const ExpressionActionsPtr & actions);

static String buildCastAsJsonWithInputTiDBField(
static String buildSingleParamJsonRelatedFunctions(
DAGExpressionAnalyzer * analyzer,
const tipb::Expr & expr,
const ExpressionActionsPtr & actions);
Expand Down
134 changes: 108 additions & 26 deletions dbms/src/Functions/FunctionsJson.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@
#include <DataTypes/DataTypesNumber.h>
#include <Flash/Coprocessor/DAGUtils.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionsTiDBConversion.h>
#include <Functions/GatherUtils/Sources.h>
#include <Functions/IFunction.h>
#include <Functions/castTypeToEither.h>
#include <Interpreters/Context.h>
#include <TiDB/Decode/JsonBinary.h>
#include <TiDB/Decode/JsonPathExprRef.h>
#include <TiDB/Decode/JsonScanner.h>
#include <TiDB/Schema/TiDB.h>
#include <common/JSON.h>
#include <simdjson.h>
Expand Down Expand Up @@ -301,6 +304,7 @@ class FunctionJsonUnquote : public IFunction

size_t getNumberOfArguments() const override { return 1; }

void setNeedValidCheck(bool need_valid_check_) { need_valid_check = need_valid_check_; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
Expand All @@ -327,16 +331,10 @@ class FunctionJsonUnquote : public IFunction
offsets_to.resize(rows);
ColumnUInt8::MutablePtr col_null_map = ColumnUInt8::create(rows, 0);
JsonBinary::JsonBinaryWriteBuffer write_buffer(data_to);
size_t current_offset = 0;
for (size_t i = 0; i < block.rows(); ++i)
{
size_t next_offset = offsets_from[i];
size_t data_length = next_offset - current_offset - 1;
JsonBinary::unquoteStringInBuffer(StringRef(&data_from[current_offset], data_length), write_buffer);
writeChar(0, write_buffer);
offsets_to[i] = write_buffer.count();
current_offset = next_offset;
}
if (need_valid_check)
doUnquote<true>(block, data_from, offsets_from, offsets_to, write_buffer);
else
doUnquote<false>(block, data_from, offsets_from, offsets_to, write_buffer);
data_to.resize(write_buffer.count());
block.getByPosition(result).column = ColumnNullable::create(std::move(col_to), std::move(col_null_map));
}
Expand All @@ -345,21 +343,69 @@ class FunctionJsonUnquote : public IFunction
fmt::format("Illegal column {} of argument of function {}", column->getName(), getName()),
ErrorCodes::ILLEGAL_COLUMN);
}

template <bool validCheck>
void doUnquote(
const Block & block,
const ColumnString::Chars_t & data_from,
const IColumn::Offsets & offsets_from,
IColumn::Offsets & offsets_to,
JsonBinary::JsonBinaryWriteBuffer & write_buffer) const
{
size_t current_offset = 0;
for (size_t i = 0; i < block.rows(); ++i)
{
size_t next_offset = offsets_from[i];
size_t data_length = next_offset - current_offset - 1;
if constexpr (validCheck)
{
// TODO(hyb): use SIMDJson to check when SIMDJson is proved in practice
if (data_length >= 2 && data_from[current_offset] == '"' && data_from[next_offset - 2] == '"'
&& unlikely(
!checkJsonValid(reinterpret_cast<const char *>(&data_from[current_offset]), data_length)))
{
throw Exception(
"Invalid JSON text: The document root must not be followed by other values.",
ErrorCodes::ILLEGAL_COLUMN);
}
}
JsonBinary::unquoteStringInBuffer(StringRef(&data_from[current_offset], data_length), write_buffer);
writeChar(0, write_buffer);
offsets_to[i] = write_buffer.count();
current_offset = next_offset;
}
}

private:
bool need_valid_check = false;
};


class FunctionCastJsonAsString : public IFunction
{
public:
static constexpr auto name = "cast_json_as_string";
static FunctionPtr create(const Context &) { return std::make_shared<FunctionCastJsonAsString>(); }
static FunctionPtr create(const Context & context)
{
if (!context.getDAGContext())
{
throw Exception("DAGContext should not be nullptr.", ErrorCodes::LOGICAL_ERROR);
}
return std::make_shared<FunctionCastJsonAsString>(context);
}

explicit FunctionCastJsonAsString(const Context & context)
: context(context)
{}

String getName() const override { return name; }

size_t getNumberOfArguments() const override { return 1; }

bool useDefaultImplementationForConstants() const override { return true; }

void setOutputTiDBFieldType(const tipb::FieldType & tidb_tp_) { tidb_tp = &tidb_tp_; }

DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if unlikely (!arguments[0]->isString())
Expand All @@ -386,25 +432,58 @@ class FunctionCastJsonAsString : public IFunction
ColumnUInt8::MutablePtr col_null_map = ColumnUInt8::create(rows, 0);
ColumnUInt8::Container & vec_null_map = col_null_map->getData();
JsonBinary::JsonBinaryWriteBuffer write_buffer(data_to);
size_t current_offset = 0;
for (size_t i = 0; i < block.rows(); ++i)
if likely (tidb_tp->flen() < 0)
{
size_t next_offset = offsets_from[i];
size_t json_length = next_offset - current_offset - 1;
if unlikely (isNullJsonBinary(json_length))
size_t current_offset = 0;
for (size_t i = 0; i < block.rows(); ++i)
{
vec_null_map[i] = 1;
size_t next_offset = offsets_from[i];
size_t json_length = next_offset - current_offset - 1;
if unlikely (isNullJsonBinary(json_length))
vec_null_map[i] = 1;
else
{
JsonBinary json_binary(
data_from[current_offset],
StringRef(&data_from[current_offset + 1], json_length - 1));
json_binary.toStringInBuffer(write_buffer);
}
writeChar(0, write_buffer);
offsets_to[i] = write_buffer.count();
current_offset = next_offset;
}
else
}
else
{
ColumnString::Chars_t container_per_element;
size_t current_offset = 0;
for (size_t i = 0; i < block.rows(); ++i)
{
JsonBinary json_binary(
data_from[current_offset],
StringRef(&data_from[current_offset + 1], json_length - 1));
json_binary.toStringInBuffer(write_buffer);
size_t next_offset = offsets_from[i];
size_t json_length = next_offset - current_offset - 1;
if unlikely (isNullJsonBinary(json_length))
vec_null_map[i] = 1;
else
{
JsonBinary::JsonBinaryWriteBuffer element_write_buffer(container_per_element);
JsonBinary json_binary(
data_from[current_offset],
StringRef(&data_from[current_offset + 1], json_length - 1));
json_binary.toStringInBuffer(element_write_buffer);
size_t orig_length = element_write_buffer.count();
auto byte_length = charLengthToByteLengthFromUTF8(
reinterpret_cast<char *>(container_per_element.data()),
orig_length,
tidb_tp->flen());
if (byte_length < element_write_buffer.count())
context.getDAGContext()->handleTruncateError("Data Too Long");
write_buffer.write(reinterpret_cast<char *>(container_per_element.data()), byte_length);
}

writeChar(0, write_buffer);
offsets_to[i] = write_buffer.count();
current_offset = next_offset;
}
writeChar(0, write_buffer);
offsets_to[i] = write_buffer.count();
current_offset = next_offset;
}
data_to.resize(write_buffer.count());
block.getByPosition(result).column = ColumnNullable::create(std::move(col_to), std::move(col_null_map));
Expand All @@ -414,8 +493,11 @@ class FunctionCastJsonAsString : public IFunction
fmt::format("Illegal column {} of argument of function {}", column->getName(), getName()),
ErrorCodes::ILLEGAL_COLUMN);
}
};

private:
const tipb::FieldType * tidb_tp;
const Context & context;
};

class FunctionJsonLength : public IFunction
{
Expand Down
52 changes: 31 additions & 21 deletions dbms/src/Functions/FunctionsTiDBConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,30 +77,40 @@ namespace
constexpr static Int64 pow10[] = {1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000};
}

ALWAYS_INLINE inline size_t charLengthToByteLengthFromUTF8(const char * data, size_t length, size_t char_length)
{
size_t ret = 0;
for (size_t char_index = 0; char_index < char_length && ret < length; ++char_index)
{
uint8_t c = data[ret];
if (c < 0x80)
ret += 1;
else if (c < 0xE0)
ret += 2;
else if (c < 0xF0)
ret += 3;
else
ret += 4;
}
if unlikely (ret > length)
{
throw Exception(
fmt::format(
"Illegal utf8 byte sequence bytes: {} result_length: {} char_length: {}",
length,
ret,
char_length),
ErrorCodes::ILLEGAL_COLUMN);
}
return ret;
}

/// cast int/real/decimal/time as string
template <typename FromDataType, bool return_nullable>
struct TiDBConvertToString
{
using FromFieldType = typename FromDataType::FieldType;

static size_t charLengthToByteLengthFromUTF8(const char * data, size_t length, size_t char_length)
{
size_t ret = 0;
for (size_t char_index = 0; char_index < char_length && ret < length; char_index++)
{
uint8_t c = data[ret];
if (c < 0x80)
ret += 1;
else if (c < 0xE0)
ret += 2;
else if (c < 0xF0)
ret += 3;
else
ret += 4;
}
return ret;
}

static void execute(
Block & block,
const ColumnNumbers & arguments,
Expand Down Expand Up @@ -148,7 +158,7 @@ struct TiDBConvertToString
size_t next_offset = (*offsets_from)[i];
size_t org_length = next_offset - current_offset - 1;
size_t byte_length = org_length;
if (tp.flen() > 0)
if (tp.flen() >= 0)
{
byte_length = tp.flen();
if (tp.charset() == "utf8" || tp.charset() == "utf8mb4")
Expand Down Expand Up @@ -189,7 +199,7 @@ struct TiDBConvertToString
WriteBufferFromVector<ColumnString::Chars_t> element_write_buffer(container_per_element);
FormatImpl<FromDataType>::execute(vec_from[i], element_write_buffer, &type, nullptr);
size_t byte_length = element_write_buffer.count();
if (tp.flen() > 0)
if (tp.flen() >= 0)
byte_length = std::min(byte_length, tp.flen());
if (byte_length < element_write_buffer.count())
context.getDAGContext()->handleTruncateError("Data Too Long");
Expand Down Expand Up @@ -235,7 +245,7 @@ struct TiDBConvertToString
WriteBufferFromVector<ColumnString::Chars_t> element_write_buffer(container_per_element);
FormatImpl<FromDataType>::execute(vec_from[i], element_write_buffer, &type, nullptr);
size_t byte_length = element_write_buffer.count();
if (tp.flen() > 0)
if (tp.flen() >= 0)
byte_length = std::min(byte_length, tp.flen());
if (byte_length < element_write_buffer.count())
context.getDAGContext()->handleTruncateError("Data Too Long");
Expand Down
6 changes: 5 additions & 1 deletion dbms/src/Functions/tests/gtest_cast_as_json.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ class TestCastAsJson : public DB::tests::FunctionTest
json_column = executeFunction(func_name, columns);
}
// The `json_binary` should be cast as a string to improve readability.
return executeFunction("cast_json_as_string", {json_column});
tipb::FieldType field_type;
field_type.set_flen(-1);
field_type.set_collate(TiDB::ITiDBCollator::BINARY);
field_type.set_tp(TiDB::TypeString);
return executeCastJsonAsStringFunction(json_column, field_type);
}

template <typename Input, bool is_raw = false>
Expand Down
Loading

0 comments on commit 4479df8

Please sign in to comment.