Skip to content

Commit

Permalink
Fix ST_Intersects join query crash (#7542)
Browse files Browse the repository at this point in the history
Throw an exception for ST_Intersects hash join queries that result in
over the maximum supported number of bounding box overlaps.

Signed-off-by: Misiu Godfrey <misiu.godfrey@kraken.mapd.com>
  • Loading branch information
paul-aiyedun authored and misiugodfrey committed Aug 26, 2024
1 parent 4c75ddb commit 0f531f5
Show file tree
Hide file tree
Showing 13 changed files with 203 additions and 41 deletions.
6 changes: 4 additions & 2 deletions QueryEngine/Execute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3934,7 +3934,8 @@ int32_t Executor::executePlanWithoutGroupBy(
ErrorCode::INTERRUPTED,
ErrorCode::SINGLE_VALUE_FOUND_MULTIPLE_VALUES,
ErrorCode::GEOS,
ErrorCode::WIDTH_BUCKET_INVALID_ARGUMENT>::check(error_code)) {
ErrorCode::WIDTH_BUCKET_INVALID_ARGUMENT,
ErrorCode::BBOX_OVERLAPS_LIMIT_EXCEEDED>::check(error_code)) {
return error_code;
}
if (ra_exe_unit.estimator) {
Expand Down Expand Up @@ -4208,7 +4209,8 @@ int32_t Executor::executePlanWithGroupBy(
ErrorCode::INTERRUPTED,
ErrorCode::SINGLE_VALUE_FOUND_MULTIPLE_VALUES,
ErrorCode::GEOS,
ErrorCode::WIDTH_BUCKET_INVALID_ARGUMENT>::check(error_code)) {
ErrorCode::WIDTH_BUCKET_INVALID_ARGUMENT,
ErrorCode::BBOX_OVERLAPS_LIMIT_EXCEEDED>::check(error_code)) {
return error_code;
}

Expand Down
1 change: 1 addition & 0 deletions QueryEngine/IRCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ std::vector<JoinLoop> Executor::buildJoinLoops(
co, current_hash_table_idx);
domain.values_buffer = matching_set.elements;
domain.element_count = matching_set.count;
domain.error_code = matching_set.error_code;
return domain;
},
/*outer_condition_match=*/
Expand Down
44 changes: 23 additions & 21 deletions QueryEngine/JoinHashTable/BoundingBoxIntersectJoinHashTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "QueryEngine/JoinHashTable/RangeJoinHashTable.h"
#include "QueryEngine/JoinHashTable/Runtime/HashJoinKeyHandlers.h"
#include "QueryEngine/JoinHashTable/Runtime/JoinHashTableGpuUtils.h"
#include "QueryEngine/enums.h"

std::unique_ptr<HashtableRecycler> BoundingBoxIntersectJoinHashTable::hash_table_cache_ =
std::make_unique<HashtableRecycler>(CacheItemType::BBOX_INTERSECT_HT,
Expand Down Expand Up @@ -1669,11 +1670,7 @@ HashJoinMatchingSet BoundingBoxIntersectJoinHashTable::codegenMatchingSet(
one_to_many_ptr =
LL_BUILDER.CreateAdd(one_to_many_ptr, LL_INT(composite_key_dict_size));

// NOTE(jclay): A fixed array of size 200 is allocated on the stack.
// this is likely the maximum value we can do that is safe to use across
// all supported GPU architectures.
const int max_array_size = 200;
const auto arr_type = get_int_array_type(32, max_array_size, LL_CONTEXT);
const auto arr_type = get_int_array_type(32, kMaxBBoxOverlapsCount, LL_CONTEXT);
const auto out_arr_lv = LL_BUILDER.CreateAlloca(arr_type);
out_arr_lv->setName("out_arr");

Expand All @@ -1685,27 +1682,32 @@ HashJoinMatchingSet BoundingBoxIntersectJoinHashTable::codegenMatchingSet(
auto rowid_ptr_i32 =
LL_BUILDER.CreatePointerCast(element_ptr, llvm::Type::getInt32PtrTy(LL_CONTEXT));

const auto error_code_ptr = LL_BUILDER.CreateAlloca(
get_int_type(32, LL_CONTEXT), nullptr, "candidate_rows_error_code");
LL_BUILDER.CreateStore(LL_INT(int32_t(0)), error_code_ptr);

const auto candidate_count_lv = executor_->cgen_state_->emitExternalCall(
"get_candidate_rows",
llvm::Type::getInt64Ty(LL_CONTEXT),
{
rowid_ptr_i32,
LL_INT(max_array_size),
many_to_many_args[1],
LL_INT(0),
LL_FP(inverse_bucket_sizes_for_dimension_[0]),
LL_FP(inverse_bucket_sizes_for_dimension_[1]),
many_to_many_args[0],
LL_INT(key_component_count), // key_component_count
composite_key_dict, // ptr to hash table
LL_INT(getEntryCount()), // entry_count
LL_INT(composite_key_dict_size), // offset_buffer_ptr_offset
LL_INT(getEntryCount() * sizeof(int32_t)) // sub_buff_size
});
{rowid_ptr_i32,
error_code_ptr,
LL_INT(kMaxBBoxOverlapsCount),
many_to_many_args[1],
LL_INT(0),
LL_FP(inverse_bucket_sizes_for_dimension_[0]),
LL_FP(inverse_bucket_sizes_for_dimension_[1]),
many_to_many_args[0],
LL_INT(key_component_count), // key_component_count
composite_key_dict, // ptr to hash table
LL_INT(getEntryCount()), // entry_count
LL_INT(composite_key_dict_size), // offset_buffer_ptr_offset
LL_INT(getEntryCount() * sizeof(int32_t)), // sub_buff_size
LL_INT(int32_t(heavyai::ErrorCode::BBOX_OVERLAPS_LIMIT_EXCEEDED))});

const auto slot_lv = LL_INT(int64_t(0));

return {rowid_ptr_i32, candidate_count_lv, slot_lv};
auto error_code_lv = LL_BUILDER.CreateLoad(
error_code_ptr->getType()->getPointerElementType(), error_code_ptr);
return {rowid_ptr_i32, candidate_count_lv, slot_lv, error_code_lv};
} else {
VLOG(1) << "Building codegenMatchingSet for Baseline";
// TODO: duplicated w/ BaselineJoinHashTable -- push into the hash table builder?
Expand Down
5 changes: 5 additions & 0 deletions QueryEngine/JoinHashTable/BoundingBoxIntersectJoinHashTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
#include "QueryEngine/JoinHashTable/BaselineJoinHashTable.h"
#include "QueryEngine/JoinHashTable/HashJoin.h"

// NOTE(jclay): A fixed array of size 200 is allocated on the stack.
// this is likely the maximum value we can do that is safe to use across
// all supported GPU architectures.
constexpr int32_t kMaxBBoxOverlapsCount{200};

class BoundingBoxIntersectJoinHashTable : public HashJoin {
public:
BoundingBoxIntersectJoinHashTable(const std::shared_ptr<Analyzer::BinOper> condition,
Expand Down
2 changes: 1 addition & 1 deletion QueryEngine/JoinHashTable/HashJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ HashJoinMatchingSet HashJoin::codegenMatchingSet(
rowid_base_i32->getType()->getScalarType()->getPointerElementType(),
rowid_base_i32,
slot_lv);
return {rowid_ptr_i32, row_count_lv, slot_lv};
return {rowid_ptr_i32, row_count_lv, slot_lv, nullptr};
}

llvm::Value* HashJoin::codegenHashTableLoad(const size_t table_idx, Executor* executor) {
Expand Down
1 change: 1 addition & 0 deletions QueryEngine/JoinHashTable/HashJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ struct HashJoinMatchingSet {
llvm::Value* elements;
llvm::Value* count;
llvm::Value* slot;
llvm::Value* error_code;
};

struct CompositeKeyInfo {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ struct Bounds {
/// The number of row ids in this array is returned.
extern "C" RUNTIME_EXPORT NEVER_INLINE DEVICE int64_t
get_candidate_rows(int32_t* out_arr,
int32_t* error_code,
const uint32_t max_arr_size,
const int8_t* range_bytes,
const int32_t range_component_index,
Expand All @@ -298,7 +299,8 @@ get_candidate_rows(int32_t* out_arr,
int64_t* hash_table_ptr,
const int64_t entry_count,
const int64_t offset_buffer_ptr_offset,
const int64_t sub_buff_size) {
const int64_t sub_buff_size,
const int32_t max_bbox_overlaps_error_code) {
const auto range = reinterpret_cast<const double*>(range_bytes);

size_t elem_count = 0;
Expand All @@ -325,7 +327,10 @@ get_candidate_rows(int32_t* out_arr,
for (int64_t j = 0; j < buffer_range.element_count; j++) {
const auto rowid = buffer_range.buffer[j];
elem_count += insert_sorted(out_arr, elem_count, rowid);
assert(max_arr_size >= elem_count);
if (elem_count > max_arr_size) {
*error_code = max_bbox_overlaps_error_code;
return 0;
}
}
}
}
Expand Down
18 changes: 17 additions & 1 deletion QueryEngine/LoopControlFlow/JoinLoop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,25 @@ llvm::BasicBlock* JoinLoop::codegen(
}
builder.CreateStore(ll_int(int64_t(0), context), iteration_counter_ptr);
const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);

const auto head_bb = llvm::BasicBlock::Create(
context, "ub_iter_head_" + join_loop.name_, parent_func);
builder.CreateBr(head_bb);

if (iteration_domain.error_code) {
cgen_state->needs_error_check_ = true;
auto ub_iter_success_code = ll_int(int32_t(0), context);
const auto ub_iter_error_condition =
builder.CreateICmpEQ(iteration_domain.error_code, ub_iter_success_code);
auto error_bb =
llvm::BasicBlock::Create(context, "ub_iter_error_exit", parent_func);
builder.CreateCondBr(ub_iter_error_condition, head_bb, error_bb);

builder.SetInsertPoint(error_bb);
builder.CreateRet(iteration_domain.error_code);
} else {
builder.CreateBr(head_bb);
}

builder.SetInsertPoint(head_bb);
llvm::Value* iteration_counter =
builder.CreateLoad(iteration_counter_ptr->getType()->getPointerElementType(),
Expand Down
1 change: 1 addition & 0 deletions QueryEngine/LoopControlFlow/JoinLoop.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct JoinLoopDomain {
llvm::Value* slot_lookup_result; // for Singleton
};
llvm::Value* values_buffer; // used for Set
llvm::Value* error_code;
};

// Any join is logically a loop. Hash joins just limit the domain of iteration,
Expand Down
2 changes: 1 addition & 1 deletion QueryEngine/NativeCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ declare i64 @get_composite_key_index_64(i64*, i64, i64*, i64);
declare i64 @get_bucket_key_for_range_compressed(i8*, i64, double);
declare i64 @get_bucket_key_for_range_double(i8*, i64, double);
declare i32 @get_num_buckets_for_bounds(i8*, i32, double, double);
declare i64 @get_candidate_rows(i32*, i32, i8*, i32, double, double, i32, i64, i64*, i64, i64, i64);
declare i64 @get_candidate_rows(i32*, i32*, i32, i8*, i32, double, double, i32, i64, i64*, i64, i64, i64, i32);
declare i64 @agg_count_shared(i64*, i64);
declare i64 @agg_count_skip_val_shared(i64*, i64, i64);
declare i32 @agg_count_int32_shared(i32*, i32);
Expand Down
25 changes: 15 additions & 10 deletions QueryEngine/RuntimeFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1856,6 +1856,7 @@ extern "C" RUNTIME_EXPORT ALWAYS_INLINE void record_error_code(const int32_t err
}
}

// error_codes points to an array on GPU, but a single value on CPU.
extern "C" RUNTIME_EXPORT ALWAYS_INLINE int32_t get_error_code(int32_t* error_codes) {
return error_codes[pos_start_impl(nullptr)];
}
Expand Down Expand Up @@ -2382,7 +2383,7 @@ extern "C" RUNTIME_EXPORT NEVER_INLINE void linear_probabilistic_count(

// First 3 parameters are output, the rest are input.
extern "C" RUNTIME_EXPORT NEVER_INLINE void query_stub_hoisted_literals(
int32_t* error_code,
int32_t* error_codes,
int32_t* total_matched,
int64_t** out,
const uint32_t frag_idx,
Expand All @@ -2396,15 +2397,15 @@ extern "C" RUNTIME_EXPORT NEVER_INLINE void query_stub_hoisted_literals(
const int64_t* join_hash_tables,
const int8_t* row_func_mgr) {
#ifndef _WIN32
assert(error_code || total_matched || out || frag_idx || row_index_resume ||
assert(error_codes || total_matched || out || frag_idx || row_index_resume ||
col_buffers || literals || num_rows || frag_row_offsets || max_matched ||
init_agg_value || join_hash_tables || row_func_mgr);
#endif
}

// First 3 parameters are output, the rest are input.
extern "C" RUNTIME_EXPORT void multifrag_query_hoisted_literals(
int32_t* error_code,
int32_t* error_codes,
int32_t* total_matched,
int64_t** out,
const uint32_t* num_fragments_ptr,
Expand All @@ -2421,8 +2422,10 @@ extern "C" RUNTIME_EXPORT void multifrag_query_hoisted_literals(
uint32_t const num_fragments = *num_fragments_ptr;
uint32_t const num_tables = *num_tables_ptr;
// num_fragments_ptr and num_tables_ptr are replaced by frag_idx when passed below.
for (uint32_t frag_idx = 0; frag_idx < num_fragments; ++frag_idx) {
query_stub_hoisted_literals(error_code,
for (uint32_t frag_idx = 0;
frag_idx < num_fragments && get_error_code(error_codes) == 0;
++frag_idx) {
query_stub_hoisted_literals(error_codes,
total_matched,
out,
frag_idx,
Expand All @@ -2439,7 +2442,7 @@ extern "C" RUNTIME_EXPORT void multifrag_query_hoisted_literals(
}

// First 3 parameters are output, the rest are input.
extern "C" RUNTIME_EXPORT NEVER_INLINE void query_stub(int32_t* error_code,
extern "C" RUNTIME_EXPORT NEVER_INLINE void query_stub(int32_t* error_codes,
int32_t* total_matched,
int64_t** out,
const uint32_t frag_idx,
Expand All @@ -2452,14 +2455,14 @@ extern "C" RUNTIME_EXPORT NEVER_INLINE void query_stub(int32_t* error_code,
const int64_t* join_hash_tables,
const int8_t* row_func_mgr) {
#ifndef _WIN32
assert(error_code || total_matched || out || frag_idx || row_index_resume ||
assert(error_codes || total_matched || out || frag_idx || row_index_resume ||
col_buffers || num_rows || frag_row_offsets || max_matched || init_agg_value ||
join_hash_tables || row_func_mgr);
#endif
}

// First 3 parameters are output, the rest are input.
extern "C" RUNTIME_EXPORT void multifrag_query(int32_t* error_code,
extern "C" RUNTIME_EXPORT void multifrag_query(int32_t* error_codes,
int32_t* total_matched,
int64_t** out,
const uint32_t* num_fragments_ptr,
Expand All @@ -2475,8 +2478,10 @@ extern "C" RUNTIME_EXPORT void multifrag_query(int32_t* error_code,
uint32_t const num_fragments = *num_fragments_ptr;
uint32_t const num_tables = *num_tables_ptr;
// num_fragments_ptr and num_tables_ptr are replaced by frag_idx when passed below.
for (uint32_t frag_idx = 0; frag_idx < num_fragments; ++frag_idx) {
query_stub(error_code,
for (uint32_t frag_idx = 0;
frag_idx < num_fragments && get_error_code(error_codes) == 0;
++frag_idx) {
query_stub(error_codes,
total_matched,
out,
frag_idx,
Expand Down
4 changes: 3 additions & 1 deletion QueryEngine/enums.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ HEAVYAI_DEFINE_ENUM_CLASS_WITH_DESCRIPTIONS(
(SINGLE_VALUE_FOUND_MULTIPLE_VALUES, "Multiple distinct values encountered"),
(GEOS, "Geo-related error"),
(WIDTH_BUCKET_INVALID_ARGUMENT,
"Arguments of WIDTH_BUCKET function does not satisfy the condition"))
"Arguments of WIDTH_BUCKET function does not satisfy the condition"),
(BBOX_OVERLAPS_LIMIT_EXCEEDED,
"Maximum supported number of bounding box overlaps exceeded"))

HEAVYAI_DEFINE_ENUM_CLASS(QueryDescriptionType,
GroupByPerfectHash,
Expand Down
Loading

0 comments on commit 0f531f5

Please sign in to comment.