Skip to content

Commit 3702055

Browse files
yangarbiterfacebook-github-bot
authored andcommitted
Import torchaudio #1639 37dbf29
Summary: Import torchaudio #1639 37dbf29 Reviewed By: carolineechen, mthrok Differential Revision: D29920658 fbshipit-source-id: 94ba8c04edcfb50e355b1ca8e937f612917ecf38
1 parent 1f1bd18 commit 3702055

File tree

9 files changed

+211
-64
lines changed

9 files changed

+211
-64
lines changed

examples/pipeline_tacotron2/text/__init__.py

Whitespace-only changes.
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2017 Keith Ito
3+
#
4+
# Permission is hereby granted, free of charge, to any person obtaining a copy
5+
# of this software and associated documentation files (the "Software"), to deal
6+
# in the Software without restriction, including without limitation the rights
7+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
# copies of the Software, and to permit persons to whom the Software is
9+
# furnished to do so, subject to the following conditions:
10+
#
11+
# The above copyright notice and this permission notice shall be included in
12+
# all copies or substantial portions of the Software.
13+
#
14+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20+
# THE SOFTWARE.
21+
#
22+
# *****************************************************************************
23+
"""
24+
Modified from https://github.com/keithito/tacotron
25+
"""
26+
27+
import inflect
28+
import re
29+
30+
31+
_inflect = inflect.engine()
32+
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
33+
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
34+
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
35+
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
36+
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
37+
_number_re = re.compile(r'[0-9]+')
38+
39+
40+
def _remove_commas(m: re.Match) -> str:
41+
return m.group(1).replace(',', '')
42+
43+
44+
def _expand_decimal_point(m: re.Match) -> str:
45+
return m.group(1).replace('.', ' point ')
46+
47+
48+
def _expand_dollars(m: re.Match) -> str:
49+
match = m.group(1)
50+
parts = match.split('.')
51+
if len(parts) > 2:
52+
return match + ' dollars' # Unexpected format
53+
dollars = int(parts[0]) if parts[0] else 0
54+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
55+
if dollars and cents:
56+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
57+
cent_unit = 'cent' if cents == 1 else 'cents'
58+
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
59+
elif dollars:
60+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
61+
return '%s %s' % (dollars, dollar_unit)
62+
elif cents:
63+
cent_unit = 'cent' if cents == 1 else 'cents'
64+
return '%s %s' % (cents, cent_unit)
65+
else:
66+
return 'zero dollars'
67+
68+
69+
def _expand_ordinal(m: re.Match) -> str:
70+
return _inflect.number_to_words(m.group(0))
71+
72+
73+
def _expand_number(m: re.Match) -> str:
74+
num = int(m.group(0))
75+
if num > 1000 and num < 3000:
76+
if num == 2000:
77+
return 'two thousand'
78+
elif num > 2000 and num < 2010:
79+
return 'two thousand ' + _inflect.number_to_words(num % 100)
80+
elif num % 100 == 0:
81+
return _inflect.number_to_words(num // 100) + ' hundred'
82+
else:
83+
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
84+
else:
85+
return _inflect.number_to_words(num, andword='')
86+
87+
88+
def normalize_numbers(text: str) -> str:
89+
text = re.sub(_comma_number_re, _remove_commas, text)
90+
text = re.sub(_pounds_re, r'\1 pounds', text)
91+
text = re.sub(_dollars_re, _expand_dollars, text)
92+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
93+
text = re.sub(_ordinal_re, _expand_ordinal, text)
94+
text = re.sub(_number_re, _expand_number, text)
95+
return text
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import unittest
2+
3+
from parameterized import parameterized
4+
5+
from .text_preprocessing import text_to_sequence
6+
7+
8+
class TestTextPreprocessor(unittest.TestCase):
9+
10+
@parameterized.expand(
11+
[
12+
["dr. Strange?", [15, 26, 14, 31, 26, 29, 11, 30, 31, 29, 12, 25, 18, 16, 10]],
13+
["ML, is fun.", [24, 23, 6, 11, 20, 30, 11, 17, 32, 25, 7]],
14+
["I love torchaudio!", [20, 11, 23, 26, 33, 16, 11, 31, 26, 29, 14, 19, 12, 32, 15, 20, 26, 2]],
15+
# 'one thousand dollars, twenty cents'
16+
["$1,000.20", [26, 25, 16, 11, 31, 19, 26, 32, 30, 12, 25, 15, 11, 15, 26, 23, 23,
17+
12, 29, 30, 6, 11, 31, 34, 16, 25, 31, 36, 11, 14, 16, 25, 31, 30]],
18+
]
19+
)
20+
def test_text_to_sequence(self, sent, seq):
21+
22+
assert (text_to_sequence(sent) == seq)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2017 Keith Ito
3+
#
4+
# Permission is hereby granted, free of charge, to any person obtaining a copy
5+
# of this software and associated documentation files (the "Software"), to deal
6+
# in the Software without restriction, including without limitation the rights
7+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
# copies of the Software, and to permit persons to whom the Software is
9+
# furnished to do so, subject to the following conditions:
10+
#
11+
# The above copyright notice and this permission notice shall be included in
12+
# all copies or substantial portions of the Software.
13+
#
14+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20+
# THE SOFTWARE.
21+
#
22+
# *****************************************************************************
23+
"""
24+
Modified from https://github.com/keithito/tacotron
25+
"""
26+
27+
from typing import List
28+
import re
29+
30+
from unidecode import unidecode
31+
32+
from .numbers import normalize_numbers
33+
34+
35+
# Regular expression matching whitespace:
36+
_whitespace_re = re.compile(r'\s+')
37+
38+
# List of (regular expression, replacement) pairs for abbreviations:
39+
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
40+
('mrs', 'misess'),
41+
('mr', 'mister'),
42+
('dr', 'doctor'),
43+
('st', 'saint'),
44+
('co', 'company'),
45+
('jr', 'junior'),
46+
('maj', 'major'),
47+
('gen', 'general'),
48+
('drs', 'doctors'),
49+
('rev', 'reverend'),
50+
('lt', 'lieutenant'),
51+
('hon', 'honorable'),
52+
('sgt', 'sergeant'),
53+
('capt', 'captain'),
54+
('esq', 'esquire'),
55+
('ltd', 'limited'),
56+
('col', 'colonel'),
57+
('ft', 'fort'),
58+
]]
59+
60+
_pad = '_'
61+
_punctuation = '!\'(),.:;? '
62+
_special = '-'
63+
_letters = 'abcdefghijklmnopqrstuvwxyz'
64+
65+
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters)
66+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
67+
68+
69+
def text_to_sequence(sent: str) -> List[int]:
70+
r'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
71+
72+
Args:
73+
sent (str): The input sentence to convert to a sequence.
74+
75+
Returns:
76+
List of integers corresponding to the symbols in the sentence.
77+
'''
78+
sent = unidecode(sent) # convert to ascii
79+
sent = sent.lower() # lower case
80+
sent = normalize_numbers(sent) # expand numbers
81+
for regex, replacement in _abbreviations: # expand abbreviations
82+
sent = re.sub(regex, replacement, sent)
83+
sent = re.sub(_whitespace_re, ' ', sent) # collapse whitespace
84+
85+
return [_symbol_to_id[s] for s in sent if s in _symbol_to_id]

test/torchaudio_unittest/transforms/batch_consistency_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ def test_batch_Resample(self):
3333
self.assertEqual(computed, expected)
3434

3535
def test_batch_MelScale(self):
36-
specgram = torch.randn(2, 31, 2786)
36+
specgram = torch.randn(2, 201, 256)
3737

3838
# Single then transform then batch
3939
expected = torchaudio.transforms.MelScale()(specgram).repeat(3, 1, 1, 1)
4040

4141
# Batch then transform
4242
computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
4343

44-
# shape = (3, 2, 201, 1394)
44+
# shape = (3, 2, 128, 256)
4545
self.assertEqual(computed, expected)
4646

4747
def test_batch_InverseMelScale(self):

test/torchaudio_unittest/transforms/torchscript_consistency_impl.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ def test_AmplitudeToDB(self):
5959
spec = torch.rand((6, 201))
6060
self._assert_consistency(T.AmplitudeToDB(), spec)
6161

62-
def test_MelScale_invalid(self):
63-
with self.assertRaises(ValueError):
64-
torch.jit.script(T.MelScale())
65-
6662
def test_MelScale(self):
6763
spec_f = torch.rand((1, 201, 6))
6864
self._assert_consistency(T.MelScale(n_stft=201), spec_f)

test/torchaudio_unittest/transforms/transforms_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,17 @@ def test_AmplitudeToDB(self):
5555
self.assertEqual(mag_to_db_torch, power_to_db_torch)
5656

5757
def test_melscale_load_save(self):
58-
specgram = torch.ones(1, 1000, 100)
58+
specgram = torch.ones(1, 201, 100)
5959
melscale_transform = transforms.MelScale()
6060
melscale_transform(specgram)
6161

62-
melscale_transform_copy = transforms.MelScale(n_stft=1000)
62+
melscale_transform_copy = transforms.MelScale()
6363
melscale_transform_copy.load_state_dict(melscale_transform.state_dict())
6464

6565
fb = melscale_transform.fb
6666
fb_copy = melscale_transform_copy.fb
6767

68-
self.assertEqual(fb_copy.size(), (1000, 128))
68+
self.assertEqual(fb_copy.size(), (201, 128))
6969
self.assertEqual(fb, fb_copy)
7070

7171
def test_melspectrogram_load_save(self):

test/torchaudio_unittest/transforms/transforms_test_impl.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import warnings
2-
31
import torch
42
import torchaudio.transforms as T
53

@@ -63,22 +61,6 @@ def test_InverseMelScale(self):
6361
assert _get_ratio(relative_diff < 1e-3) > 5e-3
6462
assert _get_ratio(relative_diff < 1e-5) > 1e-5
6563

66-
def test_melscale_unset_weight_warning(self):
67-
"""Issue a warning if MelScale initialized without a weight
68-
69-
As part of the deprecation of lazy intialization behavior (#1510),
70-
issue a warning if `n_stft` is not set.
71-
"""
72-
with warnings.catch_warnings(record=True) as caught_warnings:
73-
warnings.simplefilter("always")
74-
T.MelScale(n_mels=64, sample_rate=8000)
75-
assert len(caught_warnings) == 1
76-
77-
with warnings.catch_warnings(record=True) as caught_warnings:
78-
warnings.simplefilter("always")
79-
T.MelScale(n_mels=64, sample_rate=8000, n_stft=201)
80-
assert len(caught_warnings) == 0
81-
8264
@nested_params(
8365
["sinc_interpolation", "kaiser_window"],
8466
[16000, 44100],

torchaudio/transforms.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,8 @@ class MelScale(torch.nn.Module):
244244
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
245245
f_min (float, optional): Minimum frequency. (Default: ``0.``)
246246
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
247-
n_stft (int, optional): Number of bins in STFT. Calculated from first input
248-
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
249-
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
247+
n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)
248+
norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
250249
(area normalization). (Default: ``None``)
251250
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
252251
"""
@@ -257,7 +256,7 @@ def __init__(self,
257256
sample_rate: int = 16000,
258257
f_min: float = 0.,
259258
f_max: Optional[float] = None,
260-
n_stft: Optional[int] = None,
259+
n_stft: int = 201,
261260
norm: Optional[str] = None,
262261
mel_scale: str = "htk") -> None:
263262
super(MelScale, self).__init__()
@@ -269,35 +268,11 @@ def __init__(self,
269268
self.mel_scale = mel_scale
270269

271270
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
272-
273-
if n_stft is None or n_stft == 0:
274-
warnings.warn(
275-
'Initialization of torchaudio.transforms.MelScale with an unset weight '
276-
'`n_stft=None` is deprecated and will be removed in release 0.10. '
277-
'Please set a proper `n_stft` value. Typically this is `n_fft // 2 + 1`. '
278-
'Refer to https://github.com/pytorch/audio/issues/1510 '
279-
'for more details.'
280-
)
281-
282-
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
271+
fb = F.create_fb_matrix(
283272
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
284273
self.mel_scale)
285274
self.register_buffer('fb', fb)
286275

287-
def __prepare_scriptable__(self):
288-
r"""If `self.fb` is empty, the `forward` method will try to resize the parameter,
289-
which does not work once the transform is scripted. However, this error does not happen
290-
until the transform is executed. This is inconvenient especially if the resulting
291-
TorchScript object is executed in other environments. Therefore, we check the
292-
validity of `self.fb` here and fail if the resulting TS does not work.
293-
294-
Returns:
295-
MelScale: self
296-
"""
297-
if self.fb.numel() == 0:
298-
raise ValueError("n_stft must be provided at construction")
299-
return self
300-
301276
def forward(self, specgram: Tensor) -> Tensor:
302277
r"""
303278
Args:
@@ -311,14 +286,6 @@ def forward(self, specgram: Tensor) -> Tensor:
311286
shape = specgram.size()
312287
specgram = specgram.reshape(-1, shape[-2], shape[-1])
313288

314-
if self.fb.numel() == 0:
315-
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max,
316-
self.n_mels, self.sample_rate, self.norm,
317-
self.mel_scale)
318-
# Attributes cannot be reassigned outside __init__ so workaround
319-
self.fb.resize_(tmp_fb.size())
320-
self.fb.copy_(tmp_fb)
321-
322289
# (channel, frequency, time).transpose(...) dot (frequency, n_mels)
323290
# -> (channel, time, n_mels).transpose(...)
324291
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)

0 commit comments

Comments
 (0)