Skip to content

Commit

Permalink
[fix](window_function) min/max/sum/avg should be always nullable
Browse files Browse the repository at this point in the history
Co-authored-by: starocean999 <40539150+starocean999@users.noreply.github.com>
  • Loading branch information
mrhhsg and starocean999 committed Nov 16, 2023
1 parent e29d8cb commit 9b95e15
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ struct ReaderFirstAndLastData {
using StoreType = std::conditional_t<is_copy, CopiedValue<ColVecType, arg_is_nullable>,
Value<ColVecType, arg_is_nullable>>;
static constexpr bool nullable = arg_is_nullable;
static constexpr bool result_nullable = result_is_nullable;

void reset() {
_data_value.reset();
Expand Down Expand Up @@ -202,7 +203,13 @@ class ReaderFunctionData final

String get_name() const override { return Data::name(); }

DataTypePtr get_return_type() const override { return _argument_type; }
DataTypePtr get_return_type() const override {
if constexpr (Data::result_nullable) {
return make_nullable(_argument_type);
} else {
return _argument_type;
}
}

void insert_result_into(ConstAggregateDataPtr place, IColumn& to) const override {
this->data(place).insert_result_into(to);
Expand Down
9 changes: 8 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_window.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ struct BaseValue : public Value<ColVecType, arg_is_nullable> {
template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable>
struct LeadLagData {
public:
static constexpr bool result_nullable = result_is_nullable;
void reset() {
_data_value.reset();
_default_value.reset();
Expand Down Expand Up @@ -395,7 +396,13 @@ class WindowFunctionData final

String get_name() const override { return Data::name(); }

DataTypePtr get_return_type() const override { return _argument_type; }
DataTypePtr get_return_type() const override {
if constexpr (Data::result_nullable) {
return make_nullable(_argument_type);
} else {
return _argument_type;
}
}

void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
Expand Down
17 changes: 17 additions & 0 deletions be/src/vec/exec/vanalytic_eval_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,19 @@ void VAnalyticEvalNode::_insert_result_info(int64_t current_block_rows) {

for (int i = 0; i < _agg_functions_size; ++i) {
for (int j = get_result_start; j < _window_end_position; ++j) {
if (!_agg_functions[i]->function()->get_return_type()->is_nullable() &&
_result_window_columns[i]->is_nullable()) {
if (_current_window_empty) {
_result_window_columns[i]->insert_default();
} else {
auto* dst = assert_cast<ColumnNullable*>(_result_window_columns[i].get());
dst->get_null_map_data().push_back(0);
_agg_functions[i]->insert_result_info(
_fn_place_ptr + _offsets_of_aggregate_states[i],
&dst->get_nested_column());
}
continue;
}
_agg_functions[i]->insert_result_info(_fn_place_ptr + _offsets_of_aggregate_states[i],
_result_window_columns[i].get());
}
Expand Down Expand Up @@ -683,6 +696,10 @@ void VAnalyticEvalNode::_execute_for_win_func(int64_t partition_start, int64_t p
partition_start, partition_end, frame_start, frame_end,
_fn_place_ptr + _offsets_of_aggregate_states[i], _agg_columns.data(), nullptr);
}

// If the end is not greater than the start, the current window should be empty.
_current_window_empty =
std::min(frame_end, partition_end) <= std::max(frame_start, partition_start);
}

//binary search for range to calculate peer group
Expand Down
1 change: 1 addition & 0 deletions be/src/vec/exec/vanalytic_eval_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class VAnalyticEvalNode : public ExecNode {
int64_t _rows_end_offset = 0;
size_t _agg_functions_size = 0;
bool _agg_functions_created = false;
bool _current_window_empty = false;

/// The offset of the n-th functions.
std::vector<size_t> _offsets_of_aggregate_states;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,15 @@ public void analyzeImpl(Analyzer analyzer) throws AnalysisException {
standardize(analyzer);

setChildren();

String functionName = fn.functionName();
if (functionName.equalsIgnoreCase("sum") || functionName.equalsIgnoreCase("max")
|| functionName.equalsIgnoreCase("min") || functionName.equalsIgnoreCase("avg")) {
// sum, max, min and avg in window function should be always nullable
Function function = fnCall.fn.clone();
function.setNullableMode(Function.NullableMode.ALWAYS_NULLABLE);
fnCall.setFn(function);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.util.ExpressionUtils;
Expand All @@ -44,7 +49,20 @@ public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory i
@Override
public Rule build() {
return logicalProject().when(project -> containsWindowExpression(project.getProjects())).then(project -> {
List<NamedExpression> outputs = project.getProjects();
List<NamedExpression> outputs =
ExpressionUtils.rewriteDownShortCircuit(project.getProjects(), output -> {
if (output instanceof WindowExpression) {
Expression expression = ((WindowExpression) output).getFunction();
if (expression instanceof Sum || expression instanceof Max
|| expression instanceof Min || expression instanceof Avg) {
// sum, max, min and avg in window function should be always nullable
return ((WindowExpression) output)
.withFunction(((NullableAggregateFunction) expression)
.withAlwaysNullable(true));
}
}
return output;
});

// 1. handle bottom projects
Set<Alias> existedAlias = ExpressionUtils.collect(outputs, Alias.class::isInstance);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select_default --
21 04-21-11 1 1 1 1 1.0 1 1 1 1
22 04-22-10-21 0 0 1 1 0.5 1 0 0 0
22 04-22-10-21 1 0 1 1 0.5 1 0 1 1
23 04-23-10 1 1 1 1 1.0 1 1 1 1
24 02-24-10-21 1 1 1 1 1.0 1 1 1 1

-- !select_empty_window --
21 04-21-11 1 \N \N \N \N \N \N \N \N
22 04-22-10-21 0 \N \N \N \N \N \N \N \N
22 04-22-10-21 1 0 0 0 0.0 0 0 \N \N
23 04-23-10 1 \N \N \N \N \N \N \N \N
24 02-24-10-21 1 \N \N \N \N \N \N \N \N

Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("test_always_nullable_window_function") {
def tableName = "test_always_nullable_window_function_table"


sql """ DROP TABLE IF EXISTS ${tableName} """
sql """
CREATE TABLE IF NOT EXISTS ${tableName} (
`myday` INT,
`time_col` VARCHAR(40) NOT NULL,
`state` INT
) ENGINE=OLAP
DUPLICATE KEY(`myday`,time_col,state)
COMMENT "OLAP"
DISTRIBUTED BY HASH(`myday`) BUCKETS 2
PROPERTIES (
"replication_num" = "1",
"in_memory" = "false",
"storage_format" = "V2"
);
"""

sql """ INSERT INTO ${tableName} VALUES
(21,"04-21-11",1),
(22,"04-22-10-21",0),
(22,"04-22-10-21",1),
(23,"04-23-10",1),
(24,"02-24-10-21",1); """

qt_select_default """
select *,
first_value(state) over(partition by myday order by time_col rows BETWEEN 1 preceding AND 1 following) f_value,
last_value(state) over(partition by myday order by time_col rows BETWEEN 1 preceding AND 1 following) l_value,
sum(state) over(partition by myday order by time_col rows BETWEEN 1 preceding AND 1 following) sum_value,
avg(state) over(partition by myday order by time_col rows BETWEEN 1 preceding AND 1 following) avg_value,
max(state) over(partition by myday order by time_col rows BETWEEN 1 preceding AND 1 following) max_value,
min(state) over(partition by myday order by time_col rows BETWEEN 1 preceding AND 1 following) min_value,
lag(state, 0, null) over (partition by myday order by time_col) lag_value,
lead(state, 0, null) over (partition by myday order by time_col) lead_value
from ${tableName} order by myday, time_col, state;
"""
qt_select_empty_window """
select *,
first_value(state) over(partition by myday order by time_col rows BETWEEN 1 preceding AND 1 preceding) f_value,
last_value(state) over(partition by myday order by time_col rows BETWEEN 1 preceding AND 1 preceding) l_value,
sum(state) over(partition by myday order by time_col rows BETWEEN 1 preceding AND 1 preceding) sum_value,
avg(state) over(partition by myday order by time_col rows BETWEEN 1 preceding AND 1 preceding) avg_value,
max(state) over(partition by myday order by time_col rows BETWEEN 1 preceding AND 1 preceding) max_value,
min(state) over(partition by myday order by time_col rows BETWEEN 1 preceding AND 1 preceding) min_value,
lag(state, 2, null) over (partition by myday order by time_col) lag_value,
lead(state, 2, null) over (partition by myday order by time_col) lead_value
from ${tableName} order by myday, time_col, state;
"""

}

0 comments on commit 9b95e15

Please sign in to comment.