|
20 | 20 | "https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz":
|
21 | 21 | "6b74f3901214cb2c2934e98196829835",
|
22 | 22 | }
|
| 23 | +VALIDATION_LIST = "validation_list.txt" |
| 24 | +TESTING_LIST = "testing_list.txt" |
23 | 25 |
|
24 | 26 |
|
25 | 27 | def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str, str, int]:
|
@@ -90,29 +92,22 @@ def __init__(self,
|
90 | 92 | download_url(url, root, hash_value=checksum, hash_type="md5")
|
91 | 93 | extract_archive(archive, self._path)
|
92 | 94 |
|
93 |
| - walker = walk_files(self._path, suffix=".wav", prefix=True) |
94 |
| - walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker) |
95 |
| - |
96 |
| - if subset in ["training", "validation"]: |
97 |
| - filepath = os.path.join(self._path, "validation_list.txt") |
98 |
| - with open(filepath) as f: |
99 |
| - validation_list = [os.path.join(self._path, l.strip()) for l in f.readlines()] |
100 |
| - |
101 |
| - if subset in ["training", "testing"]: |
102 |
| - filepath = os.path.join(self._path, "testing_list.txt") |
103 |
| - with open(filepath) as f: |
104 |
| - testing_list = [os.path.join(self._path, l.strip()) for l in f.readlines()] |
| 95 | + def load_list(filename): |
| 96 | + filepath = os.path.join(self._path, filename) |
| 97 | + with open(filepath) as fileobj: |
| 98 | + return [os.path.join(self._path, line.strip()) for line in fileobj] |
105 | 99 |
|
106 | 100 | if subset == "validation":
|
107 |
| - walker = validation_list |
| 101 | + self._walker = load_list(VALIDATION_LIST) |
108 | 102 | elif subset == "testing":
|
109 |
| - walker = testing_list |
| 103 | + self._walker = load_list(TESTING_LIST) |
110 | 104 | elif subset == "training":
|
111 |
| - walker = filter( |
112 |
| - lambda w: not (w in validation_list or w in testing_list), walker |
113 |
| - ) |
114 |
| - |
115 |
| - self._walker = list(walker) |
| 105 | + excludes = load_list(VALIDATION_LIST) + load_list(TESTING_LIST) |
| 106 | + walker = walk_files(self._path, suffix=".wav", prefix=True) |
| 107 | + self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w and w not in excludes] |
| 108 | + else: |
| 109 | + walker = walk_files(self._path, suffix=".wav", prefix=True) |
| 110 | + self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w] |
116 | 111 |
|
117 | 112 | def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
|
118 | 113 | """Load the n-th sample from the dataset.
|
|
0 commit comments