Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix crash in log-mel frontend when waveform samples are integers. #1017

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion axlearn/audio/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,19 @@ def _pre_emphasis(coeff: float) -> StageFn:
return functools.partial(pre_emphasis, coeff=jnp.array(coeff))


def _fft_dtype(input_dtype: jnp.dtype) -> jnp.dtype:
if input_dtype in (jnp.bfloat16, jnp.float32, jnp.float64):
return input_dtype
elif input_dtype == jnp.int16:
return jnp.bfloat16
elif input_dtype == jnp.int32:
return jnp.float32
elif input_dtype == jnp.int64:
return jnp.float64
else:
raise ValueError(f"{input_dtype=} is not supported.")


class LogMelFrontend(BaseFrontend):
"""Computes Log Mel spectrogram features.

Expand Down Expand Up @@ -224,7 +237,7 @@ def _to_logmel(self, frames: Tensor, *, frames_paddings: Tensor) -> dict[str, Te
frames = windowing(frames, window_type=WindowType.HANN)
# FFT and construct spectrogram.
# [batch_size, num_frames, fft_size] -> [batch, num_frames, num_filters].
outputs = self._spectrogram(self._fft(frames), dtype=frames.dtype)
outputs = self._spectrogram(self._fft(frames), dtype=_fft_dtype(frames.dtype))
if self._output_transformation is not None:
outputs = self._output_transformation(outputs)
outputs = outputs * (1 - einops.rearrange(frames_paddings, "b t -> b t 1"))
Expand Down
11 changes: 6 additions & 5 deletions axlearn/audio/frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from axlearn.audio.frontend import LogMelFrontend, normalize_by_mean_std
from axlearn.audio.frontend import LogMelFrontend, _fft_dtype, normalize_by_mean_std
from axlearn.audio.frontend_utils import (
linear_to_log_spectrogram,
magnitude_spectrogram,
Expand Down Expand Up @@ -180,7 +180,8 @@ def test_output_dim(self):
with self.assertRaisesRegex(ValueError, "output_dim"):
cfg.set(name="test").instantiate(parent=None)

def test_small_input(self):
@parameterized.product(input_dtype=[jnp.bfloat16, jnp.float32, jnp.float64, jnp.int32])
def test_small_input(self, input_dtype):
sample_rate, batch_size, max_seconds = 16_000, 4, 13
num_filters = 80

Expand All @@ -189,7 +190,7 @@ def test_small_input(self):
prng_key=jax.random.PRNGKey(123),
batch_size=batch_size,
seq_len=max_seconds * sample_rate,
dtype=jnp.float64,
dtype=input_dtype,
scale=1.0,
)

Expand Down Expand Up @@ -334,7 +335,7 @@ def _log_spectogram(x: Tensor, *, dtype: jnp.dtype) -> Tensor:
output_shape = layer.output_shape(input_shape=inputs.shape)
self.assertSequenceEqual(test_outputs.shape, output_shape)

@parameterized.parameters([(jnp.float32,), (jnp.bfloat16,)])
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int32])
def test_dtype(self, dtype):
# Test that the frontend outputs follow the same dtype as inputs.
sample_rate, batch_size, max_seconds = 16_000, 4, 13
Expand All @@ -356,7 +357,7 @@ def test_dtype(self, dtype):
)
test_outputs = self._jit_forward(layer, inputs, paddings)
test_outputs, test_paddings = test_outputs["outputs"], test_outputs["paddings"]
self.assertEqual(test_outputs.dtype, inputs.dtype)
self.assertEqual(test_outputs.dtype, _fft_dtype(inputs.dtype))
self.assertEqual(test_paddings.dtype, paddings.dtype)


Expand Down
4 changes: 2 additions & 2 deletions axlearn/audio/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def fake_audio(
shape=[batch_size, seq_len],
minval=-scale,
maxval=scale,
dtype=dtype,
)
dtype=jnp.float32,
).astype(dtype)
lengths = jax.random.randint(length_key, shape=[batch_size, 1], minval=0, maxval=seq_len)
paddings = (jnp.arange(seq_len)[None, :] >= lengths).astype(jnp.int32)
return inputs, paddings