Skip to content

WIP: Add Ray Tracing #3604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/prototype.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ Room Impulse Response Simulation
:toctree: generated
:nosignatures:

ray_tracing
simulate_rir_ism
8 changes: 6 additions & 2 deletions src/torchaudio/csrc/rir/ray_tracing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,21 @@ class RayTracer {
if (NORM(to_mic - dir * impact_distance) < mic_radius + EPS) {
// The length of this last hop
auto travel_dist_at_mic = travel_dist + std::abs(impact_distance);
auto bin_idx = get_bin_idx(travel_dist_at_mic);
if (bin_idx >= histograms.size(1)) {
continue;
}
auto coeff = get_energy_coeff(travel_dist_at_mic, mic_radius_sq);
auto energy = energies / coeff;
histograms[mic_idx][get_bin_idx(travel_dist_at_mic)] += energy;
histograms[mic_idx][bin_idx] += energy;
}
}
}

travel_dist += hit_distance;
energies *= wall.reflection;

// Let's shoot the scattered ray induced by the rebound on the wall
// Let's shoot the scattered ray induced by the rebound on the wall
if (do_scattering) {
scat_ray(histograms, wall, energies, origin, hit_point, travel_dist);
energies *= (1. - wall.scattering);
Expand Down
10 changes: 6 additions & 4 deletions src/torchaudio/csrc/rir/wall.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@ struct Wall {
const torch::Tensor origin;
const torch::Tensor normal;
const torch::Tensor scattering;

const torch::Tensor reflection;

Wall(
const torch::ArrayRef<scalar_t>& origin,
const torch::ArrayRef<scalar_t>& normal,
const torch::Tensor& absorption,
const torch::Tensor& scattering)
: origin(torch::tensor(origin)),
normal(torch::tensor(normal)),
: origin(torch::tensor(origin).to(scattering.dtype())),
normal(torch::tensor(normal).to(scattering.dtype())),
scattering(scattering),
reflection(1. - absorption) {}
};
Expand Down Expand Up @@ -136,7 +135,6 @@ std::tuple<torch::Tensor, int, scalar_t> find_collision_wall(
for (unsigned int i = 0; i < 3; ++i) {
auto dir0 = SCALAR(direction[i]);
auto abs_dir0 = std::abs(dir0);

// If the ray is almost parallel to a plane, then we delegate the
// computation to the other planes.
if (abs_dir0 < EPS) {
Expand All @@ -147,6 +145,10 @@ std::tuple<torch::Tensor, int, scalar_t> find_collision_wall(
scalar_t distance = (dir0 < 0.)
? SCALAR(origin[i]) // Going towards origin
: SCALAR(room[i] - origin[i]); // Going away from origin
// sometimes origin is slightly outside of room
if (distance < 0) {
distance = 0.;
}
auto ratio = distance / abs_dir0;
int i_increment = dir0 > 0.;

Expand Down
3 changes: 2 additions & 1 deletion src/torchaudio/prototype/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
oscillator_bank,
sinc_impulse_response,
)
from ._rir import simulate_rir_ism
from ._rir import ray_tracing, simulate_rir_ism
from .functional import barkscale_fbanks, chroma_filterbank


Expand All @@ -20,6 +20,7 @@
"filter_waveform",
"frequency_impulse_response",
"oscillator_bank",
"ray_tracing",
"sinc_impulse_response",
"simulate_rir_ism",
]
117 changes: 112 additions & 5 deletions src/torchaudio/prototype/functional/_rir.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,24 @@ def _adjust_coeff(coeffs: Union[float, torch.Tensor], name: str) -> torch.Tensor
"""
num_walls = 6
if isinstance(coeffs, float):
if coeffs < 0:
raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
return torch.full((1, num_walls), coeffs)
if isinstance(coeffs, Tensor):
if torch.any(coeffs < 0):
raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
if coeffs.ndim == 1:
if coeffs.numel() != num_walls:
raise ValueError(
f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor."
f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor. "
f"Found the shape {coeffs.shape}."
)
return coeffs.unsqueeze(0)
if coeffs.ndim == 2:
if coeffs.shape != (7, num_walls):
if coeffs.shape[1] != num_walls:
raise ValueError(
f"The shape of `{name}` must be (7, {num_walls}) when it is a 2D Tensor."
f"Found the shape {coeffs.shape}."
f"The shape of `{name}` must be (NUM_BANDS, {num_walls}) when it "
f"is a 2D Tensor. Found: {coeffs.shape}."
)
return coeffs
raise TypeError(f"`{name}` must be float or Tensor.")
Expand All @@ -169,7 +173,7 @@ def _validate_inputs(
if not (source.ndim == 1 and source.numel() == 3):
raise ValueError(f"`source` must be 1D Tensor with 3 elements. Found {source.shape}.")
if not (mic_array.ndim == 2 and mic_array.shape[1] == 3):
raise ValueError(f"mic_array must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.")
raise ValueError(f"`mic_array` must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.")


def simulate_rir_ism(
Expand Down Expand Up @@ -270,3 +274,106 @@ def simulate_rir_ism(
rir = rir[..., :output_length]

return rir


def ray_tracing(
room: torch.Tensor,
source: torch.Tensor,
mic_array: torch.Tensor,
num_rays: int,
absorption: Union[float, torch.Tensor] = 0.0,
scattering: Union[float, torch.Tensor] = 0.0,
mic_radius: float = 0.5,
sound_speed: float = 343.0,
energy_thres: float = 1e-7,
time_thres: float = 10.0,
hist_bin_size: float = 0.004,
) -> torch.Tensor:
r"""Compute energy histogram via ray tracing.

The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`.

``num_rays`` rays are casted uniformly in all directions from the source;
when a ray intersects a wall, it is reflected and part of its energy is absorbed.
It is also scattered (sent directly to the microphone(s)) according to the ``scattering``
coefficient.
When a ray is close to the microphone, its current energy is recorded in the output
histogram for that given time slot.

.. devices:: CPU

.. properties:: TorchScript

Args:
room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents
three dimensions of the room.
source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
absorption (float or torch.Tensor, optional): The absorption coefficients of wall materials.
(Default: ``0.0``).
If the type is ``float``, the absorption coefficient is identical to all walls and
all frequencies.
If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, representing absorption
coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and
``"ceiling"``, respectively.
If ``absorption`` is a 2D Tensor, the shape must be `(num_bands, 6)`.
``num_bands`` is the number of frequency bands (usually 7).
scattering(float or torch.Tensor, optional): The scattering coefficients of wall materials. (Default: ``0.0``)
The shape and type of this parameter is the same as for ``absorption``.
mic_radius(float, optional): The radius of the microphone in meters. (Default: 0.5)
sound_speed (float, optional): The speed of sound in meters per second. (Default: ``343.0``)
energy_thres (float, optional): The energy level below which we stop tracing a ray. (Default: ``1e-7``)
The initial energy of each ray is ``2 / num_rays``.
time_thres (float, optional): The maximal duration for which rays are traced. (Unit: seconds) (Default: 10.0)
hist_bin_size (float, optional): The size of each bin in the output histogram. (Unit: seconds) (Default: 0.004)

Returns:
(torch.Tensor): The 3D histogram(s) where the energy of the traced ray is recorded.
Each bin corresponds to a given time slot.
The shape is `(channel, num_bands, num_bins)`, where
``num_bins = ceil(time_thres / hist_bin_size)``.
If both ``absorption`` and ``scattering`` are floats, then ``num_bands == 1``.
"""
if time_thres < hist_bin_size:
raise ValueError(
"`time_thres` must be greater than `hist_bin_size`. "
f"Found: hist_bin_size={hist_bin_size}, time_thres={time_thres}."
)

if room.dtype != source.dtype or source.dtype != mic_array.dtype:
raise ValueError(
"dtype of `room`, `source` and `mic_array` must match. "
f"Found: `room` ({room.dtype}), `source` ({source.dtype}) and "
f"`mic_array` ({mic_array.dtype})"
)

_validate_inputs(room, source, mic_array)
absorption = _adjust_coeff(absorption, "absorption").to(room.dtype)
scattering = _adjust_coeff(scattering, "scattering").to(room.dtype)

# Bring absorption and scattering to the same shape
if absorption.shape[0] == 1 and scattering.shape[0] > 1:
absorption = absorption.expand(scattering.shape)
if scattering.shape[0] == 1 and absorption.shape[0] > 1:
scattering = scattering.expand(absorption.shape)
if absorption.shape != scattering.shape:
raise ValueError(
"`absorption` and `scattering` must be broadcastable to the same number of bands and walls. "
f"Inferred shapes absorption={absorption.shape} and scattering={scattering.shape}"
)

histograms = torch.ops.torchaudio.ray_tracing(
room,
source,
mic_array,
num_rays,
absorption,
scattering,
mic_radius,
sound_speed,
energy_thres,
time_thres,
hist_bin_size,
)

return histograms
34 changes: 25 additions & 9 deletions test/cpp/rir/wall_collision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@

using namespace torchaudio::rir;

using DTYPE = double;

struct CollisionTestParam {
// Input
torch::Tensor origin;
torch::Tensor direction;
// Expected
torch::Tensor hit_point;
int next_wall_index;
float hit_distance;
DTYPE hit_distance;
};

CollisionTestParam par(
torch::ArrayRef<float> origin,
torch::ArrayRef<float> direction,
torch::ArrayRef<float> hit_point,
torch::ArrayRef<DTYPE> origin,
torch::ArrayRef<DTYPE> direction,
torch::ArrayRef<DTYPE> hit_point,
int next_wall_index,
float hit_distance) {
DTYPE hit_distance) {
auto dir = torch::tensor(direction);
return {
torch::tensor(origin),
Expand Down Expand Up @@ -50,18 +52,22 @@ TEST_P(Simple3DRoomCollisionTest, CollisionTest3D) {

auto param = GetParam();
auto [hit_point, next_wall_index, hit_distance] =
find_collision_wall<float>(room, param.origin, param.direction);
find_collision_wall<DTYPE>(room, param.origin, param.direction);

EXPECT_EQ(param.next_wall_index, next_wall_index);
EXPECT_FLOAT_EQ(param.hit_distance, hit_distance);
EXPECT_TRUE(torch::allclose(
param.hit_point, hit_point, /*rtol*/ 1e-05, /*atol*/ 1e-07));
EXPECT_NEAR(
param.hit_point[0].item<DTYPE>(), hit_point[0].item<DTYPE>(), 1e-5);
EXPECT_NEAR(
param.hit_point[1].item<DTYPE>(), hit_point[1].item<DTYPE>(), 1e-5);
EXPECT_NEAR(
param.hit_point[2].item<DTYPE>(), hit_point[2].item<DTYPE>(), 1e-5);
}

#define ISQRT2 0.70710678118

INSTANTIATE_TEST_CASE_P(
Collision3DTests,
BasicCollisionTests,
Simple3DRoomCollisionTest,
::testing::Values(
// From 0
Expand Down Expand Up @@ -100,3 +106,13 @@ INSTANTIATE_TEST_CASE_P(
par({.5, .5, 1}, {0.0, -1., -1.}, {.5, .0, .5}, 2, ISQRT2),
par({.5, .5, 1}, {0.0, 1.0, -1.}, {.5, 1., .5}, 3, ISQRT2),
par({.5, .5, 1}, {0.0, 0.0, -1.}, {.5, .5, .0}, 4, 1.0)));

INSTANTIATE_TEST_CASE_P(
CornerCollisionTest,
Simple3DRoomCollisionTest,
::testing::Values(
par({1, 1, 0}, {1., 1., 0.}, {1., 1., 0.}, 1, 0.0),
par({1, 1, 0}, {-1., 1., 0.}, {1., 1., 0.}, 3, 0.0),
par({1, 1, 1}, {1., 1., 1.}, {1., 1., 1.}, 1, 0.0),
par({1, 1, 1}, {-1., 1., 1.}, {1., 1., 1.}, 3, 0.0),
par({1, 1, 1}, {-1., -1., 1.}, {1., 1., 1.}, 5, 0.0)));
Loading