Skip to content

Commit

Permalink
Update preprocess_flist_config.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Stardust-minus authored Jul 22, 2023
1 parent 0f5847a commit d07d92b
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions preprocess_flist_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import wave
from random import shuffle

from loguru import logger
from tqdm import tqdm

import diffusion.logger.utils as du
Expand Down Expand Up @@ -46,9 +46,9 @@ def get_wav_duration(file_path):
if not file.endswith("wav"):
continue
if not pattern.match(file):
print(f"warning:文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
logger.warning(f"文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
if get_wav_duration(file) < 0.3:
print("skip too short audio:", file)
logger.info("Skip too short audio:" + file)
continue
new_wavs.append(file)
wavs = new_wavs
Expand All @@ -59,13 +59,13 @@ def get_wav_duration(file_path):
shuffle(train)
shuffle(val)

print("Writing", args.train_list)
logger.info("Writing" + args.train_list)
with open(args.train_list, "w") as f:
for fname in tqdm(train):
wavpath = fname
f.write(wavpath + "\n")

print("Writing", args.val_list)
logger.info("Writing" + args.val_list)
with open(args.val_list, "w") as f:
for fname in tqdm(val):
wavpath = fname
Expand Down Expand Up @@ -97,8 +97,8 @@ def get_wav_duration(file_path):
if args.vol_aug:
config_template["train"]["vol_aug"] = config_template["model"]["vol_embedding"] = True

print("Writing configs/config.json")
logger.info("Writing to configs/config.json")
with open("configs/config.json", "w") as f:
json.dump(config_template, f, indent=2)
print("Writing configs/diffusion.yaml")
logger.info("Writing to configs/diffusion.yaml")
du.save_config("configs/diffusion.yaml",d_config_template)

0 comments on commit d07d92b

Please sign in to comment.