Skip to content

Commit 6efedf5

Browse files
committed
feedback.
1 parent a41047e commit 6efedf5

File tree

2 files changed

+21
-27
lines changed

2 files changed

+21
-27
lines changed

test/torchaudio_unittest/datasets/speechcommands_test.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,16 @@ def setUpClass(cls):
9696
utterance,
9797
)
9898
cls.samples.append(sample)
99-
label_filename = os.path.join(label, filename)
100-
if 2 <= j < 4:
99+
if j < 2:
100+
cls.train_samples.append(sample)
101+
elif j < 4:
102+
label_filename = os.path.join(label, filename)
101103
valid.write(f'{label_filename}\n')
102104
cls.valid_samples.append(sample)
103-
elif 4 <= j < 6:
105+
elif j < 6:
106+
label_filename = os.path.join(label, filename)
104107
test.write(f'{label_filename}\n')
105108
cls.test_samples.append(sample)
106-
else:
107-
cls.train_samples.append(sample)
108109

109110
def testSpeechCommands(self):
110111
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir)
@@ -141,7 +142,6 @@ def testSpeechCommandsSubsetTrain(self):
141142

142143
def testSpeechCommandsSubsetValid(self):
143144
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation")
144-
print(dataset._path)
145145

146146
num_samples = 0
147147
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
@@ -156,9 +156,8 @@ def testSpeechCommandsSubsetValid(self):
156156

157157
assert num_samples == len(self.valid_samples)
158158

159-
def testSpeechCommandsSubset(self):
159+
def testSpeechCommandsSubsetTest(self):
160160
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing")
161-
print(dataset._path)
162161

163162
num_samples = 0
164163
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(

torchaudio/datasets/speechcommands.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
"https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz":
2121
"6b74f3901214cb2c2934e98196829835",
2222
}
23+
VALIDATION_LIST = "validation_list.txt"
24+
TESTING_LIST = "testing_list.txt"
2325

2426

2527
def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str, str, int]:
@@ -90,29 +92,22 @@ def __init__(self,
9092
download_url(url, root, hash_value=checksum, hash_type="md5")
9193
extract_archive(archive, self._path)
9294

93-
walker = walk_files(self._path, suffix=".wav", prefix=True)
94-
walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker)
95-
96-
if subset in ["training", "validation"]:
97-
filepath = os.path.join(self._path, "validation_list.txt")
98-
with open(filepath) as f:
99-
validation_list = [os.path.join(self._path, l.strip()) for l in f.readlines()]
100-
101-
if subset in ["training", "testing"]:
102-
filepath = os.path.join(self._path, "testing_list.txt")
103-
with open(filepath) as f:
104-
testing_list = [os.path.join(self._path, l.strip()) for l in f.readlines()]
95+
def load_list(filename):
96+
filepath = os.path.join(self._path, filename)
97+
with open(filepath) as fileobj:
98+
return [os.path.join(self._path, line.strip()) for line in fileobj]
10599

106100
if subset == "validation":
107-
walker = validation_list
101+
self._walker = load_list(VALIDATION_LIST)
108102
elif subset == "testing":
109-
walker = testing_list
103+
self._walker = load_list(TESTING_LIST)
110104
elif subset == "training":
111-
walker = filter(
112-
lambda w: not (w in validation_list or w in testing_list), walker
113-
)
114-
115-
self._walker = list(walker)
105+
excludes = load_list(VALIDATION_LIST) + load_list(TESTING_LIST)
106+
walker = walk_files(self._path, suffix=".wav", prefix=True)
107+
self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w and w not in excludes]
108+
else:
109+
walker = walk_files(self._path, suffix=".wav", prefix=True)
110+
self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w]
116111

117112
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
118113
"""Load the n-th sample from the dataset.

0 commit comments

Comments
 (0)