Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-speaker preprocessing, training and inference #34

Merged
merged 24 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
77de178
Support multi-speaker preprocessing
yqzhishen Jan 13, 2023
025ccf8
Support multi-speaker preprocessing
yqzhishen Jan 13, 2023
bcd875c
Merge branch 'multi-spk' of https://github.com/openvpi/DiffSinger int…
yqzhishen Jan 14, 2023
45ffe4b
Save speaker id instead of speaker name
yqzhishen Jan 14, 2023
2f47002
Support speaker embedding
yxlllc Jan 14, 2023
6e46084
Fix yaml data type error
yqzhishen Jan 15, 2023
cdf9801
Fix item order
yqzhishen Jan 15, 2023
6a2716d
Copy spk map to work dir
yqzhishen Jan 15, 2023
1684028
Support multi-speaker preprocessing
yqzhishen Jan 13, 2023
0401704
Save speaker id instead of speaker name
yqzhishen Jan 14, 2023
b868e01
Support speaker embedding
yxlllc Jan 14, 2023
2fc5d15
Fix yaml data type error
yqzhishen Jan 15, 2023
20a6dd3
Fix item order
yqzhishen Jan 15, 2023
a6fecdb
Copy spk map to work dir
yqzhishen Jan 15, 2023
6430ab8
Merge branch 'multi-spk' of https://github.com/openvpi/DiffSinger int…
yqzhishen Jan 15, 2023
eded2bd
Support multi-speaker inference
yqzhishen Jan 16, 2023
9ee7405
Fix multi-speaker inference error
yqzhishen Jan 16, 2023
bbce763
Fix NoneType conversion to Tensor error caused by spk_id
yqzhishen Jan 16, 2023
20d152a
Support static speaker mix at inference time
yqzhishen Jan 18, 2023
7ecbb05
Support static speaker mix at inference time
yqzhishen Jan 18, 2023
0678e54
Check duplicate speaker id
yqzhishen Jan 18, 2023
d6924c4
Allow hyphen in speaker name
yqzhishen Jan 19, 2023
8929c27
Support configuring multiple datasets for combined models
yqzhishen Jan 19, 2023
fadac19
Support configuring multiple datasets for combined models
yqzhishen Jan 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions basics/base_binarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,24 @@ class BaseBinarizer:
3. load_ph_set:
the phoneme set.
'''
def __init__(self, processed_data_dir=None, item_attributes=BASE_ITEM_ATTRIBUTES):
if processed_data_dir is None:
processed_data_dir = hparams['processed_data_dir']
self.processed_data_dirs = processed_data_dir.split(",")
def __init__(self, data_dir=None, item_attributes=None):
if item_attributes is None:
item_attributes = BASE_ITEM_ATTRIBUTES
if data_dir is None:
data_dir = hparams['raw_data_dir']

if 'speakers' not in hparams:
speakers = hparams['datasets']
hparams['speakers'] = hparams['datasets']
else:
speakers = hparams['speakers']
assert isinstance(speakers, list), 'Speakers must be a list'
assert len(speakers) == len(set(speakers)), 'Speakers cannot contain duplicate names'

self.raw_data_dirs = data_dir if isinstance(data_dir, list) else [data_dir]
assert len(speakers) == len(self.raw_data_dirs), \
'Number of raw data dirs must equal number of speaker names!'

self.binarization_args = hparams['binarization_args']
self.pre_align_args = hparams['pre_align_args']

Expand All @@ -53,21 +67,21 @@ def __init__(self, processed_data_dir=None, item_attributes=BASE_ITEM_ATTRIBUTES
self.item_attributes = item_attributes

# load each dataset
for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
self.load_meta_data(processed_data_dir, ds_id)
for ds_id, data_dir in enumerate(self.raw_data_dirs):
self.load_meta_data(data_dir, ds_id)
if ds_id == 0:
# check program correctness
assert all([attr in self.item_attributes for attr in list(self.items.values())[0].keys()])
self.item_names = sorted(list(self.items.keys()))

if self.binarization_args['shuffle']:
random.seed(1234)
random.seed(hparams['seed'])
random.shuffle(self.item_names)

# set default get_pitch algorithm
self.get_pitch_algorithm = get_pitch_parselmouth

def load_meta_data(self, processed_data_dir, ds_id):
def load_meta_data(self, raw_data_dir, ds_id):
raise NotImplementedError

@property
Expand All @@ -83,12 +97,8 @@ def test_item_names(self):
raise NotImplementedError

def build_spk_map(self):
spk_map = set()
for item_name in self.item_names:
spk_name = self.items[item_name]['spk_id']
spk_map.add(spk_name)
spk_map = {x: i for i, x in enumerate(sorted(list(spk_map)))}
assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map)
spk_map = {x: i for i, x in enumerate(hparams['speakers'])}
assert len(spk_map) <= hparams['num_spk'], 'Actual number of speakers should be smaller than num_spk!'
return spk_map

def item_name2spk_id(self, item_name):
Expand Down Expand Up @@ -164,8 +174,7 @@ def process_data_split(self, prefix, multiprocess=False):
f0s.append(item['f0'])
else:
# code for single cpu processing
for i in tqdm(reversed(range(len(args))), total=len(args)):
a = args[i]
for a in tqdm(args):
item = self.process_item(*a)
if item is None:
continue
Expand Down
7 changes: 6 additions & 1 deletion basics/base_svs_infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# coding=utf8
import json
import os

import torch
Expand Down Expand Up @@ -44,7 +45,11 @@ def __init__(self, hparams, device=None, load_model=True, load_vocoder=True):
phone_list = build_phoneme_list()
self.ph_encoder = TokenTextEncoder(vocab_list=phone_list, replace_oov=',')
self.pinyin2phs = build_g2p_dictionary()
self.spk_map = {'opencpop': 0}
if hparams['use_spk_id']:
with open(os.path.join(hparams['work_dir'], 'spk_map.json'), 'r', encoding='utf8') as f:
self.spk_map = json.load(f)
assert isinstance(self.spk_map, dict) and len(self.spk_map) > 0, 'Invalid or empty speaker map!'
assert len(self.spk_map) == len(set(self.spk_map.values())), 'Duplicate speaker id in speaker map!'
self.model = self.build_model()
self.model.eval()
self.model.to(self.device)
Expand Down
12 changes: 8 additions & 4 deletions basics/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,20 +220,24 @@ def start(cls):
row_log_interval=hparams['log_interval'],
max_updates=hparams['max_updates'],
num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams['validate'] else 10000,
accumulate_grad_batches=hparams['accumulate_grad_batches'])
accumulate_grad_batches=hparams['accumulate_grad_batches']
)
if not hparams['infer']: # train
# copy_code = input(f'{hparams["save_codes"]} code backup? y/n: ') == 'y'
copy_code = True # backup code every time
if copy_code:
t = datetime.now().strftime('%Y%m%d%H%M%S')
code_dir = f'{work_dir}/codes/{t}'
# TODO: test filesystem calls
os.makedirs(code_dir, exist_ok=True)
# subprocess.check_call(f'mkdir "{code_dir}"', shell=True)
for c in hparams['save_codes']:
shutil.copytree(c, code_dir, dirs_exist_ok=True)
# subprocess.check_call(f'xcopy "{c}" "{code_dir}/" /s /e /y', shell=True)
print(f"| Copied codes to {code_dir}.")
# Copy spk_map.json to work dir
spk_map = os.path.join(work_dir, 'spk_map.json')
spk_map_orig = os.path.join(hparams['binary_data_dir'], 'spk_map.json')
if not os.path.exists(spk_map) and os.path.exists(spk_map_orig):
shutil.copy(spk_map_orig, spk_map)
print(f"| Copied spk map to {spk_map}.")
trainer.checkpoint_callback.task = task
trainer.fit(task)
else:
Expand Down
12 changes: 6 additions & 6 deletions data_gen/midisinging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@


class MidiSingingBinarizer(SingingBinarizer):
def __init__(self, processed_data_dir=None, item_attributes=MIDISINGING_ITEM_ATTRIBUTES):
super().__init__(processed_data_dir, item_attributes)
def __init__(self, raw_data_dir=None, item_attributes=MIDISINGING_ITEM_ATTRIBUTES):
super().__init__(raw_data_dir, item_attributes)

def load_meta_data(self, processed_data_dir, ds_id):
def load_meta_data(self, raw_data_dir, ds_id):
'''
NOTE: this function is *isolated* from other scripts, which means
it may not be compatible with the current version.
'''
return
meta_midi = json.load(open(os.path.join(processed_data_dir, 'meta.json'), encoding='utf-8')) # [list of dict]
meta_midi = json.load(open(os.path.join(raw_data_dir, 'meta.json'), encoding='utf-8')) # [list of dict]

for song_item in meta_midi:
item_name = raw_item_name = song_item['item_name']
if len(self.processed_data_dirs) > 1:
if len(self.raw_data_dirs) > 1:
item_name = f'ds{ds_id}_{item_name}'

item = {}
Expand All @@ -37,7 +37,7 @@ def load_meta_data(self, processed_data_dir, ds_id):
item['midi_dur'] = song_item['notes_dur']
item['is_slur'] = song_item['is_slur']
item['spk_id'] = 'pop-cs'
if len(self.processed_data_dirs) > 1:
if len(self.raw_data_dirs) > 1:
self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"

self.items[item_name] = item
Expand Down
31 changes: 27 additions & 4 deletions data_gen/opencpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,36 @@

class OpencpopBinarizer(MidiSingingBinarizer):
def split_train_test_set(self, item_names):
item_names = deepcopy(item_names)
test_item_names = [x for x in item_names if any([x.startswith(ts) for ts in hparams['test_prefixes']])]
item_names = set(deepcopy(item_names))
prefixes = set([str(pr) for pr in hparams['test_prefixes']])
test_item_names = set()
# Add prefixes that specified speaker index and matches exactly item name to test set
for prefix in deepcopy(prefixes):
if prefix in item_names:
test_item_names.add(prefix)
prefixes.remove(prefix)
# Add prefixes that exactly matches item name without speaker id to test set
for prefix in deepcopy(prefixes):
for name in item_names:
if name.split(':')[-1] == prefix:
test_item_names.add(name)
prefixes.remove(prefix)
# Add names with one of the remaining prefixes to test set
for prefix in deepcopy(prefixes):
for name in item_names:
if name.startswith(prefix):
test_item_names.add(name)
prefixes.remove(prefix)
for prefix in prefixes:
for name in item_names:
if name.split(':')[-1].startswith(prefix):
test_item_names.add(name)
test_item_names = sorted(list(test_item_names))
train_item_names = [x for x in item_names if x not in set(test_item_names)]
logging.info("train {}".format(len(train_item_names)))
logging.info("test {}".format(len(test_item_names)))
return train_item_names, test_item_names

def load_meta_data(self, processed_data_dir, ds_id):
def load_meta_data(self, raw_data_dir, ds_id):
from preprocessing.opencpop import File2Batch
self.items = File2Batch.file2temporary_dict()
self.items.update(File2Batch.file2temporary_dict(raw_data_dir, ds_id))
13 changes: 6 additions & 7 deletions data_gen/singing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@


class SingingBinarizer(BaseBinarizer):
def __init__(self, processed_data_dir=None, item_attributes=SINGING_ITEM_ATTRIBUTES):
super().__init__(processed_data_dir, item_attributes)
def __init__(self, raw_data_dir=None, item_attributes=SINGING_ITEM_ATTRIBUTES):
super().__init__(raw_data_dir, item_attributes)

print('spkers: ', set(item['spk_id'] for item in self.items.values()))
self.item_names = sorted(list(self.items.keys()))
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)

Expand All @@ -34,12 +33,12 @@ def valid_item_names(self):
def test_item_names(self):
return self._test_item_names

def load_meta_data(self, processed_data_dir, ds_id):
def load_meta_data(self, raw_data_dir, ds_id):
wav_suffix = '_wf0.wav'
txt_suffix = '.txt'
ph_suffix = '_ph.txt'
tg_suffix = '.TextGrid'
all_wav_pieces = glob.glob(f'{processed_data_dir}/*/*{wav_suffix}')
all_wav_pieces = glob.glob(f'{raw_data_dir}/*/*{wav_suffix}')

for piece_path in all_wav_pieces:
item = {}
Expand All @@ -49,8 +48,8 @@ def load_meta_data(self, processed_data_dir, ds_id):
item['wav_fn'] = piece_path
item['spk_id'] = re.split('[-#]', piece_path.split('/')[-2])[0]
item['tg_fn'] = piece_path.replace(wav_suffix, tg_suffix)
item_name = piece_path[len(processed_data_dir)+1:].replace('/', '-')[:-len(wav_suffix)]
if len(self.processed_data_dirs) > 1:
item_name = piece_path[len(raw_data_dir) + 1:].replace('/', '-')[:-len(wav_suffix)]
if len(self.raw_data_dirs) > 1:
item_name = f'ds{ds_id}_{item_name}'
item['spk_id'] = f"ds{ds_id}_{item['spk_id']}"

Expand Down
38 changes: 28 additions & 10 deletions inference/ds_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,19 @@ def preprocess_input(self, inp, input_type='word'):
"""

item_name = inp.get('item_name', '<ITEM_NAME>')
spk_name = inp.get('spk_name', 'opencpop')

# single spk
spk_id = self.spk_map[spk_name]
if hparams['use_spk_id']:
spk_mix = inp.get('spk_mix')
if spk_mix is None:
for name in self.spk_map.keys():
spk_mix = {name: 1.0}
break
if len(spk_mix) == 1:
print(f'Using speaker \'{list(spk_mix.keys())[0]}\'')
else:
print_mix = '|'.join([f'{n}:{"%.3f" % spk_mix[n]}' for n in spk_mix])
print(f'Using speaker mix \'{print_mix}\'')
else:
spk_mix = None

# get ph seq, note lst, midi dur lst, is slur lst.
if input_type == 'word':
Expand Down Expand Up @@ -101,7 +110,7 @@ def preprocess_input(self, inp, input_type='word'):
return None

ph_token = self.ph_encoder.encode(ph_seq)
item = {'item_name': item_name, 'text': inp['text'], 'ph': ph_seq, 'spk_id': spk_id,
item = {'item_name': item_name, 'text': inp['text'], 'ph': ph_seq, 'spk_mix': spk_mix,
'ph_token': ph_token, 'pitch_midi': np.asarray(midis), 'midi_dur': np.asarray(midi_dur_lst),
'is_slur': np.asarray(is_slur), 'ph_dur': None, 'f0_timestep': 0., 'f0_seq': None}
item['ph_len'] = len(item['ph_token'])
Expand All @@ -117,8 +126,11 @@ def input_to_batch(self, item):
ph = [item['ph']]
txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device)
txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device)
spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device)

if hparams['use_spk_id']:
spk_mix_map = item['spk_mix']
spk_mixes = {torch.LongTensor([self.spk_map[n]]).to(self.device) : spk_mix_map[n] for n in spk_mix_map}
else:
spk_mixes = None
pitch_midi = torch.LongTensor(item['pitch_midi'])[None, :hparams['max_frames']].to(self.device)
midi_dur = torch.FloatTensor(item['midi_dur'])[None, :hparams['max_frames']].to(self.device)
is_slur = torch.LongTensor(item['is_slur'])[None, :hparams['max_frames']].to(self.device)
Expand Down Expand Up @@ -151,7 +163,7 @@ def input_to_batch(self, item):
'ph': ph,
'txt_tokens': txt_tokens,
'txt_lengths': txt_lengths,
'spk_ids': spk_ids,
'spk_mixes': spk_mixes,
'pitch_midi': pitch_midi,
'midi_dur': midi_dur,
'is_slur': is_slur,
Expand All @@ -163,9 +175,15 @@ def input_to_batch(self, item):
def forward_model(self, inp, return_mel=False):
sample = self.input_to_batch(inp)
txt_tokens = sample['txt_tokens'] # [B, T_t]
spk_id = sample.get('spk_ids')
with torch.no_grad():
output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True,
if hparams['use_spk_id']:
spk_mixes = sample['spk_mixes']
spk_mix_embed = [self.model.fs2.spk_embed(spk_id)[:, None, :] * spk_mixes[spk_id] for spk_id in
spk_mixes]
spk_mix_embed = torch.stack(spk_mix_embed, dim=1).sum(dim=1)
else:
spk_mix_embed = None
output = self.model(txt_tokens, spk_mix_embed=spk_mix_embed, ref_mels=None, infer=True,
pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'],
is_slur=sample['is_slur'], mel2ph=sample['mel2ph'], f0=sample['log2f0'])
mel_out = output['mel_out'] # [B, T,80]
Expand Down
8 changes: 7 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from utils.audio import save_wav
from utils.hparams import set_hparams, hparams
from utils.slur_utils import merge_slurs
from utils.spk_utils import parse_commandline_spk_mix

sys.path.insert(0, '/')
root_dir = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -22,6 +23,7 @@
parser = argparse.ArgumentParser(description='Run DiffSinger inference')
parser.add_argument('proj', type=str, help='Path to the input file')
parser.add_argument('--exp', type=str, required=False, help='Selection of model')
parser.add_argument('--spk', type=str, required=False, help='Speaker name or mix of speakers')
parser.add_argument('--out', type=str, required=False, help='Path of the output folder')
parser.add_argument('--title', type=str, required=False, help='Title of output file')
parser.add_argument('--num', type=int, required=False, default=1, help='Number of runs')
Expand Down Expand Up @@ -58,7 +60,6 @@
else:
print(f'| found ckpt by name: {exp}')


out = args.out
if not out:
out = os.path.dirname(os.path.abspath(args.proj))
Expand Down Expand Up @@ -106,6 +107,8 @@
warnings.filterwarnings(action='default')
infer_ins = DiffSingerE2EInfer(hparams, load_vocoder=not args.mel)

spk_mix = parse_commandline_spk_mix(args.spk) if hparams['use_spk_id'] and args.spk is not None else None


def infer_once(path: str, save_mel=False):
if save_mel:
Expand Down Expand Up @@ -145,6 +148,9 @@ def infer_once(path: str, save_mel=False):
torch.manual_seed(torch.seed() & 0xffff_ffff)
torch.cuda.manual_seed_all(torch.seed() & 0xffff_ffff)

if spk_mix is not None:
param['spk_mix'] = spk_mix

if not hparams.get('use_midi', False):
merge_slurs(param)
if save_mel:
Expand Down
12 changes: 11 additions & 1 deletion modules/naive_frontend/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def __init__(self, dictionary):
self.pitch_embed = Linear(1, hparams['hidden_size'])
else:
raise ValueError('f0_embed_type must be \'discrete\' or \'continuous\'.')
if hparams.get('use_spk_id', False):
self.spk_embed = Embedding(hparams['num_spk'], hparams['hidden_size'])

def forward(self, txt_tokens, mel2ph=None, spk_embed_id=None,
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
Expand Down Expand Up @@ -73,5 +75,13 @@ def forward(self, txt_tokens, mel2ph=None, spk_embed_id=None,
f0_mel = (1 + f0_denorm / 700).log()
pitch_embed = self.pitch_embed(f0_mel[:, :, None])

ret = {'decoder_inp': decoder_inp + pitch_embed, 'f0_denorm': f0_denorm}
if hparams['use_spk_id']:
if infer:
spk_embed = kwarg.get('spk_mix_embed')
else:
spk_embed = self.spk_embed(spk_embed_id)[:, None, :]
else:
spk_embed = 0

ret = {'decoder_inp': decoder_inp + pitch_embed + spk_embed, 'f0_denorm': f0_denorm}
return ret
Loading