Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
BiteTheDDDDt committed Mar 7, 2023
1 parent 28d56ef commit a6b16ff
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 61 deletions.
48 changes: 24 additions & 24 deletions be/src/vec/aggregate_functions/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class IAggregateFunctionHelper : public IAggregateFunction {
const size_t num_rows) const noexcept override {
const size_t size_of_data_ = size_of_data();
for (size_t i = 0; i != num_rows; ++i) {
static_cast<const Derived*>(this)->destroy(place + size_of_data_ * i);
assert_cast<const Derived*>(this)->destroy(place + size_of_data_ * i);
}
}

Expand All @@ -235,7 +235,7 @@ class IAggregateFunctionHelper : public IAggregateFunction {
}
auto iter = place_rows.begin();
while (iter != place_rows.end()) {
static_cast<const Derived*>(this)->add_many(iter->first, columns, iter->second,
assert_cast<const Derived*>(this)->add_many(iter->first, columns, iter->second,
arena);
iter++;
}
Expand All @@ -244,23 +244,23 @@ class IAggregateFunctionHelper : public IAggregateFunction {
}

for (size_t i = 0; i < batch_size; ++i) {
static_cast<const Derived*>(this)->add(places[i] + place_offset, columns, i, arena);
assert_cast<const Derived*>(this)->add(places[i] + place_offset, columns, i, arena);
}
}

void add_batch_selected(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
const IColumn** columns, Arena* arena) const override {
for (size_t i = 0; i < batch_size; ++i) {
if (places[i]) {
static_cast<const Derived*>(this)->add(places[i] + place_offset, columns, i, arena);
assert_cast<const Derived*>(this)->add(places[i] + place_offset, columns, i, arena);
}
}
}

void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena* arena) const override {
for (size_t i = 0; i < batch_size; ++i) {
static_cast<const Derived*>(this)->add(place, columns, i, arena);
assert_cast<const Derived*>(this)->add(place, columns, i, arena);
}
}
//now this is use for sum/count/avg/min/max win function, other win function should override this function in class
Expand All @@ -271,28 +271,28 @@ class IAggregateFunctionHelper : public IAggregateFunction {
frame_start = std::max<int64_t>(frame_start, partition_start);
frame_end = std::min<int64_t>(frame_end, partition_end);
for (int64_t i = frame_start; i < frame_end; ++i) {
static_cast<const Derived*>(this)->add(place, columns, i, arena);
assert_cast<const Derived*>(this)->add(place, columns, i, arena);
}
}

void add_batch_range(size_t batch_begin, size_t batch_end, AggregateDataPtr place,
const IColumn** columns, Arena* arena, bool has_null) override {
for (size_t i = batch_begin; i <= batch_end; ++i) {
static_cast<const Derived*>(this)->add(place, columns, i, arena);
assert_cast<const Derived*>(this)->add(place, columns, i, arena);
}
}

void insert_result_into_vec(const std::vector<AggregateDataPtr>& places, const size_t offset,
IColumn& to, const size_t num_rows) const override {
for (size_t i = 0; i != num_rows; ++i) {
static_cast<const Derived*>(this)->insert_result_into(places[i] + offset, to);
assert_cast<const Derived*>(this)->insert_result_into(places[i] + offset, to);
}
}

void serialize_vec(const std::vector<AggregateDataPtr>& places, size_t offset,
BufferWritable& buf, const size_t num_rows) const override {
for (size_t i = 0; i != num_rows; ++i) {
static_cast<const Derived*>(this)->serialize(places[i] + offset, buf);
assert_cast<const Derived*>(this)->serialize(places[i] + offset, buf);
buf.commit();
}
}
Expand All @@ -307,35 +307,35 @@ class IAggregateFunctionHelper : public IAggregateFunction {
const size_t num_rows, Arena* arena) const override {
char place[size_of_data()];
for (size_t i = 0; i != num_rows; ++i) {
static_cast<const Derived*>(this)->create(place);
static_cast<const Derived*>(this)->add(place, columns, i, arena);
static_cast<const Derived*>(this)->serialize(place, buf);
assert_cast<const Derived*>(this)->create(place);
assert_cast<const Derived*>(this)->add(place, columns, i, arena);
assert_cast<const Derived*>(this)->serialize(place, buf);
buf.commit();
static_cast<const Derived*>(this)->destroy(place);
assert_cast<const Derived*>(this)->destroy(place);
}
}

void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst,
const size_t num_rows, Arena* arena) const override {
VectorBufferWriter writter(static_cast<ColumnString&>(*dst));
VectorBufferWriter writter(assert_cast<ColumnString&>(*dst));
streaming_agg_serialize(columns, writter, num_rows, arena);
}

void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
MutableColumnPtr& dst) const override {
VectorBufferWriter writter(static_cast<ColumnString&>(*dst));
static_cast<const Derived*>(this)->serialize(place, writter);
VectorBufferWriter writter(assert_cast<ColumnString&>(*dst));
assert_cast<const Derived*>(this)->serialize(place, writter);
writter.commit();
}

void deserialize_vec(AggregateDataPtr places, const ColumnString* column, Arena* arena,
size_t num_rows) const override {
const auto size_of_data = static_cast<const Derived*>(this)->size_of_data();
const auto size_of_data = assert_cast<const Derived*>(this)->size_of_data();
for (size_t i = 0; i != num_rows; ++i) {
auto place = places + size_of_data * i;
VectorBufferReader buffer_reader(column->get_data_at(i));
static_cast<const Derived*>(this)->create(place);
static_cast<const Derived*>(this)->deserialize(place, buffer_reader, arena);
assert_cast<const Derived*>(this)->create(place);
assert_cast<const Derived*>(this)->deserialize(place, buffer_reader, arena);
}
}

Expand All @@ -346,20 +346,20 @@ class IAggregateFunctionHelper : public IAggregateFunction {

void merge_vec(const AggregateDataPtr* places, size_t offset, ConstAggregateDataPtr rhs,
Arena* arena, const size_t num_rows) const override {
const auto size_of_data = static_cast<const Derived*>(this)->size_of_data();
const auto size_of_data = assert_cast<const Derived*>(this)->size_of_data();
for (size_t i = 0; i != num_rows; ++i) {
static_cast<const Derived*>(this)->merge(places[i] + offset, rhs + size_of_data * i,
assert_cast<const Derived*>(this)->merge(places[i] + offset, rhs + size_of_data * i,
arena);
}
}

void merge_vec_selected(const AggregateDataPtr* places, size_t offset,
ConstAggregateDataPtr rhs, Arena* arena,
const size_t num_rows) const override {
const auto size_of_data = static_cast<const Derived*>(this)->size_of_data();
const auto size_of_data = assert_cast<const Derived*>(this)->size_of_data();
for (size_t i = 0; i != num_rows; ++i) {
if (places[i]) {
static_cast<const Derived*>(this)->merge(places[i] + offset, rhs + size_of_data * i,
assert_cast<const Derived*>(this)->merge(places[i] + offset, rhs + size_of_data * i,
arena);
}
}
Expand Down Expand Up @@ -399,7 +399,7 @@ class IAggregateFunctionDataHelper : public IAggregateFunctionHelper<Derived> {
char deserialized_data[size_of_data()];
AggregateDataPtr deserialized_place = (AggregateDataPtr)deserialized_data;

auto derived = static_cast<const Derived*>(this);
auto derived = assert_cast<const Derived*>(this);
derived->create(deserialized_place);
derived->deserialize(deserialized_place, buf, arena);
derived->merge(place, deserialized_place, arena);
Expand Down
14 changes: 2 additions & 12 deletions be/src/vec/aggregate_functions/aggregate_function_avg_weighted.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,8 @@ namespace doris::vectorized {
AggregateFunctionPtr create_aggregate_function_avg_weight(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
WhichDataType which(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return AggregateFunctionPtr(new AggregateFunctionAvgWeight<TYPE>(argument_types));
FOR_NUMERIC_TYPES(DISPATCH)
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH

LOG(WARNING) << fmt::format("Illegal argument type for aggregate function topn_array is: {}",
argument_types[0]->get_name());
return nullptr;
return AggregateFunctionPtr(creator_with_type::create<AggregateFunctionAvgWeight>(
result_is_nullable, argument_types));
}

void register_aggregate_function_avg_weighted(AggregateFunctionSimpleFactory& factory) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct AggregateFunctionAvgWeightedData {
void add(const T& data_val, double weight_val) {
if constexpr (IsDecimalV2<T>) {
DecimalV2Value value = binary_cast<Int128, DecimalV2Value>(data_val);
data_sum = data_sum + (static_cast<double>(value) * weight_val);
data_sum = data_sum + (double(value) * weight_val);
} else {
data_sum = data_sum + (data_val * weight_val);
}
Expand Down Expand Up @@ -81,8 +81,8 @@ class AggregateFunctionAvgWeight final

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena*) const override {
const auto& column = static_cast<const ColVecType&>(*columns[0]);
const auto& weight = static_cast<const ColumnVector<Float64>&>(*columns[1]);
const auto& column = assert_cast<const ColVecType&>(*columns[0]);
const auto& weight = assert_cast<const ColumnVector<Float64>&>(*columns[1]);
this->data(place).add(column.get_data()[row_num], weight.get_element(row_num));
}

Expand All @@ -103,7 +103,7 @@ class AggregateFunctionAvgWeight final
}

void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& column = static_cast<ColumnVector<Float64>&>(to);
auto& column = assert_cast<ColumnVector<Float64>&>(to);
column.get_data().push_back(this->data(place).get());
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,13 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::stri
AggregateFunctionPtr create_aggregate_function_percentile(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
return AggregateFunctionPtr(creator_without_type::create<AggregateFunctionPercentile>(
result_is_nullable, argument_types));
return std::make_shared<AggregateFunctionPercentile>(argument_types);
}

AggregateFunctionPtr create_aggregate_function_percentile_array(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
return AggregateFunctionPtr(creator_without_type::create<AggregateFunctionPercentileArray>(
result_is_nullable, argument_types));
return std::make_shared<AggregateFunctionPercentileArray>(argument_types);
}

void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class AggregateFunctionPercentileApprox
if (std::isnan(result)) {
nullable_column.insert_default();
} else {
auto& col = static_cast<ColumnVector<Float64>&>(nullable_column.get_nested_column());
auto& col = assert_cast<ColumnVector<Float64>&>(nullable_column.get_nested_column());
col.get_data().push_back(result);
nullable_column.get_null_map_data().push_back(0);
}
Expand Down Expand Up @@ -193,11 +193,11 @@ class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPerce
for (int i = 0; i < 2; ++i) {
const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]);
if (nullable_column == nullptr) { //Not Nullable column
const auto& column = static_cast<const ColumnVector<Float64>&>(*columns[i]);
const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]);
column_data[i] = column.get_float64(row_num);
} else if (!nullable_column->is_null_at(
row_num)) { // Nullable column && Not null data
const auto& column = static_cast<const ColumnVector<Float64>&>(
const auto& column = assert_cast<const ColumnVector<Float64>&>(
nullable_column->get_nested_column());
column_data[i] = column.get_float64(row_num);
} else { // Nullable column && null data
Expand All @@ -211,8 +211,8 @@ class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPerce
this->data(place).add(column_data[0], column_data[1]);

} else {
const auto& sources = static_cast<const ColumnVector<Float64>&>(*columns[0]);
const auto& quantile = static_cast<const ColumnVector<Float64>&>(*columns[1]);
const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]);
const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]);

this->data(place).init();
this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num));
Expand All @@ -233,11 +233,11 @@ class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPer
for (int i = 0; i < 3; ++i) {
const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]);
if (nullable_column == nullptr) { //Not Nullable column
const auto& column = static_cast<const ColumnVector<Float64>&>(*columns[i]);
const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]);
column_data[i] = column.get_float64(row_num);
} else if (!nullable_column->is_null_at(
row_num)) { // Nullable column && Not null data
const auto& column = static_cast<const ColumnVector<Float64>&>(
const auto& column = assert_cast<const ColumnVector<Float64>&>(
nullable_column->get_nested_column());
column_data[i] = column.get_float64(row_num);
} else { // Nullable column && null data
Expand All @@ -251,9 +251,9 @@ class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPer
this->data(place).add(column_data[0], column_data[1]);

} else {
const auto& sources = static_cast<const ColumnVector<Float64>&>(*columns[0]);
const auto& quantile = static_cast<const ColumnVector<Float64>&>(*columns[1]);
const auto& compression = static_cast<const ColumnVector<Float64>&>(*columns[2]);
const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]);
const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]);
const auto& compression = assert_cast<const ColumnVector<Float64>&>(*columns[2]);

this->data(place).init(compression.get_float64(row_num));
this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num));
Expand Down Expand Up @@ -345,7 +345,7 @@ struct PercentileState {
}

void insert_result_into(IColumn& to) const {
auto& column_data = static_cast<ColumnVector<Float64>&>(to).get_data();
auto& column_data = assert_cast<ColumnVector<Float64>&>(to).get_data();
for (int i = 0; i < vec_counts.size(); ++i) {
column_data.push_back(vec_counts[i].terminate(vec_quantile[i]).val);
}
Expand All @@ -365,8 +365,8 @@ class AggregateFunctionPercentile final

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena*) const override {
const auto& sources = static_cast<const ColumnVector<Int64>&>(*columns[0]);
const auto& quantile = static_cast<const ColumnVector<Float64>&>(*columns[1]);
const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]);
const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]);
AggregateFunctionPercentile::data(place).add(sources.get_int(row_num), quantile.get_data(),
1);
}
Expand Down Expand Up @@ -410,12 +410,12 @@ class AggregateFunctionPercentileArray final

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena*) const override {
const auto& sources = static_cast<const ColumnVector<Int64>&>(*columns[0]);
const auto& quantile_array = static_cast<const ColumnArray&>(*columns[1]);
const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]);
const auto& quantile_array = assert_cast<const ColumnArray&>(*columns[1]);
const auto& offset_column_data = quantile_array.get_offsets();
const auto& nested_column =
static_cast<const ColumnNullable&>(quantile_array.get_data()).get_nested_column();
const auto& nested_column_data = static_cast<const ColumnVector<Float64>&>(nested_column);
assert_cast<const ColumnNullable&>(quantile_array.get_data()).get_nested_column();
const auto& nested_column_data = assert_cast<const ColumnVector<Float64>&>(nested_column);

AggregateFunctionPercentileArray::data(place).add(
sources.get_int(row_num), nested_column_data.get_data(),
Expand Down

0 comments on commit a6b16ff

Please sign in to comment.