Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
#include "shamrock/solvergraph/FieldRefs.hpp"
#include "shamrock/solvergraph/Indexes.hpp"
#include "shamrock/solvergraph/ScalarsEdge.hpp"
#include "shamrock/solvergraph/SolverGraph.hpp"
#include "shamsys/legacy/log.hpp"

// GSPH solvergraph edges
#include "shammodels/gsph/solvergraph/MergedPatchDataEdge.hpp"
#include "shamtree/CompressedLeafBVH.hpp"
#include "shamtree/KarrasRadixTreeField.hpp"
#include "shamtree/RadixTree.hpp"
Expand Down Expand Up @@ -70,6 +74,8 @@ namespace shammodels::gsph {

using RTree = shamtree::CompressedLeafBVH<Tmorton, Tvec, 3>;

shamrock::solvergraph::SolverGraph solver_graph;

/// Particle counts per patch
std::shared_ptr<shamrock::solvergraph::Indexes<u32>> part_counts;
std::shared_ptr<shamrock::solvergraph::Indexes<u32>> part_counts_with_ghost;
Expand All @@ -91,8 +97,8 @@ namespace shammodels::gsph {
Component<GhostHandle> ghost_handler;
Component<GhostHandleCache> ghost_patch_cache;

/// Merged position-h data for neighbor search
Component<shambase::DistributedData<shamrock::patch::PatchDataLayer>> merged_xyzh;
/// Merged position-h data for neighbor search - managed via SolverGraph
std::shared_ptr<solvergraph::MergedPatchDataEdge> merged_xyzh;

/// Radix trees for neighbor search
Component<shambase::DistributedData<RTree>> merged_pos_trees;
Expand All @@ -105,8 +111,9 @@ namespace shammodels::gsph {
/// Ghost data layout and merged data
std::shared_ptr<shamrock::patch::PatchDataLayerLayout> xyzh_ghost_layout;
std::shared_ptr<shamrock::patch::PatchDataLayerLayout> ghost_layout;
Component<shambase::DistributedData<shamrock::patch::PatchDataLayer>>
merged_patchdata_ghost;

/// Merged patchdata including all ghost fields - managed via SolverGraph
std::shared_ptr<solvergraph::MergedPatchDataEdge> merged_patchdata_ghost;

/// Density field computed via SPH summation
std::shared_ptr<shamrock::solvergraph::Field<Tscal>> density;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// -------------------------------------------------------//
//
// SHAMROCK code for hydrodynamics
// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
//
// -------------------------------------------------------//

#pragma once

/**
* @file MergedPatchDataEdge.hpp
* @author Guo Yansong (guo.yansong.ngy@gmail.com)
* @brief SolverGraph edge for merged PatchDataLayer
*/

#include "shambase/DistributedData.hpp"
#include "shambase/memory.hpp"
#include "shamrock/patch/PatchDataLayer.hpp"
#include "shamrock/solvergraph/IEdgeNamed.hpp"

namespace shammodels::gsph::solvergraph {

/// SolverGraph edge for merged PatchDataLayer storage (local + ghost particles)
class MergedPatchDataEdge : public shamrock::solvergraph::IEdgeNamed {
public:
using IEdgeNamed::IEdgeNamed;

shambase::DistributedData<shamrock::patch::PatchDataLayer> data;

shamrock::patch::PatchDataLayer &get(u64 id) { return data.get(id); }
const shamrock::patch::PatchDataLayer &get(u64 id) const { return data.get(id); }

shambase::DistributedData<shamrock::patch::PatchDataLayer> &get_data() { return data; }
const shambase::DistributedData<shamrock::patch::PatchDataLayer> &get_data() const {
return data;
}

inline virtual void free_alloc() override { data = {}; }
};

} // namespace shammodels::gsph::solvergraph
192 changes: 105 additions & 87 deletions src/shammodels/gsph/src/Solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ void shammodels::gsph::Solver<Tvec, Kern>::init_solver_graph() {
storage.neigh_cache
= std::make_shared<shammodels::sph::solvergraph::NeighCache>(edges::neigh_cache, "neigh");

// Register merged patchdata edges for dependency tracking
storage.merged_xyzh = storage.solver_graph.register_edge(
"merged_xyzh", solvergraph::MergedPatchDataEdge("merged_xyzh", "\\mathbf{xyzh}_{\\rm m}"));

storage.merged_patchdata_ghost = storage.solver_graph.register_edge(
"merged_patchdata_ghost",
solvergraph::MergedPatchDataEdge("merged_patchdata_ghost", "\\mathbb{U}_{\\rm ghost}"));

storage.omega = std::make_shared<shamrock::solvergraph::Field<Tscal>>(1, "omega", "\\Omega");
storage.density = std::make_shared<shamrock::solvergraph::Field<Tscal>>(1, "density", "\\rho");
storage.pressure = std::make_shared<shamrock::solvergraph::Field<Tscal>>(1, "pressure", "P");
Expand Down Expand Up @@ -190,8 +198,8 @@ template<class Tvec, template<class> class Kern>
void shammodels::gsph::Solver<Tvec, Kern>::merge_position_ghost() {
StackEntry stack_loc{};

storage.merged_xyzh.set(
storage.ghost_handler.get().build_comm_merge_positions(storage.ghost_patch_cache.get()));
shambase::get_check_ref(storage.merged_xyzh).data
= (storage.ghost_handler.get().build_comm_merge_positions(storage.ghost_patch_cache.get()));

// Get field indices from xyzh_ghost_layout
const u32 ixyz_ghost
Expand All @@ -201,39 +209,45 @@ void shammodels::gsph::Solver<Tvec, Kern>::merge_position_ghost() {

// Set element counts
shambase::get_check_ref(storage.part_counts).indexes
= storage.merged_xyzh.get().template map<u32>(
[&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
return scheduler().patch_data.get_pdat(id).get_obj_cnt();
});
= shambase::get_check_ref(storage.merged_xyzh)
.get_data()
.template map<u32>([&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
return scheduler().patch_data.get_pdat(id).get_obj_cnt();
});

// Set element counts with ghost
shambase::get_check_ref(storage.part_counts_with_ghost).indexes
= storage.merged_xyzh.get().template map<u32>(
[&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
return mpdat.get_obj_cnt();
});
= shambase::get_check_ref(storage.merged_xyzh)
.get_data()
.template map<u32>([&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
return mpdat.get_obj_cnt();
});

// Attach spans to block coords
shambase::get_check_ref(storage.positions_with_ghosts)
.set_refs(
storage.merged_xyzh.get().template map<std::reference_wrapper<PatchDataField<Tvec>>>(
[&, ixyz_ghost](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
return std::ref(mpdat.get_field<Tvec>(ixyz_ghost));
}));
shambase::get_check_ref(storage.merged_xyzh)
.get_data()
.template map<std::reference_wrapper<PatchDataField<Tvec>>>(
[&, ixyz_ghost](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
return std::ref(mpdat.get_field<Tvec>(ixyz_ghost));
}));

shambase::get_check_ref(storage.hpart_with_ghosts)
.set_refs(
storage.merged_xyzh.get().template map<std::reference_wrapper<PatchDataField<Tscal>>>(
[&, ihpart_ghost](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
return std::ref(mpdat.get_field<Tscal>(ihpart_ghost));
}));
shambase::get_check_ref(storage.merged_xyzh)
.get_data()
.template map<std::reference_wrapper<PatchDataField<Tscal>>>(
[&, ihpart_ghost](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
return std::ref(mpdat.get_field<Tscal>(ihpart_ghost));
}));
}

template<class Tvec, template<class> class Kern>
void shammodels::gsph::Solver<Tvec, Kern>::build_merged_pos_trees() {
StackEntry stack_loc{};

auto &merged_xyzh = storage.merged_xyzh.get();
auto &merged_xyzh = shambase::get_check_ref(storage.merged_xyzh).get_data();
auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();

// Get field index from xyzh_ghost_layout
Expand Down Expand Up @@ -278,7 +292,7 @@ template<class Tvec, template<class> class Kern>
void shammodels::gsph::Solver<Tvec, Kern>::compute_presteps_rint() {
StackEntry stack_loc{};

auto &xyzh_merged = storage.merged_xyzh.get();
auto &xyzh_merged = shambase::get_check_ref(storage.merged_xyzh).get_data();
auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();

storage.rtree_rint_field.set(
Expand Down Expand Up @@ -324,7 +338,7 @@ void shammodels::gsph::Solver<Tvec, Kern>::start_neighbors_cache() {

// Build neighbor cache using tree traversal - same approach as SPH module
auto build_neigh_cache = [&](u64 patch_id) -> shamrock::tree::ObjectCache {
auto &mfield = storage.merged_xyzh.get().get(patch_id);
auto &mfield = shambase::get_check_ref(storage.merged_xyzh).get_data().get(patch_id);

sham::DeviceBuffer<Tvec> &buf_xyz = mfield.template get_field_buf_ref<Tvec>(0);
sham::DeviceBuffer<Tscal> &buf_hpart = mfield.template get_field_buf_ref<Tscal>(1);
Expand Down Expand Up @@ -757,8 +771,8 @@ void shammodels::gsph::Solver<Tvec, Kern>::communicate_merge_ghosts_fields() {
});

// Merge local and ghost data
storage.merged_patchdata_ghost.set(
ghost_handle.template merge_native<PatchDataLayer, PatchDataLayer>(
shambase::get_check_ref(storage.merged_patchdata_ghost).data
= (ghost_handle.template merge_native<PatchDataLayer, PatchDataLayer>(
std::move(interf_pdat),
[&](const shamrock::patch::Patch p, shamrock::patch::PatchDataLayer &pdat) {
PatchDataLayer pdat_new(ghost_layout_ptr);
Expand Down Expand Up @@ -803,7 +817,7 @@ void shammodels::gsph::Solver<Tvec, Kern>::communicate_merge_ghosts_fields() {

template<class Tvec, template<class> class Kern>
void shammodels::gsph::Solver<Tvec, Kern>::reset_merge_ghosts_fields() {
storage.merged_patchdata_ghost.reset();
shambase::get_check_ref(storage.merged_patchdata_ghost).free_alloc();
}

template<class Tvec, template<class> class Kern>
Expand Down Expand Up @@ -856,7 +870,7 @@ void shammodels::gsph::Solver<Tvec, Kern>::compute_omega() {
// 3. If h grows beyond tolerance, signal for cache rebuild
// =========================================================================

auto &merged_xyzh = storage.merged_xyzh.get();
auto &merged_xyzh = shambase::get_check_ref(storage.merged_xyzh).get_data();

// Create field references for the iteration module
// Position spans (from merged xyzh)
Expand Down Expand Up @@ -1129,73 +1143,77 @@ void shammodels::gsph::Solver<Tvec, Kern>::compute_eos_fields() {
soundspeed_field.ensure_sizes(counts_with_ghosts);

// Iterate over merged_patchdata_ghost (includes local + ghost particles)
storage.merged_patchdata_ghost.get().for_each([&](u64 id, PatchDataLayer &mpdat) {
u32 total_elements
= shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id);
if (total_elements == 0)
return;

// Use SPH-summation density from communicated ghost data
sham::DeviceBuffer<Tscal> &buf_density = mpdat.get_field_buf_ref<Tscal>(idensity_interf);
auto &pressure_buf = pressure_field.get_field(id).get_buf();
auto &soundspeed_buf = soundspeed_field.get_field(id).get_buf();
shambase::get_check_ref(storage.merged_patchdata_ghost)
.get_data()
.for_each([&](u64 id, PatchDataLayer &mpdat) {
u32 total_elements
= shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id);
if (total_elements == 0)
return;

sham::DeviceQueue &q = dev_sched->get_queue();
sham::EventList depends_list;
// Use SPH-summation density from communicated ghost data
sham::DeviceBuffer<Tscal> &buf_density
= mpdat.get_field_buf_ref<Tscal>(idensity_interf);
auto &pressure_buf = pressure_field.get_field(id).get_buf();
auto &soundspeed_buf = soundspeed_field.get_field(id).get_buf();

auto density = buf_density.get_read_access(depends_list);
auto pressure = pressure_buf.get_write_access(depends_list);
auto soundspeed = soundspeed_buf.get_write_access(depends_list);
sham::DeviceQueue &q = dev_sched->get_queue();
sham::EventList depends_list;

const Tscal *uint_ptr = nullptr;
if (has_uint) {
uint_ptr = mpdat.get_field_buf_ref<Tscal>(iuint_interf).get_read_access(depends_list);
}
auto density = buf_density.get_read_access(depends_list);
auto pressure = pressure_buf.get_write_access(depends_list);
auto soundspeed = soundspeed_buf.get_write_access(depends_list);

auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
shambase::parallel_for(cgh, total_elements, "compute_eos_gsph", [=](u64 gid) {
u32 i = (u32) gid;
const Tscal *uint_ptr = nullptr;
if (has_uint) {
uint_ptr
= mpdat.get_field_buf_ref<Tscal>(iuint_interf).get_read_access(depends_list);
}

// Use SPH-summation density (from compute_omega, communicated to ghosts)
Tscal rho = density[i];
rho = sycl::max(rho, Tscal(1e-30));

if (has_uint && uint_ptr != nullptr) {
// Adiabatic EOS (reference: g_pre_interaction.cpp line 107)
// P = (\gamma - 1) * \rho * u
Tscal u = uint_ptr[i];
u = sycl::max(u, Tscal(1e-30));
Tscal P = (gamma - Tscal(1.0)) * rho * u;

// Sound speed from internal energy (reference: solver.cpp line 2661)
// c = sqrt(\gamma * (\gamma - 1) * u)
Tscal cs = sycl::sqrt(gamma * (gamma - Tscal(1.0)) * u);

// Clamp to reasonable values
P = sycl::clamp(P, Tscal(1e-30), Tscal(1e30));
cs = sycl::clamp(cs, Tscal(1e-10), Tscal(1e10));

pressure[i] = P;
soundspeed[i] = cs;
} else {
// Isothermal case
Tscal cs = Tscal(1.0);
Tscal P = cs * cs * rho;

pressure[i] = P;
soundspeed[i] = cs;
}
auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
shambase::parallel_for(cgh, total_elements, "compute_eos_gsph", [=](u64 gid) {
u32 i = (u32) gid;

// Use SPH-summation density (from compute_omega, communicated to ghosts)
Tscal rho = density[i];
rho = sycl::max(rho, Tscal(1e-30));

if (has_uint && uint_ptr != nullptr) {
// Adiabatic EOS (reference: g_pre_interaction.cpp line 107)
// P = (\gamma - 1) * \rho * u
Tscal u = uint_ptr[i];
u = sycl::max(u, Tscal(1e-30));
Tscal P = (gamma - Tscal(1.0)) * rho * u;

// Sound speed from internal energy (reference: solver.cpp line 2661)
// c = sqrt(\gamma * (\gamma - 1) * u)
Tscal cs = sycl::sqrt(gamma * (gamma - Tscal(1.0)) * u);

// Clamp to reasonable values
P = sycl::clamp(P, Tscal(1e-30), Tscal(1e30));
cs = sycl::clamp(cs, Tscal(1e-10), Tscal(1e10));

pressure[i] = P;
soundspeed[i] = cs;
} else {
// Isothermal case
Tscal cs = Tscal(1.0);
Tscal P = cs * cs * rho;

pressure[i] = P;
soundspeed[i] = cs;
}
});
});
});

// Complete all buffer event states
buf_density.complete_event_state(e);
if (has_uint) {
mpdat.get_field_buf_ref<Tscal>(iuint_interf).complete_event_state(e);
}
pressure_buf.complete_event_state(e);
soundspeed_buf.complete_event_state(e);
});
// Complete all buffer event states
buf_density.complete_event_state(e);
if (has_uint) {
mpdat.get_field_buf_ref<Tscal>(iuint_interf).complete_event_state(e);
}
pressure_buf.complete_event_state(e);
soundspeed_buf.complete_event_state(e);
});
}

template<class Tvec, template<class> class Kern>
Expand Down Expand Up @@ -1309,7 +1327,7 @@ void shammodels::gsph::Solver<Tvec, Kern>::compute_gradients() {
grad_vy_field.ensure_sizes(counts);
grad_vz_field.ensure_sizes(counts);

auto &merged_xyzh = storage.merged_xyzh.get();
auto &merged_xyzh = shambase::get_check_ref(storage.merged_xyzh).get_data();
auto &neigh_cache = storage.neigh_cache->neigh_cache;

static constexpr Tscal Rkern = Kernel::Rkern;
Expand Down Expand Up @@ -1824,7 +1842,7 @@ shammodels::gsph::TimestepLog shammodels::gsph::Solver<Tvec, Kern>::evolve_once(
reset_presteps_rint();
clear_merged_pos_trees();
reset_merge_ghosts_fields();
storage.merged_xyzh.reset();
shambase::get_check_ref(storage.merged_xyzh).free_alloc();
clear_ghost_cache();
reset_serial_patch_tree();
reset_ghost_handler();
Expand Down
Loading