Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions src/gravity/octree.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "../struct/particle.h"
#include <vector>
#include <cmath>
#include <memory> // Required for unique_ptr
#include "dt/softening.h"
#include "floatdef.h"

Expand All @@ -22,35 +23,50 @@ struct Octree {
real x, y, z; // node center
real size; // half-width
bool leaf = true;
Particle* body = nullptr;
Octree* child[8] = { nullptr };
Particle* body = nullptr;

// Ownership: unique_ptr handles memory automatically
std::unique_ptr<Octree> child[8] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };

// Quadrupole tensor
real Qxx = 0, Qyy = 0, Qzz = 0;
real Qxy = 0, Qxz = 0, Qyz = 0;

Octree(real X, real Y, real Z, real S) : x(X), y(Y), z(Z), size(S), m(0), cx(0), cy(0), cz(0) {}

~Octree() { for (auto c : child) delete c; }
// Destructor is now empty; unique_ptr cleans up children automatically
~Octree() = default;

int index(const Particle& p) const {
return (p.x > x) * 1 + (p.y > y) * 2 + (p.z > z) * 4;
}

Octree* createChild(int idx) {
// Returns unique_ptr to take ownership
std::unique_ptr<Octree> createChild(int idx) {
real hs = size * real(0.5);
return new Octree(x + ((idx & 1) ? hs : -hs), y + ((idx & 2) ? hs : -hs), z + ((idx & 4) ? hs : -hs), hs);
return std::make_unique<Octree>(
x + ((idx & 1) ? hs : -hs),
y + ((idx & 2) ? hs : -hs),
z + ((idx & 4) ? hs : -hs),
hs
);
}

void insert(Particle* p) {
if (leaf && body == nullptr) { body = p; return; }
if (leaf && body == nullptr) {
body = p;
return;
}

if (leaf) {
leaf = false;
Particle* old = body; body = nullptr;
Particle* old = body;
body = nullptr;
int idx = index(*old);
if (!child[idx]) child[idx] = createChild(idx);
child[idx]->insert(old);
}

int idx = index(*p);
if (!child[idx]) child[idx] = createChild(idx);
child[idx]->insert(p);
Expand All @@ -65,7 +81,7 @@ struct Octree {
}

m = 0; cx = cy = cz = 0;
for (auto c : child) {
for (auto& c : child) { // Use reference to unique_ptr
if (!c) continue;
c->computeMass();
if (c->m == 0) continue;
Expand All @@ -75,10 +91,9 @@ struct Octree {
if (m > 0) { cx /= m; cy /= m; cz /= m; }

Qxx = Qyy = Qzz = Qxy = Qxz = Qyz = 0;
for (auto c : child) {
for (auto& c : child) {
if (!c || c->m == 0) continue;
real rx = c->cx - cx; real ry = c->cy - cy; real rz = c->cz - cz;
// Internal node softening to match force calculation
real r2 = rx * rx + ry * ry + rz * rz + (size * size * real(0.01));
real mchild = c->m;
Qxx += mchild * (3 * rx * rx - r2);
Expand All @@ -91,6 +106,7 @@ struct Octree {
}
};

// Traverse using raw pointers (non-owning observer)
inline void bhAccel(Octree* node, const Particle& p, real theta, real& ax, real& ay, real& az) {
if (!node || node->m == 0) return;
if (node->leaf && node->body == &p) return;
Expand Down Expand Up @@ -131,5 +147,7 @@ inline void bhAccel(Octree* node, const Particle& p, real theta, real& ax, real&
return;
}

for (auto c : node->child) if (c) bhAccel(c, p, theta, ax, ay, az);
for (auto& c : node->child) {
if (c) bhAccel(c.get(), p, theta, ax, ay, az); // Use .get() to pass raw pointer
}
}
60 changes: 22 additions & 38 deletions src/gravity/step.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,95 +14,79 @@
#include "floatdef.h"
#include "octree.h"
#include <vector>
#include <memory>
#include <algorithm>

inline void Step(std::vector<Particle> &p, real dt) {
if (p.empty())
return;
if (p.empty()) return;

real theta = 0.5;
real half = dt * real(0.5);

auto buildTree = [&](Octree *&root) {
// Compute bounding box
// Helper lambda that returns a unique_ptr
auto buildTree = [&]() -> std::unique_ptr<Octree> {
real minx = +1e30, miny = +1e30, minz = +1e30;
real maxx = -1e30, maxy = -1e30, maxz = -1e30;

for (auto &a : p) {
minx = std::min(minx, a.x);
miny = std::min(miny, a.y);
minz = std::min(minz, a.z);
maxx = std::max(maxx, a.x);
maxy = std::max(maxy, a.y);
maxz = std::max(maxz, a.z);
for (const auto &a : p) {
minx = std::min(minx, a.x); miny = std::min(miny, a.y); minz = std::min(minz, a.z);
maxx = std::max(maxx, a.x); maxy = std::max(maxy, a.y); maxz = std::max(maxz, a.z);
}

real cx = (minx + maxx) * 0.5;
real cy = (miny + maxy) * 0.5;
real cz = (minz + maxz) * 0.5;
real dx = maxx - minx;
real dy = maxy - miny;
real dz = maxz - minz;
real size = std::max({maxx - minx, maxy - miny, maxz - minz}) * real(0.5);

real size = std::max(dx, std::max(dy, dz)) * real(0.5);
if (size <= 0) size = 1;

if (size <= 0)
size = 1; // safety

root = new Octree(cx, cy, cz, size);
// Create the owned root
auto root = std::make_unique<Octree>(cx, cy, cz, size);

for (auto &a : p)
root->insert(&a);

root->computeMass();
return root;
};

// =========================
// First Kick (dt/2)
// =========================
// --- First Kick (dt/2) ---
{
Octree *root = nullptr;
buildTree(root);
std::unique_ptr<Octree> root = buildTree();

#pragma omp parallel for schedule(static)
for (int i = 0; i < (int)p.size(); i++) {
real ax = 0, ay = 0, az = 0;
bhAccel(root, p[i], theta, ax, ay, az);
// Pass the raw pointer via .get() for the traversal
bhAccel(root.get(), p[i], theta, ax, ay, az);

p[i].vx += ax * half;
p[i].vy += ay * half;
p[i].vz += az * half;
}

delete root;
// No 'delete root' needed! It happens automatically here.
}

// =========================
// Drift (dt)
// =========================
// --- Drift (dt) ---
#pragma omp parallel for schedule(static)
for (int i = 0; i < (int)p.size(); i++) {
p[i].x += p[i].vx * dt;
p[i].y += p[i].vy * dt;
p[i].z += p[i].vz * dt;
}

// =========================
// Second Kick (dt/2)
// =========================
// --- Second Kick (dt/2) ---
{
Octree *root = nullptr;
buildTree(root);
std::unique_ptr<Octree> root = buildTree();

#pragma omp parallel for schedule(static)
for (int i = 0; i < (int)p.size(); i++) {
real ax = 0, ay = 0, az = 0;
bhAccel(root, p[i], theta, ax, ay, az);
bhAccel(root.get(), p[i], theta, ax, ay, az);

p[i].vx += ax * half;
p[i].vy += ay * half;
p[i].vz += az * half;
}

delete root;
}
}