Skip to content

Commit ac642ea

Browse files
authored
Fix crash in log-mel frontend when waveform samples are integers. (apple#1017)
After updating JAX, this existing hidden bug started causing CI failures. When the sample dtype is int32 (which is valid), `jnp.finfo` returns None, even though `jnp.iinfo` is available. The previous JAX version seemed to handle this case more forgivingly. ``` ../axlearn/axlearn/audio/frontend_utils.py:297: in linear_to_log_spectrogram return jnp.log(jnp.maximum(x, jnp.finfo(x.dtype).tiny)) ```
1 parent 7c64b55 commit ac642ea

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

axlearn/audio/frontend.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,19 @@ def _pre_emphasis(coeff: float) -> StageFn:
118118
return functools.partial(pre_emphasis, coeff=jnp.array(coeff))
119119

120120

121+
def _fft_dtype(input_dtype: jnp.dtype) -> jnp.dtype:
122+
if input_dtype in (jnp.bfloat16, jnp.float32, jnp.float64):
123+
return input_dtype
124+
elif input_dtype == jnp.int16:
125+
return jnp.bfloat16
126+
elif input_dtype == jnp.int32:
127+
return jnp.float32
128+
elif input_dtype == jnp.int64:
129+
return jnp.float64
130+
else:
131+
raise ValueError(f"{input_dtype=} is not supported.")
132+
133+
121134
class LogMelFrontend(BaseFrontend):
122135
"""Computes Log Mel spectrogram features.
123136
@@ -224,7 +237,7 @@ def _to_logmel(self, frames: Tensor, *, frames_paddings: Tensor) -> dict[str, Te
224237
frames = windowing(frames, window_type=WindowType.HANN)
225238
# FFT and construct spectrogram.
226239
# [batch_size, num_frames, fft_size] -> [batch, num_frames, num_filters].
227-
outputs = self._spectrogram(self._fft(frames), dtype=frames.dtype)
240+
outputs = self._spectrogram(self._fft(frames), dtype=_fft_dtype(frames.dtype))
228241
if self._output_transformation is not None:
229242
outputs = self._output_transformation(outputs)
230243
outputs = outputs * (1 - einops.rearrange(frames_paddings, "b t -> b t 1"))

axlearn/audio/frontend_test.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from jax.experimental import mesh_utils
1818
from jax.sharding import Mesh, NamedSharding, PartitionSpec
1919

20-
from axlearn.audio.frontend import LogMelFrontend, normalize_by_mean_std
20+
from axlearn.audio.frontend import LogMelFrontend, _fft_dtype, normalize_by_mean_std
2121
from axlearn.audio.frontend_utils import (
2222
linear_to_log_spectrogram,
2323
magnitude_spectrogram,
@@ -180,7 +180,8 @@ def test_output_dim(self):
180180
with self.assertRaisesRegex(ValueError, "output_dim"):
181181
cfg.set(name="test").instantiate(parent=None)
182182

183-
def test_small_input(self):
183+
@parameterized.product(input_dtype=[jnp.bfloat16, jnp.float32, jnp.float64, jnp.int32])
184+
def test_small_input(self, input_dtype):
184185
sample_rate, batch_size, max_seconds = 16_000, 4, 13
185186
num_filters = 80
186187

@@ -189,7 +190,7 @@ def test_small_input(self):
189190
prng_key=jax.random.PRNGKey(123),
190191
batch_size=batch_size,
191192
seq_len=max_seconds * sample_rate,
192-
dtype=jnp.float64,
193+
dtype=input_dtype,
193194
scale=1.0,
194195
)
195196

@@ -334,7 +335,7 @@ def _log_spectogram(x: Tensor, *, dtype: jnp.dtype) -> Tensor:
334335
output_shape = layer.output_shape(input_shape=inputs.shape)
335336
self.assertSequenceEqual(test_outputs.shape, output_shape)
336337

337-
@parameterized.parameters([(jnp.float32,), (jnp.bfloat16,)])
338+
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int32])
338339
def test_dtype(self, dtype):
339340
# Test that the frontend outputs follow the same dtype as inputs.
340341
sample_rate, batch_size, max_seconds = 16_000, 4, 13
@@ -356,7 +357,7 @@ def test_dtype(self, dtype):
356357
)
357358
test_outputs = self._jit_forward(layer, inputs, paddings)
358359
test_outputs, test_paddings = test_outputs["outputs"], test_outputs["paddings"]
359-
self.assertEqual(test_outputs.dtype, inputs.dtype)
360+
self.assertEqual(test_outputs.dtype, _fft_dtype(inputs.dtype))
360361
self.assertEqual(test_paddings.dtype, paddings.dtype)
361362

362363

axlearn/audio/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def fake_audio(
2323
shape=[batch_size, seq_len],
2424
minval=-scale,
2525
maxval=scale,
26-
dtype=dtype,
27-
)
26+
dtype=jnp.float32,
27+
).astype(dtype)
2828
lengths = jax.random.randint(length_key, shape=[batch_size, 1], minval=0, maxval=seq_len)
2929
paddings = (jnp.arange(seq_len)[None, :] >= lengths).astype(jnp.int32)
3030
return inputs, paddings

0 commit comments

Comments
 (0)