Skip to content

Commit

Permalink
[vectorized](udaf) support array type for java-udaf (apache#17351)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangstar333 authored Mar 9, 2023
1 parent 06dee69 commit 4ef4615
Show file tree
Hide file tree
Showing 14 changed files with 632 additions and 15 deletions.
4 changes: 2 additions & 2 deletions be/src/util/jni-util.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ class JniUtil {
static jclass jni_util_class() { return jni_util_cl_; }
static jmethodID throwable_to_stack_trace_id() { return throwable_to_stack_trace_id_; }

static const int32_t INITIAL_RESERVED_BUFFER_SIZE = 1024;
static const int64_t INITIAL_RESERVED_BUFFER_SIZE = 1024;
// TODO: we need a heuristic strategy to increase buffer size for variable-size output.
static inline int32_t IncreaseReservedBufferSize(int n) {
static inline int64_t IncreaseReservedBufferSize(int n) {
return INITIAL_RESERVED_BUFFER_SIZE << n;
}

Expand Down
102 changes: 100 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "runtime/user_function_cache.h"
#include "util/jni-util.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_string.h"
#include "vec/common/string_ref.h"
#include "vec/core/field.h"
Expand Down Expand Up @@ -55,11 +56,15 @@ struct AggregateJavaUdafData {
input_values_buffer_ptr.reset(new int64_t[num_args]);
input_nulls_buffer_ptr.reset(new int64_t[num_args]);
input_offsets_ptrs.reset(new int64_t[num_args]);
input_array_nulls_buffer_ptr.reset(new int64_t[num_args]);
input_array_string_offsets_ptrs.reset(new int64_t[num_args]);
input_place_ptrs.reset(new int64_t);
output_value_buffer.reset(new int64_t);
output_null_value.reset(new int64_t);
output_offsets_ptr.reset(new int64_t);
output_intermediate_state_ptr.reset(new int64_t);
output_array_null_ptr.reset(new int64_t);
output_array_string_offsets_ptr.reset(new int64_t);
}

~AggregateJavaUdafData() {
Expand Down Expand Up @@ -92,13 +97,21 @@ struct AggregateJavaUdafData {
ctor_params.__set_input_offsets_ptrs((int64_t)input_offsets_ptrs.get());
ctor_params.__set_input_buffer_ptrs((int64_t)input_values_buffer_ptr.get());
ctor_params.__set_input_nulls_ptrs((int64_t)input_nulls_buffer_ptr.get());
ctor_params.__set_input_array_nulls_buffer_ptr(
(int64_t)input_array_nulls_buffer_ptr.get());
ctor_params.__set_input_array_string_offsets_ptrs(
(int64_t)input_array_string_offsets_ptrs.get());

ctor_params.__set_output_buffer_ptr((int64_t)output_value_buffer.get());
ctor_params.__set_input_places_ptr((int64_t)input_place_ptrs.get());

ctor_params.__set_output_null_ptr((int64_t)output_null_value.get());
ctor_params.__set_output_offsets_ptr((int64_t)output_offsets_ptr.get());
ctor_params.__set_output_intermediate_state_ptr(
(int64_t)output_intermediate_state_ptr.get());
ctor_params.__set_output_array_null_ptr((int64_t)output_array_null_ptr.get());
ctor_params.__set_output_array_string_offsets_ptr(
(int64_t)output_array_string_offsets_ptr.get());

jbyteArray ctor_params_bytes;

Expand Down Expand Up @@ -140,6 +153,30 @@ struct AggregateJavaUdafData {
} else if (data_col->is_numeric() || data_col->is_column_decimal()) {
input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(data_col->get_raw_data().data);
} else if (data_col->is_column_array()) {
const ColumnArray* array_col = assert_cast<const ColumnArray*>(data_col);
input_offsets_ptrs.get()[arg_idx] = reinterpret_cast<int64_t>(
array_col->get_offsets_column().get_raw_data().data);
const ColumnNullable& array_nested_nullable =
assert_cast<const ColumnNullable&>(array_col->get_data());
auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr();
auto data_column = array_nested_nullable.get_nested_column_ptr();
input_array_nulls_buffer_ptr.get()[arg_idx] = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(data_column_null_map)
->get_data()
.data());

//need pass FE, nullamp and offset, chars
if (data_column->is_column_string()) {
const ColumnString* col = assert_cast<const ColumnString*>(data_column.get());
input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(col->get_chars().data());
input_array_string_offsets_ptrs.get()[arg_idx] =
reinterpret_cast<int64_t>(col->get_offsets().data());
} else {
input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(data_column->get_raw_data().data);
}
} else {
return Status::InvalidArgument(
strings::Substitute("Java UDAF doesn't support type is $0 now !",
Expand Down Expand Up @@ -210,7 +247,7 @@ struct AggregateJavaUdafData {
ColumnString::Offsets& offsets = \
const_cast<ColumnString::Offsets&>(str_col->get_offsets()); \
int increase_buffer_size = 0; \
int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
chars.resize(buffer_size); \
*output_value_buffer = reinterpret_cast<int64_t>(chars.data()); \
*output_offsets_ptr = reinterpret_cast<int64_t>(offsets.data()); \
Expand All @@ -219,7 +256,7 @@ struct AggregateJavaUdafData {
executor_result_id, to.size() - 1, place); \
while (res != JNI_TRUE) { \
increase_buffer_size++; \
int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
chars.resize(buffer_size); \
*output_value_buffer = reinterpret_cast<int64_t>(chars.data()); \
*output_intermediate_state_ptr = chars.size(); \
Expand All @@ -230,6 +267,63 @@ struct AggregateJavaUdafData {
*output_value_buffer = reinterpret_cast<int64_t>(data_col.get_raw_data().data); \
env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, \
to.size() - 1, place); \
} else if (data_col.is_column_array()) { \
ColumnArray& array_col = assert_cast<ColumnArray&>(data_col); \
ColumnNullable& array_nested_nullable = \
assert_cast<ColumnNullable&>(array_col.get_data()); \
auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr(); \
auto data_column = array_nested_nullable.get_nested_column_ptr(); \
auto& offset_column = array_col.get_offsets_column(); \
int increase_buffer_size = 0; \
int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
*output_offsets_ptr = reinterpret_cast<int64_t>(offset_column.get_raw_data().data); \
data_column_null_map->resize(buffer_size); \
auto& null_map_data = \
assert_cast<ColumnVector<UInt8>*>(data_column_null_map.get())->get_data(); \
*output_array_null_ptr = reinterpret_cast<int64_t>(null_map_data.data()); \
*output_intermediate_state_ptr = buffer_size; \
if (data_column->is_column_string()) { \
ColumnString* str_col = assert_cast<ColumnString*>(data_column.get()); \
ColumnString::Chars& chars = assert_cast<ColumnString::Chars&>(str_col->get_chars()); \
ColumnString::Offsets& offsets = \
assert_cast<ColumnString::Offsets&>(str_col->get_offsets()); \
chars.resize(buffer_size); \
offsets.resize(buffer_size); \
*output_value_buffer = reinterpret_cast<int64_t>(chars.data()); \
*output_array_string_offsets_ptr = reinterpret_cast<int64_t>(offsets.data()); \
jboolean res = env->CallNonvirtualBooleanMethod( \
executor_obj, executor_cl, executor_result_id, to.size() - 1, place); \
while (res != JNI_TRUE) { \
increase_buffer_size++; \
buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
null_map_data.resize(buffer_size); \
chars.resize(buffer_size); \
offsets.resize(buffer_size); \
*output_array_null_ptr = reinterpret_cast<int64_t>(null_map_data.data()); \
*output_value_buffer = reinterpret_cast<int64_t>(chars.data()); \
*output_array_string_offsets_ptr = reinterpret_cast<int64_t>(offsets.data()); \
*output_intermediate_state_ptr = buffer_size; \
res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, \
executor_result_id, to.size() - 1, place); \
} \
} else { \
data_column->resize(buffer_size); \
*output_value_buffer = reinterpret_cast<int64_t>(data_column->get_raw_data().data); \
jboolean res = env->CallNonvirtualBooleanMethod( \
executor_obj, executor_cl, executor_result_id, to.size() - 1, place); \
while (res != JNI_TRUE) { \
increase_buffer_size++; \
buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
null_map_data.resize(buffer_size); \
data_column->resize(buffer_size); \
*output_array_null_ptr = reinterpret_cast<int64_t>(null_map_data.data()); \
*output_value_buffer = \
reinterpret_cast<int64_t>(data_column->get_raw_data().data); \
*output_intermediate_state_ptr = buffer_size; \
res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, \
executor_result_id, to.size() - 1, place); \
} \
} \
} else { \
return Status::InvalidArgument(strings::Substitute( \
"Java UDAF doesn't support return type is $0 now !", result_type->get_name())); \
Expand Down Expand Up @@ -286,11 +380,15 @@ struct AggregateJavaUdafData {
std::unique_ptr<int64_t[]> input_values_buffer_ptr;
std::unique_ptr<int64_t[]> input_nulls_buffer_ptr;
std::unique_ptr<int64_t[]> input_offsets_ptrs;
std::unique_ptr<int64_t[]> input_array_nulls_buffer_ptr;
std::unique_ptr<int64_t[]> input_array_string_offsets_ptrs;
std::unique_ptr<int64_t> input_place_ptrs;
std::unique_ptr<int64_t> output_value_buffer;
std::unique_ptr<int64_t> output_null_value;
std::unique_ptr<int64_t> output_offsets_ptr;
std::unique_ptr<int64_t> output_intermediate_state_ptr;
std::unique_ptr<int64_t> output_array_null_ptr;
std::unique_ptr<int64_t> output_array_string_offsets_ptr;

int argument_size = 0;
std::string serialize_data;
Expand Down
8 changes: 5 additions & 3 deletions be/src/vec/functions/function_java_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
ColumnString::Offsets& offsets = \
const_cast<ColumnString::Offsets&>(str_col->get_offsets()); \
int increase_buffer_size = 0; \
int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
chars.resize(buffer_size); \
offsets.resize(num_rows); \
*(jni_ctx->output_value_buffer) = reinterpret_cast<int64_t>(chars.data()); \
Expand All @@ -211,7 +211,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
nullptr); \
while (jni_ctx->output_intermediate_state_ptr->row_idx < num_rows) { \
increase_buffer_size++; \
int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
chars.resize(buffer_size); \
*(jni_ctx->output_value_buffer) = reinterpret_cast<int64_t>(chars.data()); \
jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \
Expand All @@ -232,7 +232,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
auto data_column = array_nested_nullable.get_nested_column_ptr(); \
auto& offset_column = array_col->get_offsets_column(); \
int increase_buffer_size = 0; \
int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
offset_column.resize(num_rows); \
*(jni_ctx->output_offsets_ptr) = \
reinterpret_cast<int64_t>(offset_column.get_raw_data().data); \
Expand Down Expand Up @@ -263,6 +263,8 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
*(jni_ctx->output_array_null_ptr) = \
reinterpret_cast<int64_t>(null_map_data.data()); \
*(jni_ctx->output_value_buffer) = reinterpret_cast<int64_t>(chars.data()); \
*(jni_ctx->output_array_string_offsets_ptr) = \
reinterpret_cast<int64_t>(offsets.data()); \
jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \
env->CallNonvirtualVoidMethodA(jni_ctx->executor, executor_cl_, \
executor_evaluate_id_, nullptr); \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ public ArrayList<?> arrayTypeInputData(Type type, int argIdx, long row)
}
}

protected abstract long getCurrentOutputOffset(long row);
protected abstract long getCurrentOutputOffset(long row, boolean isArrayType);

/**
* Close the class loader we may have created.
Expand Down Expand Up @@ -615,7 +615,7 @@ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws Ud
case STRING: {
long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr);
byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8);
long offset = getCurrentOutputOffset(row);
long offset = getCurrentOutputOffset(row, false);
if (offset + bytes.length > bufferSize) {
return false;
}
Expand All @@ -637,7 +637,7 @@ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws Ud
}

public boolean arrayTypeOutputData(Object obj, Type type, long row) throws UdfRuntimeException {
long offset = getCurrentOutputOffset(row);
long offset = getCurrentOutputOffset(row, true);
long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr);
long outputNullMapBase = UdfUtils.UNSAFE.getLong(null, outputArrayNullPtr);
long outputBufferBase = UdfUtils.UNSAFE.getLong(null, outputBufferPtr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,14 @@ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws Ud
}

@Override
protected long getCurrentOutputOffset(long row) {
return Integer.toUnsignedLong(
protected long getCurrentOutputOffset(long row, boolean isArrayType) {
if (isArrayType) {
return Integer.toUnsignedLong(
UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * (row - 1)));
} else {
return Integer.toUnsignedLong(
UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1)));
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws Ud
}

@Override
protected long getCurrentOutputOffset(long row) {
protected long getCurrentOutputOffset(long row, boolean isArrayType) {
return outputOffset;
}

Expand Down
Loading

0 comments on commit 4ef4615

Please sign in to comment.