Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 167 additions & 7 deletions doc/sphinx/examples/sph/run_sph_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
rout = 10 * rin # [au]
r0 = rin # [au]

H_r_0 = 0.01
H_r_0 = 0.05
q = 0.75
p = 3.0 / 2.0

Expand Down Expand Up @@ -146,12 +146,12 @@ def H_profile(r):
# cfg.add_ext_force_point_mass(center_mass, center_racc)

cfg.add_kill_sphere(center=(0, 0, 0), radius=bsize) # kill particles outside the simulation box
cfg.add_ext_force_lense_thirring(
central_mass=center_mass,
Racc=rin,
a_spin=0.9,
dir_spin=(np.sin(inclination), np.cos(inclination), 0.0),
)
# cfg.add_ext_force_lense_thirring(
# central_mass=center_mass,
# Racc=rin,
# a_spin=0.9,
# dir_spin=(np.sin(inclination), np.cos(inclination), 0.0),
# )

cfg.set_units(codeu)
cfg.set_particle_mass(pmass)
Expand Down Expand Up @@ -482,3 +482,163 @@ def plot_vtheta_slice_cylindrical(metadata, arr_vtheta_pos):
plot_vtheta_slice_cylindrical(metadata, arr_vtheta_pos)

plt.show()

# %%
# Azymuthal rendering

H_r_render = H_r_0 * 4


def make_azymuthal_coords(nr, nz):
"""
Generate a list of positions in cylindrical coordinates (r, theta)
spanning [0, ext*2] x [-pi, pi] for use with the rendering module.

Returns:
list: List of [x, y, z] coordinate lists
"""

# Create the cylindrical coordinate grid
r_vals = np.linspace(0, ext, nr)
z_vals = np.linspace(-ext * H_r_render, ext * H_r_render, nz)

# Create meshgrid
r_grid, z_grid = np.meshgrid(r_vals, z_vals)

# Flatten and stack to create list of positions
positions = np.column_stack([r_grid.ravel(), z_grid.ravel()])

return [tuple(pos) for pos in positions]


def make_ring_rays(positions):
def position_to_ring_ray(position):
r = position[0]
z = position[1]
e_x = (1.0, 0.0, 0.0)
e_y = (0.0, 1.0, 0.0)
center = (0.0, 0.0, z)
return shamrock.math.RingRay_f64_3(center, r, e_x, e_y)

return [position_to_ring_ray(position) for position in positions]


def make_slice_coord_for_azymuthal(positions):
def position_to_ring_ray(position):
r = position[0]
z = position[1]
e_x = (1.0, 0.0, 0.0)
e_y = (0.0, 1.0, 0.0)
center = (0.0, 0.0, z)
return (r, 0.0, z)

return [position_to_ring_ray(position) for position in positions]


nr = 1024
nz = 1024

positions_azymuthal = make_azymuthal_coords(nr, nz)
ring_rays_azymuthal = make_ring_rays(positions_azymuthal)
slice_coords_azymuthal = make_slice_coord_for_azymuthal(positions_azymuthal)

arr_rho_azymuthal = model.render_azymuthal_integ("rho", "f64", ring_rays_azymuthal)
arr_rho_slice_azymuthal = model.render_slice("rho", "f64", slice_coords_azymuthal)

arr_vxyz_azymuthal = model.render_azymuthal_integ("vxyz", "f64_3", ring_rays_azymuthal)
arr_vxyz_slice_azymuthal = model.render_slice("vxyz", "f64_3", slice_coords_azymuthal)


def plot_rho_integ_azymuthal(metadata, arr_rho_azymuthal):
ext = metadata["extent"]

my_cmap = matplotlib.colormaps["gist_heat"].copy() # copy the default cmap
my_cmap.set_bad(color="black")

arr_rho_azymuthal = np.array(arr_rho_azymuthal).reshape(nr, nz)

res = plt.imshow(
arr_rho_azymuthal, cmap=my_cmap, origin="lower", extent=ext, norm="log", vmin=1e-5, vmax=1
)
plt.xlabel("r")
plt.ylabel("z")
plt.title(f"t = {metadata['time']:0.3f} [seconds]")
cbar = plt.colorbar(res, extend="both")
cbar.set_label(r"$\int \rho \, \mathrm{d}\theta$ [code unit]")


def plot_rho_slice_azymuthal(metadata, arr_rho_slice_azymuthal):
ext = metadata["extent"]

my_cmap = matplotlib.colormaps["gist_heat"].copy() # copy the default cmap
my_cmap.set_bad(color="black")

arr_rho_slice_azymuthal = np.array(arr_rho_slice_azymuthal).reshape(nr, nz)

res = plt.imshow(
arr_rho_slice_azymuthal,
cmap=my_cmap,
origin="lower",
extent=ext,
norm="log",
vmin=1e-5,
vmax=1,
)
plt.xlabel("r")
plt.ylabel("z")
plt.title(f"t = {metadata['time']:0.3f} [seconds]")
cbar = plt.colorbar(res, extend="both")
cbar.set_label(r"$\rho$ [code unit]")


def plot_vz_integ_azymuthal(metadata, arr_vxyz_azymuthal):
ext = metadata["extent"]

my_cmap = matplotlib.colormaps["seismic"].copy() # copy the default cmap
my_cmap.set_bad(color="black")

arr_vz_azymuthal = np.array(arr_vxyz_azymuthal).reshape(nr, nz, 3)[:, :, 2]

res = plt.imshow(
arr_vz_azymuthal, cmap=my_cmap, origin="lower", extent=ext, vmin=-1e-6, vmax=1e-6
)
plt.xlabel("r")
plt.ylabel("z")
plt.title(f"t = {metadata['time']:0.3f} [seconds]")
cbar = plt.colorbar(res, extend="both")
cbar.set_label(r"$\int v_z \, \mathrm{d}\theta$ [code unit]")


def plot_vz_slice_azymuthal(metadata, arr_vxyz_slice_azymuthal):
ext = metadata["extent"]

my_cmap = matplotlib.colormaps["seismic"].copy() # copy the default cmap
my_cmap.set_bad(color="black")

arr_vz_slice_azymuthal = np.array(arr_vxyz_slice_azymuthal).reshape(nr, nz, 3)[:, :, 2]

res = plt.imshow(
arr_vz_slice_azymuthal, cmap=my_cmap, origin="lower", extent=ext, vmin=-5e-6, vmax=5e-6
)
plt.xlabel("r")
plt.ylabel("z")
plt.title(f"t = {metadata['time']:0.3f} [seconds]")
cbar = plt.colorbar(res, extend="both")
cbar.set_label(r"$v_z$ [code unit]")


metadata = {"extent": [0, ext, -ext * H_r_render, ext * H_r_render], "time": model.get_time()}
fig_size = (6, 3)
plt.figure(dpi=dpi, figsize=fig_size)
plot_rho_integ_azymuthal(metadata, arr_rho_azymuthal)

plt.figure(dpi=dpi, figsize=fig_size)
plot_rho_slice_azymuthal(metadata, arr_rho_slice_azymuthal)

plt.figure(dpi=dpi, figsize=fig_size)
plot_vz_integ_azymuthal(metadata, arr_vxyz_azymuthal)

plt.figure(dpi=dpi, figsize=fig_size)
plot_vz_slice_azymuthal(metadata, arr_vxyz_slice_azymuthal)

plt.show()
61 changes: 61 additions & 0 deletions src/shammath/include/shammath/AABB.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,35 @@ namespace shammath {
}
};

/**
* @brief Ring ray representation for intersection testing
*
* A = {center + radius * (e_x * cos(theta) + e_y * sin(theta)) | theta in [0, 2*pi]}
*
* @tparam T Vector type for coordinates
*/
template<class T>
struct RingRay {
using T_prop = shambase::VectorProperties<T>;
using Tscal = typename T_prop::component_type;

T center;
Tscal radius;
T e_x;
T e_y;

T get_ez() { return sycl::cross(e_x, e_y); }
/**
* @brief Construct a ring ray from center, e_x, and e_y
*
* @param center Center of the ring
* @param e_x Unit vector along the x-axis of the ring
* @param e_y Unit vector along the y-axis of the ring
*/
inline RingRay(T center, Tscal radius, T e_x, T e_y)
: center(center), radius(radius), e_x(e_x), e_y(e_y) {}
};

/**
* @brief Axis-Aligned bounding box
*
Expand Down Expand Up @@ -121,6 +150,8 @@ namespace shammath {
*/
inline T delt() const { return upper - lower; }

inline Tscal get_radius() const { return sycl::length(delt()) / 2; }

/**
* @brief Returns the volume of the AABB
*
Expand Down Expand Up @@ -300,6 +331,17 @@ namespace shammath {
*/
[[nodiscard]] inline bool intersect_ray(Ray<T> ray) const noexcept;

/**
* @brief Check if the ring ray intersect the AABB
*
* This function perform a ring ray-AABB intersection test.
* It return true if the ring ray intersect the AABB and false otherwise.
*
* @param[in] ring_ray The ring ray to test
* @return true if the ring ray intersect the AABB
*/
[[nodiscard]] inline bool intersect_ring_ray_approx(RingRay<T> ring_ray) const noexcept;

/// equal operator
inline bool operator==(const AABB<T> &other) const noexcept {
return sham::equals(lower, other.lower) && sham::equals(upper, other.upper);
Expand Down Expand Up @@ -334,4 +376,23 @@ namespace shammath {
return tmax >= tmin;
}

template<class T>
[[nodiscard]] inline bool AABB<T>::intersect_ring_ray_approx(
RingRay<T> ring_ray) const noexcept {
T aabb_center = get_center();
Tscal aabb_radius = get_radius();

T r_center = ring_ray.center - aabb_center;

Tscal x_val = sycl::dot(r_center, ring_ray.e_x);
Tscal y_val = sycl::dot(r_center, ring_ray.e_y);
Tscal z_val = sycl::dot(r_center, ring_ray.get_ez());

Tscal r_val = sycl::sqrt(x_val * x_val + y_val * y_val);
Tscal delta_r = r_val - ring_ray.radius;
Tscal rab2_ring = z_val * z_val + delta_r * delta_r;

return rab2_ring <= aabb_radius * aabb_radius;
}

} // namespace shammath
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ namespace shammodels::sph::modules {
std::function<field_getter_t> field_getter,
const sham::DeviceBuffer<shammath::Ray<Tvec>> &rays);

sham::DeviceBuffer<Tfield> compute_azymuthal_integ(
std::function<field_getter_t> field_getter,
const sham::DeviceBuffer<shammath::RingRay<Tvec>> &ring_rays);

sham::DeviceBuffer<Tfield> compute_slice(
std::string field_name,
const sham::DeviceBuffer<Tvec> &positions,
Expand All @@ -66,6 +70,12 @@ namespace shammodels::sph::modules {
std::optional<std::function<pybind11::array_t<Tfield>(size_t, pybind11::dict &)>>
custom_getter);

sham::DeviceBuffer<Tfield> compute_azymuthal_integ(
std::string field_name,
const sham::DeviceBuffer<shammath::RingRay<Tvec>> &ring_rays,
std::optional<std::function<pybind11::array_t<Tfield>(size_t, pybind11::dict &)>>
custom_getter);

sham::DeviceBuffer<Tfield> compute_slice(
std::function<field_getter_t> field_getter,
Tvec center,
Expand Down Expand Up @@ -124,6 +134,17 @@ namespace shammodels::sph::modules {
return compute_column_integ(field_name, rays_buf, custom_getter);
}

inline sham::DeviceBuffer<Tfield> compute_azymuthal_integ(
std::string field_name,
const std::vector<shammath::RingRay<Tvec>> &ring_rays,
std::optional<std::function<pybind11::array_t<Tfield>(size_t, pybind11::dict &)>>
custom_getter) {
sham::DeviceBuffer<shammath::RingRay<Tvec>> ring_rays_buf{
ring_rays.size(), shamsys::instance::get_compute_scheduler_ptr()};
ring_rays_buf.copy_from_stdvec(ring_rays);
return compute_azymuthal_integ(field_name, ring_rays_buf, custom_getter);
}

private:
inline PatchScheduler &scheduler() { return shambase::get_check_ref(context.sched); }
};
Expand Down
Loading
Loading