Skip to content

Conversation

@Guo-astro
Copy link
Contributor

Summary

Add MergedPatchDataEdge for merged PatchDataLayer storage management via solvergraph.

Part of GSPH solvergraph migration (1 object per PR per maintainer request).

Changes:

  • Add MergedPatchDataEdge.hpp for merged patchdata edge
  • Add solver_graph and merged patchdata edges to SolverStorage
  • Register edges in init_solver_graph()
  • Update all merged_xyzh usages to solvergraph pattern
  • Update all merged_patchdata_ghost usages to solvergraph pattern
  • Update UpdateDerivs.cpp to use solvergraph pattern

Test plan

  • Build passes
  • GSPH Sod shock tube test passes

Add MergedPatchDataEdge for merged PatchDataLayer storage management
via solvergraph.

Changes:
- Add MergedPatchDataEdge.hpp for merged patchdata edge
- Add solver_graph and merged patchdata edges to SolverStorage
- Register edges in init_solver_graph()
- Update all merged_xyzh usages to solvergraph pattern
- Update all merged_patchdata_ghost usages to solvergraph pattern
- Update UpdateDerivs.cpp to use solvergraph pattern

Part of GSPH solvergraph migration (1 object per PR).
@gemini-code-assist
Copy link

Summary of Changes

Hello @Guo-astro, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the handling of merged particle data (merged_xyzh and merged_patchdata_ghost) within the GSPH module. By integrating these data structures into the solvergraph system via a new MergedPatchDataEdge, the PR aims to centralize storage management and dependency tracking, contributing to the ongoing solvergraph migration effort.

Highlights

  • New SolverGraph Edge: Introduced MergedPatchDataEdge to encapsulate DistributedData<PatchDataLayer>, providing a standardized way to manage this data within the solvergraph system.
  • SolverStorage Migration: Migrated merged_xyzh and merged_patchdata_ghost in SolverStorage to utilize the new MergedPatchDataEdge type, placing them under the management of the SolverGraph.
  • SolverGraph Integration: Registered the new MergedPatchDataEdge instances for merged_xyzh and merged_patchdata_ghost with the SolverGraph during the init_solver_graph() process.
  • Codebase Adaptation: Updated all references and operations on merged_xyzh and merged_patchdata_ghost across Solver.cpp and UpdateDerivs.cpp to conform to the solvergraph pattern, including changes to data access and memory management (free_alloc).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully migrates merged_xyzh and merged_patchdata_ghost to the solvergraph pattern. This involves introducing a new MergedPatchDataEdge class to encapsulate the DistributedData<PatchDataLayer>, updating SolverStorage to use std::shared_ptr for these edges, and registering them with the SolverGraph. All usage sites have been correctly updated to reflect the new access patterns, and resource cleanup now properly utilizes the free_alloc() method of the new edge type. The changes are well-implemented and consistent with the stated goal of integrating these data structures into the SolverGraph framework. No issues of medium, high, or critical severity were found during the review.

@github-actions
Copy link
Contributor

Thanks @Guo-astro for opening this PR!

You can do multiple things directly here:
1 - Comment pre-commit.ci run to run pre-commit checks.
2 - Comment pre-commit.ci autofix to apply fixes.
3 - Add label autofix.ci to fix authorship & pre-commit for every commit made.
4 - Add label light-ci to only trigger a reduced & faster version of the CI (need the full one before merge).
5 - Add label trigger-ci to create an empty commit to trigger the CI.

Once the workflow completes a message will appear displaying informations related to the run.

Also the PR gets automatically reviewed by gemini, you can:
1 - Comment /gemini review to trigger a review
2 - Comment /gemini summary for a summary
3 - Tag it using @gemini-code-assist either in the PR or in review comments on files

@github-actions
Copy link
Contributor

Workflow report

workflow report corresponding to commit d93cebd
Commiter email is guo.yansong.ngy@gmail.com

Pre-commit check report

Some failures were detected in base source checks checks.
Check the On PR / Linting / Base source checks (pull_request) job in the tests for more detailled output

Suggested changes

Detailed changes :
diff --git a/src/shammodels/gsph/src/Solver.cpp b/src/shammodels/gsph/src/Solver.cpp
index 8cab7f87..6c9c09f8 100644
--- a/src/shammodels/gsph/src/Solver.cpp
+++ b/src/shammodels/gsph/src/Solver.cpp
@@ -89,8 +89,7 @@ void shammodels::gsph::Solver<Tvec, Kern>::init_solver_graph() {
 
     // Register merged patchdata edges for dependency tracking
     storage.merged_xyzh = storage.solver_graph.register_edge(
-        "merged_xyzh",
-        solvergraph::MergedPatchDataEdge("merged_xyzh", "\\mathbf{xyzh}_{\\rm m}"));
+        "merged_xyzh", solvergraph::MergedPatchDataEdge("merged_xyzh", "\\mathbf{xyzh}_{\\rm m}"));
 
     storage.merged_patchdata_ghost = storage.solver_graph.register_edge(
         "merged_patchdata_ghost",
@@ -199,8 +198,8 @@ template<class Tvec, template<class> class Kern>
 void shammodels::gsph::Solver<Tvec, Kern>::merge_position_ghost() {
     StackEntry stack_loc{};
 
-    shambase::get_check_ref(storage.merged_xyzh).data = (
-        storage.ghost_handler.get().build_comm_merge_positions(storage.ghost_patch_cache.get()));
+    shambase::get_check_ref(storage.merged_xyzh).data
+        = (storage.ghost_handler.get().build_comm_merge_positions(storage.ghost_patch_cache.get()));
 
     // Get field indices from xyzh_ghost_layout
     const u32 ixyz_ghost
@@ -210,32 +209,38 @@ void shammodels::gsph::Solver<Tvec, Kern>::merge_position_ghost() {
 
     // Set element counts
     shambase::get_check_ref(storage.part_counts).indexes
-        = shambase::get_check_ref(storage.merged_xyzh).get_data().template map<u32>(
-            [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
-                return scheduler().patch_data.get_pdat(id).get_obj_cnt();
-            });
+        = shambase::get_check_ref(storage.merged_xyzh)
+              .get_data()
+              .template map<u32>([&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
+                  return scheduler().patch_data.get_pdat(id).get_obj_cnt();
+              });
 
     // Set element counts with ghost
     shambase::get_check_ref(storage.part_counts_with_ghost).indexes
-        = shambase::get_check_ref(storage.merged_xyzh).get_data().template map<u32>(
-            [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
-                return mpdat.get_obj_cnt();
-            });
+        = shambase::get_check_ref(storage.merged_xyzh)
+              .get_data()
+              .template map<u32>([&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
+                  return mpdat.get_obj_cnt();
+              });
 
     // Attach spans to block coords
     shambase::get_check_ref(storage.positions_with_ghosts)
         .set_refs(
-            shambase::get_check_ref(storage.merged_xyzh).get_data().template map<std::reference_wrapper<PatchDataField<Tvec>>>(
-                [&, ixyz_ghost](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
-                    return std::ref(mpdat.get_field<Tvec>(ixyz_ghost));
-                }));
+            shambase::get_check_ref(storage.merged_xyzh)
+                .get_data()
+                .template map<std::reference_wrapper<PatchDataField<Tvec>>>(
+                    [&, ixyz_ghost](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
+                        return std::ref(mpdat.get_field<Tvec>(ixyz_ghost));
+                    }));
 
     shambase::get_check_ref(storage.hpart_with_ghosts)
         .set_refs(
-            shambase::get_check_ref(storage.merged_xyzh).get_data().template map<std::reference_wrapper<PatchDataField<Tscal>>>(
-                [&, ihpart_ghost](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
-                    return std::ref(mpdat.get_field<Tscal>(ihpart_ghost));
-                }));
+            shambase::get_check_ref(storage.merged_xyzh)
+                .get_data()
+                .template map<std::reference_wrapper<PatchDataField<Tscal>>>(
+                    [&, ihpart_ghost](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
+                        return std::ref(mpdat.get_field<Tscal>(ihpart_ghost));
+                    }));
 }
 
 template<class Tvec, template<class> class Kern>
@@ -766,8 +771,8 @@ void shammodels::gsph::Solver<Tvec, Kern>::communicate_merge_ghosts_fields() {
     });
 
     // Merge local and ghost data
-    shambase::get_check_ref(storage.merged_patchdata_ghost).data = (
-        ghost_handle.template merge_native<PatchDataLayer, PatchDataLayer>(
+    shambase::get_check_ref(storage.merged_patchdata_ghost).data
+        = (ghost_handle.template merge_native<PatchDataLayer, PatchDataLayer>(
             std::move(interf_pdat),
             [&](const shamrock::patch::Patch p, shamrock::patch::PatchDataLayer &pdat) {
                 PatchDataLayer pdat_new(ghost_layout_ptr);
@@ -1138,73 +1143,77 @@ void shammodels::gsph::Solver<Tvec, Kern>::compute_eos_fields() {
     soundspeed_field.ensure_sizes(counts_with_ghosts);
 
     // Iterate over merged_patchdata_ghost (includes local + ghost particles)
-    shambase::get_check_ref(storage.merged_patchdata_ghost).get_data().for_each([&](u64 id, PatchDataLayer &mpdat) {
-        u32 total_elements
-            = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id);
-        if (total_elements == 0)
-            return;
-
-        // Use SPH-summation density from communicated ghost data
-        sham::DeviceBuffer<Tscal> &buf_density = mpdat.get_field_buf_ref<Tscal>(idensity_interf);
-        auto &pressure_buf                     = pressure_field.get_field(id).get_buf();
-        auto &soundspeed_buf                   = soundspeed_field.get_field(id).get_buf();
+    shambase::get_check_ref(storage.merged_patchdata_ghost)
+        .get_data()
+        .for_each([&](u64 id, PatchDataLayer &mpdat) {
+            u32 total_elements
+                = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id);
+            if (total_elements == 0)
+                return;
 
-        sham::DeviceQueue &q = dev_sched->get_queue();
-        sham::EventList depends_list;
+            // Use SPH-summation density from communicated ghost data
+            sham::DeviceBuffer<Tscal> &buf_density
+                = mpdat.get_field_buf_ref<Tscal>(idensity_interf);
+            auto &pressure_buf   = pressure_field.get_field(id).get_buf();
+            auto &soundspeed_buf = soundspeed_field.get_field(id).get_buf();
 
-        auto density    = buf_density.get_read_access(depends_list);
-        auto pressure   = pressure_buf.get_write_access(depends_list);
-        auto soundspeed = soundspeed_buf.get_write_access(depends_list);
+            sham::DeviceQueue &q = dev_sched->get_queue();
+            sham::EventList depends_list;
 
-        const Tscal *uint_ptr = nullptr;
-        if (has_uint) {
-            uint_ptr = mpdat.get_field_buf_ref<Tscal>(iuint_interf).get_read_access(depends_list);
-        }
+            auto density    = buf_density.get_read_access(depends_list);
+            auto pressure   = pressure_buf.get_write_access(depends_list);
+            auto soundspeed = soundspeed_buf.get_write_access(depends_list);
 
-        auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
-            shambase::parallel_for(cgh, total_elements, "compute_eos_gsph", [=](u64 gid) {
-                u32 i = (u32) gid;
+            const Tscal *uint_ptr = nullptr;
+            if (has_uint) {
+                uint_ptr
+                    = mpdat.get_field_buf_ref<Tscal>(iuint_interf).get_read_access(depends_list);
+            }
 
-                // Use SPH-summation density (from compute_omega, communicated to ghosts)
-                Tscal rho = density[i];
-                rho       = sycl::max(rho, Tscal(1e-30));
-
-                if (has_uint && uint_ptr != nullptr) {
-                    // Adiabatic EOS (reference: g_pre_interaction.cpp line 107)
-                    // P = (\gamma - 1) * \rho * u
-                    Tscal u = uint_ptr[i];
-                    u       = sycl::max(u, Tscal(1e-30));
-                    Tscal P = (gamma - Tscal(1.0)) * rho * u;
-
-                    // Sound speed from internal energy (reference: solver.cpp line 2661)
-                    // c = sqrt(\gamma * (\gamma - 1) * u)
-                    Tscal cs = sycl::sqrt(gamma * (gamma - Tscal(1.0)) * u);
-
-                    // Clamp to reasonable values
-                    P  = sycl::clamp(P, Tscal(1e-30), Tscal(1e30));
-                    cs = sycl::clamp(cs, Tscal(1e-10), Tscal(1e10));
-
-                    pressure[i]   = P;
-                    soundspeed[i] = cs;
-                } else {
-                    // Isothermal case
-                    Tscal cs = Tscal(1.0);
-                    Tscal P  = cs * cs * rho;
-
-                    pressure[i]   = P;
-                    soundspeed[i] = cs;
-                }
+            auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
+                shambase::parallel_for(cgh, total_elements, "compute_eos_gsph", [=](u64 gid) {
+                    u32 i = (u32) gid;
+
+                    // Use SPH-summation density (from compute_omega, communicated to ghosts)
+                    Tscal rho = density[i];
+                    rho       = sycl::max(rho, Tscal(1e-30));
+
+                    if (has_uint && uint_ptr != nullptr) {
+                        // Adiabatic EOS (reference: g_pre_interaction.cpp line 107)
+                        // P = (\gamma - 1) * \rho * u
+                        Tscal u = uint_ptr[i];
+                        u       = sycl::max(u, Tscal(1e-30));
+                        Tscal P = (gamma - Tscal(1.0)) * rho * u;
+
+                        // Sound speed from internal energy (reference: solver.cpp line 2661)
+                        // c = sqrt(\gamma * (\gamma - 1) * u)
+                        Tscal cs = sycl::sqrt(gamma * (gamma - Tscal(1.0)) * u);
+
+                        // Clamp to reasonable values
+                        P  = sycl::clamp(P, Tscal(1e-30), Tscal(1e30));
+                        cs = sycl::clamp(cs, Tscal(1e-10), Tscal(1e10));
+
+                        pressure[i]   = P;
+                        soundspeed[i] = cs;
+                    } else {
+                        // Isothermal case
+                        Tscal cs = Tscal(1.0);
+                        Tscal P  = cs * cs * rho;
+
+                        pressure[i]   = P;
+                        soundspeed[i] = cs;
+                    }
+                });
             });
-        });
 
-        // Complete all buffer event states
-        buf_density.complete_event_state(e);
-        if (has_uint) {
-            mpdat.get_field_buf_ref<Tscal>(iuint_interf).complete_event_state(e);
-        }
-        pressure_buf.complete_event_state(e);
-        soundspeed_buf.complete_event_state(e);
-    });
+            // Complete all buffer event states
+            buf_density.complete_event_state(e);
+            if (has_uint) {
+                mpdat.get_field_buf_ref<Tscal>(iuint_interf).complete_event_state(e);
+            }
+            pressure_buf.complete_event_state(e);
+            soundspeed_buf.complete_event_state(e);
+        });
 }
 
 template<class Tvec, template<class> class Kern>
diff --git a/src/shammodels/gsph/src/modules/UpdateDerivs.cpp b/src/shammodels/gsph/src/modules/UpdateDerivs.cpp
index 217e11e2..c1efbb2b 100644
--- a/src/shammodels/gsph/src/modules/UpdateDerivs.cpp
+++ b/src/shammodels/gsph/src/modules/UpdateDerivs.cpp
@@ -84,9 +84,10 @@ void shammodels::gsph::modules::UpdateDerivs<Tvec, SPHKernel>::update_derivs_ite
         = has_uint ? ghost_layout.get_field_idx<Tscal>(gsph::names::newtonian::uint) : 0;
 
     // Get merged data and caches from storage
-    auto &merged_xyzh                                 = shambase::get_check_ref(storage.merged_xyzh).get_data();
-    shamrock::solvergraph::Field<Tscal> &omega_field  = shambase::get_check_ref(storage.omega);
-    shambase::DistributedData<PatchDataLayer> &mpdats = shambase::get_check_ref(storage.merged_patchdata_ghost).get_data();
+    auto &merged_xyzh = shambase::get_check_ref(storage.merged_xyzh).get_data();
+    shamrock::solvergraph::Field<Tscal> &omega_field = shambase::get_check_ref(storage.omega);
+    shambase::DistributedData<PatchDataLayer> &mpdats
+        = shambase::get_check_ref(storage.merged_patchdata_ghost).get_data();
 
     // Get pressure and soundspeed from storage (includes ghosts)
     shamrock::solvergraph::Field<Tscal> &pressure_field = shambase::get_check_ref(storage.pressure);
@@ -300,9 +301,10 @@ void shammodels::gsph::modules::UpdateDerivs<Tvec, SPHKernel>::update_derivs_hll
     u32 iuint_interf
         = has_uint ? ghost_layout.get_field_idx<Tscal>(gsph::names::newtonian::uint) : 0;
 
-    auto &merged_xyzh                                 = shambase::get_check_ref(storage.merged_xyzh).get_data();
-    shamrock::solvergraph::Field<Tscal> &omega_field  = shambase::get_check_ref(storage.omega);
-    shambase::DistributedData<PatchDataLayer> &mpdats = shambase::get_check_ref(storage.merged_patchdata_ghost).get_data();
+    auto &merged_xyzh = shambase::get_check_ref(storage.merged_xyzh).get_data();
+    shamrock::solvergraph::Field<Tscal> &omega_field = shambase::get_check_ref(storage.omega);
+    shambase::DistributedData<PatchDataLayer> &mpdats
+        = shambase::get_check_ref(storage.merged_patchdata_ghost).get_data();
 
     // Get pressure and soundspeed from storage (includes ghosts)
     shamrock::solvergraph::Field<Tscal> &pressure_field = shambase::get_check_ref(storage.pressure);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants