Skip to content

Commit

Permalink
*.: Use String as the common type for Decimal in join when an Excepti…
Browse files Browse the repository at this point in the history
…on occurs (#6179)

close #4519
  • Loading branch information
SeaRise authored Nov 14, 2022
1 parent 7c4b740 commit 25e5c1c
Show file tree
Hide file tree
Showing 7 changed files with 528 additions and 25 deletions.
15 changes: 9 additions & 6 deletions dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <Flash/Coprocessor/DAGExpressionAnalyzer.h>
#include <Flash/Coprocessor/DAGExpressionAnalyzerHelper.h>
#include <Flash/Coprocessor/DAGUtils.h>
#include <Flash/Coprocessor/JoinInterpreterHelper.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionsTiDBConversion.h>
Expand All @@ -38,7 +39,6 @@
#include <Storages/Transaction/TypeMapping.h>
#include <WindowFunctions/WindowFunctionFactory.h>


namespace DB
{
namespace ErrorCodes
Expand Down Expand Up @@ -923,7 +923,7 @@ void DAGExpressionAnalyzer::appendJoin(
std::pair<bool, Names> DAGExpressionAnalyzer::buildJoinKey(
const ExpressionActionsPtr & actions,
const google::protobuf::RepeatedPtrField<tipb::Expr> & keys,
const DataTypes & key_types,
const JoinKeyTypes & join_key_types,
bool left,
bool is_right_out_join)
{
Expand All @@ -939,10 +939,13 @@ std::pair<bool, Names> DAGExpressionAnalyzer::buildJoinKey(

String key_name = getActions(key, actions);
DataTypePtr current_type = actions->getSampleBlock().getByName(key_name).type;
if (!removeNullable(current_type)->equals(*removeNullable(key_types[i])))
const auto & join_key_type = join_key_types[i];
if (!removeNullable(current_type)->equals(*removeNullable(join_key_type.key_type)))
{
/// need to convert to key type
key_name = appendCast(key_types[i], actions, key_name);
key_name = join_key_type.is_incompatible_decimal
? applyFunction("formatDecimal", {key_name}, actions, nullptr)
: appendCast(join_key_type.key_type, actions, key_name);
has_actions = true;
}
if (!has_actions && (!left || is_right_out_join))
Expand Down Expand Up @@ -986,7 +989,7 @@ std::pair<bool, Names> DAGExpressionAnalyzer::buildJoinKey(
bool DAGExpressionAnalyzer::appendJoinKeyAndJoinFilters(
ExpressionActionsChain & chain,
const google::protobuf::RepeatedPtrField<tipb::Expr> & keys,
const DataTypes & key_types,
const JoinKeyTypes & join_key_types,
Names & key_names,
bool left,
bool is_right_out_join,
Expand All @@ -997,7 +1000,7 @@ bool DAGExpressionAnalyzer::appendJoinKeyAndJoinFilters(
ExpressionActionsPtr actions = chain.getLastActions();

bool ret = false;
std::tie(ret, key_names) = buildJoinKey(actions, keys, key_types, left, is_right_out_join);
std::tie(ret, key_names) = buildJoinKey(actions, keys, join_key_types, left, is_right_out_join);

if (!filters.empty())
{
Expand Down
7 changes: 5 additions & 2 deletions dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ enum class ExtraCastAfterTSMode
AppendDurationCast
};

struct JoinKeyType;
using JoinKeyTypes = std::vector<JoinKeyType>;

class DAGExpressionAnalyzerHelper;
/** Transforms an expression from DAG expression into a sequence of actions to execute it.
*/
Expand Down Expand Up @@ -157,7 +160,7 @@ class DAGExpressionAnalyzer : private boost::noncopyable
bool appendJoinKeyAndJoinFilters(
ExpressionActionsChain & chain,
const google::protobuf::RepeatedPtrField<tipb::Expr> & keys,
const DataTypes & key_types,
const JoinKeyTypes & join_key_types,
Names & key_names,
bool left,
bool is_right_out_join,
Expand Down Expand Up @@ -288,7 +291,7 @@ class DAGExpressionAnalyzer : private boost::noncopyable
std::pair<bool, Names> buildJoinKey(
const ExpressionActionsPtr & actions,
const google::protobuf::RepeatedPtrField<tipb::Expr> & keys,
const DataTypes & key_types,
const JoinKeyTypes & join_key_types,
bool left,
bool is_right_out_join);

Expand Down
64 changes: 50 additions & 14 deletions dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <Common/TiFlashException.h>
#include <DataStreams/ExpressionBlockInputStream.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/getLeastSupertype.h>
#include <Flash/Coprocessor/DAGExpressionAnalyzer.h>
#include <Flash/Coprocessor/DAGUtils.h>
Expand All @@ -26,7 +27,14 @@

#include <unordered_map>

namespace DB::JoinInterpreterHelper
namespace DB
{
namespace ErrorCodes
{
extern const int NO_COMMON_TYPE;
} // namespace ErrorCodes

namespace JoinInterpreterHelper
{
namespace
{
Expand Down Expand Up @@ -99,32 +107,59 @@ std::pair<ASTTableJoin::Kind, size_t> getJoinKindAndBuildSideIndex(const tipb::J
return {kind, build_side_index};
}

DataTypes getJoinKeyTypes(const tipb::Join & join)
JoinKeyType geCommonTypeForJoinOn(const DataTypePtr & left_type, const DataTypePtr & right_type)
{
try
{
return {getLeastSupertype({left_type, right_type}), false};
}
catch (DB::Exception & e)
{
if (e.code() == ErrorCodes::NO_COMMON_TYPE
&& removeNullable(left_type)->isDecimal()
&& removeNullable(right_type)->isDecimal())
{
// fix https://github.com/pingcap/tiflash/issues/4519
// String is the common type for all types, it is always safe to choose String.
// But then we need to use `FunctionFormatDecimal` to format decimal.
// For example 0.1000000000 is equal to 0.10000000000000000000, but the original strings are not equal.
RUNTIME_ASSERT(!left_type->onlyNull() || !right_type->onlyNull());
auto fall_back_type = std::make_shared<DataTypeString>();
bool make_nullable = left_type->isNullable() || right_type->isNullable();
return {make_nullable ? makeNullable(fall_back_type) : fall_back_type, true};
}
else
{
throw;
}
}
}

JoinKeyTypes getJoinKeyTypes(const tipb::Join & join)
{
if (unlikely(join.left_join_keys_size() != join.right_join_keys_size()))
throw TiFlashException("size of join.left_join_keys != size of join.right_join_keys", Errors::Coprocessor::BadRequest);
DataTypes key_types;
JoinKeyTypes join_key_types;
for (int i = 0; i < join.left_join_keys_size(); ++i)
{
if (unlikely(!exprHasValidFieldType(join.left_join_keys(i)) || !exprHasValidFieldType(join.right_join_keys(i))))
throw TiFlashException("Join key without field type", Errors::Coprocessor::BadRequest);
DataTypes types;
types.emplace_back(getDataTypeByFieldTypeForComputingLayer(join.left_join_keys(i).field_type()));
types.emplace_back(getDataTypeByFieldTypeForComputingLayer(join.right_join_keys(i).field_type()));
DataTypePtr common_type = getLeastSupertype(types);
key_types.emplace_back(common_type);
auto left_type = getDataTypeByFieldTypeForComputingLayer(join.left_join_keys(i).field_type());
auto right_type = getDataTypeByFieldTypeForComputingLayer(join.right_join_keys(i).field_type());
join_key_types.emplace_back(geCommonTypeForJoinOn(left_type, right_type));
}
return key_types;
return join_key_types;
}

TiDB::TiDBCollators getJoinKeyCollators(const tipb::Join & join, const DataTypes & join_key_types)
TiDB::TiDBCollators getJoinKeyCollators(const tipb::Join & join, const JoinKeyTypes & join_key_types)
{
TiDB::TiDBCollators collators;
size_t join_key_size = join_key_types.size();
if (join.probe_types_size() == static_cast<int>(join_key_size) && join.build_types_size() == join.probe_types_size())
for (size_t i = 0; i < join_key_size; ++i)
{
if (removeNullable(join_key_types[i])->isString())
// Don't need to check the collate for decimal format string.
if (removeNullable(join_key_types[i].key_type)->isString() && !join_key_types[i].is_incompatible_decimal)
{
if (unlikely(join.probe_types(i).collate() != join.build_types(i).collate()))
throw TiFlashException("Join with different collators on the join key", Errors::Coprocessor::BadRequest);
Expand Down Expand Up @@ -330,7 +365,7 @@ std::tuple<ExpressionActionsPtr, Names, String> prepareJoin(
const Context & context,
const Block & input_header,
const google::protobuf::RepeatedPtrField<tipb::Expr> & keys,
const DataTypes & key_types,
const JoinKeyTypes & join_key_types,
bool left,
bool is_right_out_join,
const google::protobuf::RepeatedPtrField<tipb::Expr> & filters)
Expand All @@ -342,7 +377,8 @@ std::tuple<ExpressionActionsPtr, Names, String> prepareJoin(
ExpressionActionsChain chain;
Names key_names;
String filter_column_name;
dag_analyzer.appendJoinKeyAndJoinFilters(chain, keys, key_types, key_names, left, is_right_out_join, filters, filter_column_name);
dag_analyzer.appendJoinKeyAndJoinFilters(chain, keys, join_key_types, key_names, left, is_right_out_join, filters, filter_column_name);
return {chain.getLastActions(), std::move(key_names), std::move(filter_column_name)};
}
} // namespace DB::JoinInterpreterHelper
} // namespace JoinInterpreterHelper
} // namespace DB
11 changes: 9 additions & 2 deletions dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ namespace DB
{
class Context;

struct JoinKeyType
{
DataTypePtr key_type;
bool is_incompatible_decimal;
};
using JoinKeyTypes = std::vector<JoinKeyType>;

namespace JoinInterpreterHelper
{
struct TiFlashJoin
Expand All @@ -40,7 +47,7 @@ struct TiFlashJoin
ASTTableJoin::Kind kind;
size_t build_side_index = 0;

DataTypes join_key_types;
JoinKeyTypes join_key_types;
TiDB::TiDBCollators join_key_collators;

ASTTableJoin::Strictness strictness;
Expand Down Expand Up @@ -123,7 +130,7 @@ std::tuple<ExpressionActionsPtr, Names, String> prepareJoin(
const Context & context,
const Block & input_header,
const google::protobuf::RepeatedPtrField<tipb::Expr> & keys,
const DataTypes & key_types,
const JoinKeyTypes & join_key_types,
bool left,
bool is_right_out_join,
const google::protobuf::RepeatedPtrField<tipb::Expr> & filters);
Expand Down
Loading

0 comments on commit 25e5c1c

Please sign in to comment.