Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/shamalgs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ set(Sources
src/primitives/gen_buffer_index.cpp
src/primitives/segmented_sort_in_place.cpp
src/primitives/append_subset_to.cpp
src/primitives/stream_compact.cpp
)

if(SHAMROCK_USE_SHARED_LIB)
Expand Down
26 changes: 26 additions & 0 deletions src/shamalgs/include/shamalgs/primitives/stream_compact.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// -------------------------------------------------------//
//
// SHAMROCK code for hydrodynamics
// Copyright (c) 2021-2025 Timothée David--Cléris <tim.shamrock@proton.me>
// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
//
// -------------------------------------------------------//

#pragma once

/**
* @file sort_by_keys.hpp
* @author Timothée David--Cléris (tim.shamrock@proton.me)
* @brief Sort by keys algorithms
Comment on lines +13 to +15

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The documentation in this file header appears to be copied from sort_by_keys.hpp. Please update it to accurately describe the contents of stream_compact.hpp and the stream_compact function.

 * @file stream_compact.hpp
 * @author Timothée David--Cléris (tim.shamrock@proton.me)
 * @brief Stream compaction primitive

*
*/

#include "shambackends/DeviceBuffer.hpp"

namespace shamalgs::primitives {

sham::DeviceBuffer<u32> stream_compact(
const sham::DeviceScheduler_ptr &sched, sham::DeviceBuffer<u32> && buf_flags, u32 len);

} // namespace shamalgs::primitives
59 changes: 59 additions & 0 deletions src/shamalgs/src/primitives/stream_compact.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// -------------------------------------------------------//
//
// SHAMROCK code for hydrodynamics
// Copyright (c) 2021-2025 Timothée David--Cléris <tim.shamrock@proton.me>
// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
//
// -------------------------------------------------------//

/**
* @file sort_by_keys.hpp
* @author Timothée David--Cléris (tim.shamrock@proton.me)
* @brief Sort by keys algorithms
Comment on lines +11 to +13

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The documentation in this file header seems to be copied from sort_by_keys.hpp. Please update it to describe the stream_compact implementation in this file.

Suggested change
* @file sort_by_keys.hpp
* @author Timothée David--Cléris (tim.shamrock@proton.me)
* @brief Sort by keys algorithms
* @file stream_compact.cpp
* @author Timothée David--Cléris (tim.shamrock@proton.me)
* @brief Implementation of the stream compaction primitive

*
*/

#include "shamalgs/primitives/stream_compact.hpp"
#include "shamalgs/primitives/scan_exclusive_sum_in_place.hpp"
#include "shambackends/DeviceBuffer.hpp"
#include "shambackends/kernel_call.hpp"

namespace shamalgs::primitives {

sham::DeviceBuffer<u32> stream_compact(
const sham::DeviceScheduler_ptr &sched, sham::DeviceBuffer<u32> &&buf_flags, u32 len) {

if (buf_flags.get_size() < len + 1)
shambase::throw_with_loc<std::invalid_argument>(shambase::format(
"buf_flags.get_size() < len+1\n buf_flags.get_size() = {}, len = {}",
buf_flags.get_size(),
len));

shamalgs::primitives::scan_exclusive_sum_in_place(buf_flags, len + 1);

u32 new_len = buf_flags.get_val_at_idx(len);

sham::DeviceBuffer<u32> index_map(new_len, sched);

if (new_len > 0) {
sham::kernel_call(
sched->get_queue(),
sham::MultiRef{buf_flags},
sham::MultiRef{index_map},
len + 1,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The kernel is launched with len + 1 work items, which can lead to an out-of-bounds memory access. The loop variable idx will range from 0 to len. When idx is len, the expression sum_vals[idx + 1] on line 48 will attempt to access sum_vals[len + 1]. However, buf_flags (which sum_vals points to) is guaranteed to have a size of at least len + 1, meaning valid indices are from 0 to len. Accessing index len + 1 is out of bounds.

The kernel is intended to process len flags, so it should be launched with len work items. This will make idx range from 0 to len - 1, and the maximum index accessed will be sum_vals[len], which is within the buffer's bounds.

Suggested change
len + 1,
len,

[](u32 idx, const u32 *sum_vals, u32 *new_idx) {
u32 current_val = sum_vals[idx];

bool should_write = (current_val < sum_vals[idx + 1]);

if (should_write) {
new_idx[current_val] = idx;
}
});
}

return index_map;
}

} // namespace shamalgs::primitives
69 changes: 60 additions & 9 deletions src/shammodels/sph/src/BasicSPHGhosts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ int main(){
*/

#include "shambase/exception.hpp"
#include "shamalgs/primitives/scan_exclusive_sum_in_place.hpp"
#include "shamalgs/primitives/stream_compact.hpp"
#include "shamcomm/collectives.hpp"
#include "shammodels/sph/BasicSPHGhosts.hpp"
#include <functional>
Expand Down Expand Up @@ -460,25 +462,74 @@ auto BasicSPHGhostHandler<vec>::gen_id_table_interfaces(GeneratorMap &&gen)

std::map<u64, f64> send_count_stats;

auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
sham::DeviceQueue &q = shambase::get_check_ref(dev_sched).get_queue();

struct MultikernelInput {
u64 sender;
u64 receiver;
InterfaceBuildInfos &build;
PatchDataField<vec> &xyz;
sham::DeviceBuffer<u32> mask;
};

std::vector<MultikernelInput> multikernel_input;

gen.for_each([&](u64 sender, u64 receiver, InterfaceBuildInfos &build) {
shamrock::patch::PatchDataLayer &src = sched.patch_data.get_pdat(sender);
PatchDataField<vec> &xyz = src.get_field<vec>(0);
if (xyz.get_obj_cnt() == 0) {
return;
}
multikernel_input.push_back(
{sender,
receiver,
build,
xyz,
sham::DeviceBuffer<u32>(xyz.get_obj_cnt() + 1, dev_sched)});
});

sham::DeviceBuffer<u32> idxs_res = xyz.get_ids_where(
[](auto access, u32 id, vec vmin, vec vmax) {
return Patch::is_in_patch_converted(access[id], vmin, vmax);
},
build.cut_volume.lower,
build.cut_volume.upper);
for (auto &input : multikernel_input) {

auto &sender = input.sender;
auto &receiver = input.receiver;
auto &build = input.build;
auto &xyz = input.xyz;
auto &mask = input.mask;
auto obj_cnt = xyz.get_obj_cnt();

sham::kernel_call(
q,
sham::MultiRef{xyz.get_buf()},
sham::MultiRef{mask},
obj_cnt + 1,
[vmin = build.cut_volume.lower, vmax = build.cut_volume.upper, obj_cnt](
u32 id, const vec *__restrict acc, u32 *__restrict acc_mask) {
acc_mask[id]
= (id < obj_cnt) ? Patch::is_in_patch_converted(acc[id], vmin, vmax) : 0;
});
}

for (auto &input : multikernel_input) {

auto &sender = input.sender;
auto &receiver = input.receiver;
auto &build = input.build;
auto &xyz = input.xyz;
auto &mask = input.mask;
auto obj_cnt = xyz.get_obj_cnt();

sham::DeviceBuffer<u32> idxs_res
= shamalgs::primitives::stream_compact(dev_sched, std::move(mask), obj_cnt);

u32 pcnt = idxs_res.get_size();

// prevent sending empty patches
if (pcnt == 0) {
return;
continue;
}

f64 ratio = f64(pcnt) / f64(src.get_obj_cnt());
f64 ratio = f64(pcnt) / f64(xyz.get_obj_cnt());

shamlog_debug_ln(
"InterfaceGen",
Expand All @@ -494,7 +545,7 @@ auto BasicSPHGhostHandler<vec>::gen_id_table_interfaces(GeneratorMap &&gen)
res.add_obj(sender, receiver, InterfaceIdTable{build, std::move(idxs_res), ratio});

send_count_stats[sender] += ratio;
});
}

bool has_warn = false;

Expand Down
Loading