Skip to content

Commit

Permalink
use hasAny to replace arraysOverlap
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinyhZou authored and zouyunhe committed Sep 19, 2024
1 parent 8cf6a58 commit df7331a
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,6 @@ case class EncodeDecodeValidator() extends FunctionValidator {
}
}

case class ArrayJoinValidator() extends FunctionValidator {
override def doValidate(expr: Expression): Boolean = expr match {
case t: ArrayJoin => !t.children.head.isInstanceOf[Literal]
case _ => true
}
}

case class FormatStringValidator() extends FunctionValidator {
override def doValidate(expr: Expression): Boolean = {
val formatString = expr.asInstanceOf[FormatString]
Expand All @@ -181,13 +174,11 @@ object CHExpressionUtil {
)

final val CH_BLACKLIST_SCALAR_FUNCTION: Map[String, FunctionValidator] = Map(
ARRAY_JOIN -> ArrayJoinValidator(),
SPLIT_PART -> DefaultValidator(),
TO_UNIX_TIMESTAMP -> UnixTimeStampValidator(),
UNIX_TIMESTAMP -> UnixTimeStampValidator(),
SEQUENCE -> SequenceValidator(),
GET_JSON_OBJECT -> GetJsonObjectValidator(),
ARRAYS_OVERLAP -> DefaultValidator(),
SPLIT -> StringSplitValidator(),
SUBSTRING_INDEX -> SubstringIndexValidator(),
LPAD -> StringLPadValidator(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1416,4 +1416,14 @@ class GlutenClickHouseHiveTableSuite
spark.sql("DROP TABLE test_tbl_7054")
}

test("GLUTEN-6506: Orc read time zone") {
val dataPath = s"$basePath/orc-data/test_reader_time_zone.snappy.orc"
val create_table_sql = ("create table test_tbl_6506(" +
"id bigint, t timestamp) stored as orc location '%s'")
.format(dataPath)
val select_sql = "select * from test_tbl_6506"
spark.sql(create_table_sql)
compareResultsAgainstVanillaSpark(select_sql, true, _ => {})
spark.sql("drop table test_tbl_6506")
}
}
43 changes: 31 additions & 12 deletions cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,39 @@ class SparkFunctionArrayJoin : public IFunction
return makeNullable(data_type);
}

ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
if (arguments.size() != 2 && arguments.size() != 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 or 3 arguments", getName());

const auto * arg_null_col = checkAndGetColumn<ColumnNullable>(arguments[0].column.get());
const ColumnArray * array_col;
if (!arg_null_col)
array_col = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
auto res_col = ColumnString::create();
auto null_col = ColumnUInt8::create(input_rows_count, 0);
PaddedPODArray<UInt8> & null_result = null_col->getData();
if (input_rows_count == 0)
return ColumnNullable::create(std::move(res_col), std::move(null_col));

const auto * arg_const_col = checkAndGetColumn<ColumnConst>(arguments[0].column.get());
const ColumnArray * array_col = nullptr;
const ColumnNullable * arg_null_col = nullptr;
if (arg_const_col)
{
if (arg_const_col->onlyNull())
{
null_result[0] = 1;
return ColumnNullable::create(std::move(res_col), std::move(null_col));
}
array_col = checkAndGetColumn<ColumnArray>(arg_const_col->getDataColumnPtr().get());
}
else
array_col = checkAndGetColumn<ColumnArray>(arg_null_col->getNestedColumnPtr().get());
{
arg_null_col = checkAndGetColumn<ColumnNullable>(arguments[0].column.get());
if (!arg_null_col)
array_col = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
else
array_col = checkAndGetColumn<ColumnArray>(arg_null_col->getNestedColumnPtr().get());
}
if (!array_col)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st argument must be array type", getName());

auto res_col = ColumnString::create();
auto null_col = ColumnUInt8::create(array_col->size(), 0);
PaddedPODArray<UInt8> & null_result = null_col->getData();
std::pair<bool, StringRef> delim_p, null_replacement_p;
bool return_result = false;
auto checkAndGetConstString = [&](const ColumnPtr & col) -> std::pair<bool, StringRef>
Expand Down Expand Up @@ -145,7 +161,7 @@ class SparkFunctionArrayJoin : public IFunction
}
}
};
if (arg_null_col->isNullAt(i))
if (arg_null_col && arg_null_col->isNullAt(i))
{
setResultNull();
continue;
Expand All @@ -166,9 +182,9 @@ class SparkFunctionArrayJoin : public IFunction
continue;
}
}

size_t array_size = array_offsets[i] - current_offset;
size_t data_pos = array_pos == 0 ? 0 : string_offsets[array_pos - 1];
size_t last_not_null_pos = 0;
for (size_t j = 0; j < array_size; ++j)
{
if (array_nested_col && array_nested_col->isNullAt(j + array_pos))
Expand All @@ -179,11 +195,14 @@ class SparkFunctionArrayJoin : public IFunction
if (j != array_size - 1)
res += delim.toString();
}
else if (j == array_size - 1)
res = res.substr(0, last_not_null_pos);
}
else
{
const StringRef s(&string_data[data_pos], string_offsets[j + array_pos] - data_pos - 1);
res += s.toString();
last_not_null_pos = res.size();
if (j != array_size - 1)
res += delim.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Range, range, range);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Flatten, flatten, sparkArrayFlatten);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArrayJoin, array_join, sparkArrayJoin);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysZip, arrays_zip, arrayZipUnaligned);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysOverlap, arrays_overlap, sparkArraysOverlap);

// map functions
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Map, map, map);
Expand Down
130 changes: 130 additions & 0 deletions cpp-ch/local-engine/Parser/scalar_function_parser/arraysOverlap.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Core/Field.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/DataTypeArray.h>
#include <Parser/FunctionParser.h>
#include <Common/CHUtil.h>

namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}

namespace local_engine
{

class FunctionParserArraysOverlap : public FunctionParser
{
public:
explicit FunctionParserArraysOverlap(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { }
~FunctionParserArraysOverlap() override = default;

static constexpr auto name = "arrays_overlap";

String getName() const override { return name; }

const ActionsDAG::Node * parse(
const substrait::Expression_ScalarFunction & substrait_func,
ActionsDAG & actions_dag) const override
{
/**
parse arrays_overlap(arr1, arr2) as
if (isNull(arr1) || isNull(arr2))
return NULL
else if (isEmpty(arr1) || isEmpty(arr2))
return false;
else if (arr1.hasAny(arr2))
{
if (!arr1.has(NULL) || !arr2.has(NULL))
return true;
else if (arr1.intersect(arr2) != NULL)
return true
else
return NULL;
}
else if (arr1.has(NULL) || arr2.has(NULL))
return NULL;
else
return false;
*/

auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
if (parsed_args.size() != 2)
throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName());

auto ch_function_name = getCHFunctionName(substrait_func);

const auto * arr1_node = parsed_args[0];
const auto * arr2_node = parsed_args[1];

const auto * arr1_is_null_node = toFunctionNode(actions_dag, "isNull", {arr1_node});
const auto * arr2_is_null_node = toFunctionNode(actions_dag, "isNull", {arr2_node});

const auto * arr1_not_null_node = toFunctionNode(actions_dag, "assumeNotNull", {arr1_node});
const auto * arr2_not_null_node = toFunctionNode(actions_dag, "assumeNotNull", {arr2_node});
const auto * arrs_or_null_node = toFunctionNode(actions_dag, "or", {arr1_is_null_node, arr2_is_null_node});

const DataTypeArray * arr_type = static_cast<const DataTypeArray *>(arr1_not_null_node->result_type.get());
const auto * null_type_node = addColumnToActionsDAG(actions_dag, makeNullable(arr_type->getNestedType()), Field{});
const auto * null_const_node = addColumnToActionsDAG(actions_dag, makeNullable(std::make_shared<DataTypeUInt8>()), Field{});
const auto * true_const_node = addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeUInt8>(), 1);
const auto * false_const_node = addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeUInt8>(), 0);
const auto * one_const_node = addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeUInt64>(), 1);

const auto * arr1_has_null_node = toFunctionNode(actions_dag, "has", {arr1_not_null_node, null_type_node});
const auto * arr2_has_null_node = toFunctionNode(actions_dag, "has", {arr2_not_null_node, null_type_node});
const auto * arr1_not_has_null_node = toFunctionNode(actions_dag, "not", {arr1_has_null_node});
const auto * arr2_not_has_null_node = toFunctionNode(actions_dag, "not", {arr2_has_null_node});

const auto * arrs_or_has_null_node = toFunctionNode(actions_dag, "or", {arr1_has_null_node, arr2_has_null_node});
const auto * arrs_one_not_has_null_node = toFunctionNode(actions_dag, "or", {arr1_not_has_null_node, arr2_not_has_null_node});
const auto * arrs_not_has_null_node = toFunctionNode(actions_dag, "and", {arr1_not_has_null_node, arr2_not_has_null_node});

const auto * arr1_is_empty_node = toFunctionNode(actions_dag, "empty", {arr1_not_null_node});
const auto * arr2_is_empty_node = toFunctionNode(actions_dag, "empty", {arr2_not_null_node});
const auto * arrs_or_empty_node = toFunctionNode(actions_dag, "or", {arr1_is_empty_node, arr2_is_empty_node});

const auto * arrs_has_any_node = toFunctionNode(actions_dag, "hasAny", {arr1_not_null_node, arr2_not_null_node});
const auto * arrs_intersect_node = toFunctionNode(actions_dag, "arrayIntersect", {arr1_not_null_node, arr2_not_null_node});
const auto * arrs_intersect_len_node = toFunctionNode(actions_dag, "length", {arrs_intersect_node});
const auto * arrs_intersect_is_single_node = toFunctionNode(actions_dag, "equals", {arrs_intersect_len_node, one_const_node});
const auto * arrs_intersect_has_null_node = toFunctionNode(actions_dag, "has", {arrs_intersect_node, null_type_node});
const auto * arrs_intersect_single_has_null = toFunctionNode(actions_dag, "and", {arrs_intersect_is_single_node, arrs_intersect_has_null_node});

const auto * arrs_intersect_single_has_null_result = toFunctionNode(actions_dag, "if", {arrs_intersect_single_has_null, null_const_node, true_const_node});
const auto * arrs_if_has_null_node = toFunctionNode(actions_dag, "if", {arrs_one_not_has_null_node, true_const_node, arrs_intersect_single_has_null_result});
const auto * arrs_overlap_node = toFunctionNode(actions_dag, "multiIf", {
arrs_or_null_node, null_const_node,
arrs_or_empty_node, false_const_node,
arrs_has_any_node, arrs_if_has_null_node,
arrs_or_has_null_node, null_const_node,
false_const_node});
return convertNodeTypeIfNeeded(substrait_func, arrs_overlap_node, actions_dag);
}

String getCHFunctionName(const substrait::Expression_ScalarFunction & /*substrait_func*/) const override
{
return "hasAny";
}
};

static FunctionParserRegister<FunctionParserArraysOverlap> register_array_position;
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ FormatFile::InputFormatPtr ORCFormatFile::createInputFormat(const DB::Block & he
const String mapped_timezone = DateTimeUtil::convertTimeZone(config_timezone);
format_settings.orc.reader_time_zone_name = mapped_timezone;
}

auto input_format = std::make_shared<DB::NativeORCBlockInputFormat>(*file_format->read_buffer, header, format_settings);
file_format->input = input_format;
return file_format;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,10 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("cast from struct III")
enableSuite[GlutenCollectionExpressionsSuite]
.exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576
.exclude("Array and Map Size")
.exclude("MapEntries")
.exclude("Map Concat")
.exclude("MapFromEntries")
.exclude("Sequence of numbers")
.exclude("elementAt")
.exclude("Shuffle")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,10 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("cast from struct III")
enableSuite[GlutenCollectionExpressionsSuite]
.exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576
.exclude("Array and Map Size")
.exclude("MapEntries")
.exclude("Map Concat")
.exclude("MapFromEntries")
.exclude("Sequence of numbers")
.exclude("elementAt")
.exclude("Shuffle")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,10 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("SPARK-36924: Cast IntegralType to YearMonthIntervalType")
enableSuite[GlutenCollectionExpressionsSuite]
.exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576
.exclude("Array and Map Size")
.exclude("MapEntries")
.exclude("Map Concat")
.exclude("MapFromEntries")
.exclude("Sequence of numbers")
.exclude("elementAt")
.exclude("Shuffle")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,10 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("SPARK-36924: Cast IntegralType to YearMonthIntervalType")
enableSuite[GlutenCollectionExpressionsSuite]
.exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576
.exclude("Array and Map Size")
.exclude("MapEntries")
.exclude("Map Concat")
.exclude("MapFromEntries")
.exclude("Sequence of numbers")
.exclude("elementAt")
.exclude("Shuffle")
Expand Down

0 comments on commit df7331a

Please sign in to comment.