Skip to content

Commit 695b6fc

Browse files
committed
Add batch support for resample (tensorflow#1390)
This PR adds batch support for resample so that it is possible to pass 3D input of `[batch, sample, channel]` shape for tfio.audio.resample. The implementation is done through `tf.vectorized_map` This PR fixes 1366. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
1 parent f163ee1 commit 695b6fc

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

tensorflow_io/core/python/ops/audio_ops.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -372,12 +372,13 @@ def fade(input, fade_in, fade_out, mode, name=None):
372372
return factor_in * factor_out * input
373373

374374

375-
def resample(input, rate_in, rate_out, name=None): # pylint: disable=redefined-builtin
375+
def resample(input, rate_in, rate_out, name=None):
376376
"""Resample audio.
377377
378378
Args:
379-
input: A 1-D (`[samples]`) or 2-D (`[samples, channels]`) `Tensor` of
380-
type `int16` or `float`. Audio input.
379+
input: A 1-D (`[samples]`) or 2-D (`[samples, channels]`) or 3-D
380+
(`[batch, samples, channels]`) `Tensor` of type
381+
`int16` or `float`. Audio input.
381382
rate_in: The rate of the audio input.
382383
rate_out: The rate of the audio output.
383384
name: A name for the operation (optional).
@@ -387,14 +388,37 @@ def resample(input, rate_in, rate_out, name=None): # pylint: disable=redefined-
387388
"""
388389
rank = tf.rank(input)
389390

390-
input = tf.cond(
391-
tf.math.equal(rank, 1), lambda: tf.expand_dims(input, -1), lambda: input
392-
)
393-
value = core_ops.io_audio_resample(
394-
input, rate_in=rate_in, rate_out=rate_out, name=name
391+
def f1():
392+
return tf.expand_dims(tf.expand_dims(input, -1), 0)
393+
394+
def f2():
395+
return tf.expand_dims(input, 0)
396+
397+
def f3():
398+
return input
399+
400+
input = tf.case(
401+
[(tf.math.equal(rank, 1), f1), (tf.math.equal(rank, 2), f2)], default=f3
395402
)
396-
return tf.cond(
397-
tf.math.equal(rank, 1), lambda: tf.squeeze(value, [-1]), lambda: value
403+
404+
def f(i):
405+
return core_ops.io_audio_resample(
406+
i, rate_in=rate_in, rate_out=rate_out, name=name
407+
)
408+
409+
value = tf.vectorized_map(f, input)
410+
411+
def g1():
412+
return tf.squeeze(value, [0, -1])
413+
414+
def g2():
415+
return tf.squeeze(value, [0])
416+
417+
def g3():
418+
return value
419+
420+
return tf.case(
421+
[(tf.math.equal(rank, 1), g1), (tf.math.equal(rank, 2), g2)], default=g3
398422
)
399423

400424

tests/test_audio_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ def fixture_resample_1d():
7272
return args[:, 0], func, expected[:, 0]
7373

7474

75+
@pytest.fixture(name="resample_batch", scope="module")
76+
def fixture_resample_batch():
77+
"""fixture_resample_batch"""
78+
args, func, expected = fixture_resample_base()
79+
return tf.stack([args, args]), func, tf.stack([expected, expected])
80+
81+
7582
@pytest.fixture(name="decode_wav", scope="module")
7683
def fixture_decode_wav():
7784
"""fixture_decode_wav"""
@@ -633,6 +640,7 @@ def func(e):
633640
[
634641
pytest.param("resample"),
635642
pytest.param("resample_1d"),
643+
pytest.param("resample_batch"),
636644
pytest.param("decode_wav"),
637645
pytest.param("encode_wav"),
638646
pytest.param("decode_wav_u8"),
@@ -683,6 +691,7 @@ def func(e):
683691
ids=[
684692
"resample",
685693
"resample[1d]",
694+
"resample[batch]",
686695
"decode_wav",
687696
"encode_wav",
688697
"decode_wav|u8",

0 commit comments

Comments
 (0)