Skip to content

Add Ray tracing method for RIR (#2850) #3234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/prototype.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ Room Impulse Response Simulation
:toctree: generated
:nosignatures:

ray_tracing
simulate_rir_ism
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
if _mod_utils.is_module_available("pyroomacoustics"):
import pyroomacoustics as pra

import numpy as np
import torch
import torchaudio.prototype.functional as F
from parameterized import param, parameterized
Expand Down Expand Up @@ -545,3 +546,303 @@ def test_simulate_rir_ism_multi_band(self, channel):
expected[i, 0 : room.rir[i][0].shape[0]] = torch.from_numpy(room.rir[i][0])
actual = F.simulate_rir_ism(room_dim, source, mic_array, max_order, absorption)
self.assertEqual(expected, actual, atol=1e-3, rtol=1e-3)

@parameterized.expand(
[
(0.1, 0.2, (2, 1, 2500)), # both float
# Per-wall
(torch.rand(4), 0.2, (2, 1, 2500)),
(0.1, torch.rand(4), (2, 1, 2500)),
(torch.rand(4), torch.rand(4), (2, 1, 2500)),
# Per-band and per-wall
(torch.rand(6, 4), 0.2, (2, 6, 2500)),
(0.1, torch.rand(6, 4), (2, 6, 2500)),
(torch.rand(6, 4), torch.rand(6, 4), (2, 6, 2500)),
]
)
def test_ray_tracing_output_shape(self, absorption, scattering, expected_shape):
room_dim = torch.tensor([20, 25], dtype=self.dtype)
mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype)
source = torch.tensor([7, 6], dtype=self.dtype)
num_rays = 100

hist = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)

assert hist.shape == expected_shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use self.assertEqual. (Buck does not report numbers with assert)


def test_ray_tracing_input_errors(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need more of these. currently the implementation does not reject the invalid shapes like empty tensors.

with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"):
F.ray_tracing(
room=torch.tensor([[4, 5]]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[3, 4]]), num_rays=10
)
with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"):
F.ray_tracing(
room=torch.tensor([4, 5, 4, 5]),
source=torch.tensor([0, 0]),
mic_array=torch.tensor([[3, 4]]),
num_rays=10,
)
with self.assertRaisesRegex(ValueError, r"mic_array must be 1D tensor of shape \(D,\), or 2D tensor"):
F.ray_tracing(
room=torch.tensor([4, 5]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[[3, 4]]]), num_rays=10
)
with self.assertRaisesRegex(ValueError, "room must be of float32 or float64 dtype"):
F.ray_tracing(
room=torch.tensor([4, 5]).to(torch.int),
source=torch.tensor([0, 0]),
mic_array=torch.tensor([3, 4]),
num_rays=10,
)
with self.assertRaisesRegex(ValueError, "dtype of room, source and mic_array must be the same"):
F.ray_tracing(
room=torch.tensor([4, 5]).to(torch.float64),
source=torch.tensor([0, 0]).to(torch.float32),
mic_array=torch.tensor([3, 4]),
num_rays=10,
)
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
F.ray_tracing(
room=torch.tensor([4, 5, 10], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
)
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
)
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
F.ray_tracing(
room=torch.tensor([4, 5, 10], dtype=torch.float),
source=torch.tensor([0, 0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
)
with self.assertRaisesRegex(ValueError, "time_thres=10 must be at least greater than hist_bin_size=11"):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
time_thres=10,
hist_bin_size=11,
)
with self.assertRaisesRegex(ValueError, "The shape of absorption must be"):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=torch.rand(5, dtype=torch.float),
)
with self.assertRaisesRegex(ValueError, "The shape of scattering must be"):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
scattering=torch.rand(5, 5, dtype=torch.float),
)
with self.assertRaisesRegex(ValueError, "The shape of absorption must be"):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=torch.rand(5, 5, dtype=torch.float),
)
with self.assertRaisesRegex(ValueError, "The shape of scattering must be"):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
scattering=torch.rand(5, dtype=torch.float),
)
with self.assertRaisesRegex(
ValueError, "absorption and scattering must have the same number of bands and walls"
):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=torch.rand(6, 4, dtype=torch.float),
scattering=torch.rand(5, 4, dtype=torch.float),
)

# Make sure passing different shapes for absorption or scattering doesn't raise an error
# float and tensor
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=0.1,
scattering=torch.rand(5, 4, dtype=torch.float),
)
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=torch.rand(5, 4, dtype=torch.float),
scattering=0.1,
)
# per-wall only and per-band + per-wall
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=torch.rand(4, dtype=torch.float),
scattering=torch.rand(6, 4, dtype=torch.float),
)
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=torch.rand(6, 4, dtype=torch.float),
scattering=torch.rand(4, dtype=torch.float),
)

def test_ray_tracing_per_band_per_wall_absorption(self):
"""Check that when the value of absorption and scattering are the same
across walls and frequency bands, the output histograms are:
- all equal across frequency bands
- equal to simply passing a float value instead of a (num_bands, D) or
(D,) tensor.
"""

room_dim = torch.tensor([20, 25], dtype=self.dtype)
mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype)
source = torch.tensor([7, 6], dtype=self.dtype)
num_rays = 1_000
ABS, SCAT = 0.1, 0.2

absorption = torch.full(fill_value=ABS, size=(6, 4), dtype=self.dtype)
scattering = torch.full(fill_value=SCAT, size=(6, 4), dtype=self.dtype)
hist_per_band_per_wall = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)
absorption = torch.full(fill_value=ABS, size=(4,), dtype=self.dtype)
scattering = torch.full(fill_value=SCAT, size=(4,), dtype=self.dtype)
hist_per_wall = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)

absorption = ABS
scattering = SCAT
hist_single = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)
assert hist_per_band_per_wall.shape == (2, 6, 2500)
assert hist_per_wall.shape == (2, 1, 2500)
assert hist_single.shape == (2, 1, 2500)
torch.testing.assert_close(hist_single, hist_per_wall)

hist_single = hist_single.expand(2, 6, 2500)
torch.testing.assert_close(hist_single, hist_per_band_per_wall)

@parameterized.expand(
[
([20, 25], [2, 2], [[8, 8], [7, 6]], 10_000), # 2D with 2 mics
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 1_000), # 3D with 1 mic
]
)
def test_ray_tracing_same_results_as_pyroomacoustics(self, room_dim, source, mic_array, num_rays):

walls = ["west", "east", "south", "north"]
if len(room_dim) == 3:
walls += ["floor", "ceiling"]
num_walls = len(walls)
num_bands = 6 # Note: in ray tracing, we don't need to restrict the number of bands to 7

absorption = torch.rand(num_bands, num_walls, dtype=self.dtype)
scattering = torch.rand(num_bands, num_walls, dtype=self.dtype)
energy_thres = 1e-7
time_thres = 10.0
hist_bin_size = 0.004
mic_radius = 0.5
sound_speed = 343.0

room_dim = torch.tensor(room_dim, dtype=self.dtype)
source = torch.tensor(source, dtype=self.dtype)
mic_array = torch.tensor(mic_array, dtype=self.dtype)

room = pra.ShoeBox(
room_dim.tolist(),
ray_tracing=True,
materials={
walls[i]: pra.Material(
energy_absorption={
"coeffs": absorption[:, i].reshape(-1).detach().numpy(),
"center_freqs": 125 * 2 ** np.arange(num_bands),
},
scattering={
"coeffs": scattering[:, i].reshape(-1).detach().numpy(),
"center_freqs": 125 * 2 ** np.arange(num_bands),
},
)
for i in range(num_walls)
},
air_absorption=False,
max_order=0, # Make sure PRA doesn't use the hybrid method (we just want ray tracing)
)
room.add_microphone_array(mic_array.T.tolist())
room.add_source(source.tolist())
room.set_ray_tracing(
n_rays=num_rays,
energy_thres=energy_thres,
time_thres=time_thres,
hist_bin_size=hist_bin_size,
receiver_radius=mic_radius,
)
room.set_sound_speed(sound_speed)

room.compute_rir()
hist_pra = torch.tensor(np.array(room.rt_histograms))[:, 0, 0]

hist = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
sound_speed=sound_speed,
mic_radius=mic_radius,
energy_thres=energy_thres,
time_thres=time_thres,
hist_bin_size=hist_bin_size,
)

assert hist.ndim == 3
assert hist.shape == hist_pra.shape
self.assertEqual(hist.to(torch.float32), hist_pra)
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,43 @@ def test_simulate_rir_ism_multi_band(self, channel):
F.simulate_rir_ism,
(room_dim, source, mic_array, max_order, absorption, None, 81, center_frequency, 343.0, 16000.0),
)

@parameterized.expand(
[
([20, 25], [2, 2], [[8, 8], [7, 6]], 1_000), # 2D with 2 mics
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 500), # 3D with 1 mic
]
)
def test_ray_tracing(self, room_dim, source, mic_array, num_rays):
num_walls = 4 if len(room_dim) == 2 else 6
num_bands = 3

absorption = torch.rand(num_bands, num_walls, dtype=torch.float32)
scattering = torch.rand(num_bands, num_walls, dtype=torch.float32)

energy_thres = 1e-7
time_thres = 10.0
hist_bin_size = 0.004
mic_radius = 0.5
sound_speed = 343.0

room_dim = torch.tensor(room_dim, dtype=self.dtype)
source = torch.tensor(source, dtype=self.dtype)
mic_array = torch.tensor(mic_array, dtype=self.dtype)

self._assert_consistency(
F.ray_tracing,
(
room_dim,
source,
mic_array,
num_rays,
absorption,
scattering,
mic_radius,
sound_speed,
energy_thres,
time_thres,
hist_bin_size,
),
)
2 changes: 1 addition & 1 deletion torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ if(BUILD_RNNT)
endif()

if(BUILD_RIR)
list(APPEND sources rir.cpp)
list(APPEND sources rir.cpp ray_tracing.cpp)
list(APPEND compile_definitions INCLUDE_RIR)
endif()

Expand Down
Loading