Skip to content

Commit

Permalink
support finetuning from pretrained checkpoints (#108)
Browse files Browse the repository at this point in the history
* support pre_train model add doc

* Update docs for finetuning

---------

Co-authored-by: autumn <2>
Co-authored-by: yqzhishen <yangqian_1015@icloud.com>
  • Loading branch information
autumn-2-net and yqzhishen committed Jul 17, 2023
1 parent ba7de62 commit 7847af2
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 8 deletions.
78 changes: 72 additions & 6 deletions basics/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,70 @@ def __init__(self, *args, **kwargs):
def setup(self, stage):
self.phone_encoder = self.build_phone_encoder()
self.model = self.build_model()
# utils.load_warp(self)
if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None:
self.load_finetune_ckpt( self.load_pre_train_model())
self.print_arch()
self.build_losses()
self.train_dataset = self.dataset_cls(hparams['train_set_name'])
self.valid_dataset = self.dataset_cls(hparams['valid_set_name'])

def load_finetune_ckpt(
self, state_dict
):

adapt_shapes = hparams['finetune_strict_shapes']
if not adapt_shapes:
cur_model_state_dict = self.state_dict()
unmatched_keys = []
for key, param in state_dict.items():
if key in cur_model_state_dict:
new_param = cur_model_state_dict[key]
if new_param.shape != param.shape:
unmatched_keys.append(key)
print('| Unmatched keys: ', key, new_param.shape, param.shape)
for key in unmatched_keys:
del state_dict[key]
self.load_state_dict(state_dict, strict=False)

def load_pre_train_model(self):

pre_train_ckpt_path = hparams.get('finetune_ckpt_path')
blacklist = hparams.get('finetune_ignored_params')
# whitelist=hparams.get('pre_train_whitelist')
if blacklist is None:
blacklist = []
# if whitelist is None:
# raise RuntimeError("")

if pre_train_ckpt_path is not None:
ckpt = torch.load(pre_train_ckpt_path)
# if ckpt.get('category') is None:
# raise RuntimeError("")

if isinstance(self.model, CategorizedModule):
self.model.check_category(ckpt.get('category'))

state_dict = {}
for i in ckpt['state_dict']:
# if 'diffusion' in i:
# if i in rrrr:
# continue
skip = False
for b in blacklist:
if i.startswith(b):
skip = True
break

if skip:
continue

state_dict[i] = ckpt['state_dict'][i]
print(i)
return state_dict
else:
raise RuntimeError("")

@staticmethod
def build_phone_encoder():
phone_list = build_phoneme_list()
Expand Down Expand Up @@ -292,6 +351,11 @@ def on_test_end(self):
def start(cls):
pl.seed_everything(hparams['seed'], workers=True)
task = cls()

# if pre_train is not None:
# task.load_state_dict(pre_train,strict=False)
# print("load success-------------------------------------------------------------------")

work_dir = pathlib.Path(hparams['work_dir'])
trainer = pl.Trainer(
accelerator=hparams['pl_trainer_accelerator'],
Expand Down Expand Up @@ -379,16 +443,16 @@ def on_load_checkpoint(self, checkpoint):
from utils import simulate_lr_scheduler
if checkpoint.get('trainer_stage', '') == RunningStage.VALIDATING.value:
self.skip_immediate_validation = True

optimizer_args = hparams['optimizer_args']
scheduler_args = hparams['lr_scheduler_args']

if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args:
optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2'])

if checkpoint.get('optimizer_states', None):
opt_states = checkpoint['optimizer_states']
assert len(opt_states) == 1 # only support one optimizer
assert len(opt_states) == 1 # only support one optimizer
opt_state = opt_states[0]
for param_group in opt_state['param_groups']:
for k, v in optimizer_args.items():
Expand All @@ -398,13 +462,14 @@ def on_load_checkpoint(self, checkpoint):
rank_zero_info(f'| Overriding optimizer parameter {k} from checkpoint: {param_group[k]} -> {v}')
param_group[k] = v
if 'initial_lr' in param_group and param_group['initial_lr'] != optimizer_args['lr']:
rank_zero_info(f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}')
rank_zero_info(
f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}')
param_group['initial_lr'] = optimizer_args['lr']

if checkpoint.get('lr_schedulers', None):
assert checkpoint.get('optimizer_states', False)
schedulers = checkpoint['lr_schedulers']
assert len(schedulers) == 1 # only support one scheduler
assert len(schedulers) == 1 # only support one scheduler
scheduler = schedulers[0]
for k, v in scheduler_args.items():
if k in scheduler and scheduler[k] != v:
Expand All @@ -419,5 +484,6 @@ def on_load_checkpoint(self, checkpoint):
scheduler['_last_lr'] = new_lrs
for param_group, new_lr in zip(checkpoint['optimizer_states'][0]['param_groups'], new_lrs):
if param_group['lr'] != new_lr:
rank_zero_info(f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}')
rank_zero_info(
f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}')
param_group['lr'] = new_lr
10 changes: 10 additions & 0 deletions configs/acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,13 @@ max_updates: 320000
num_ckpt_keep: 5
permanent_ckpt_start: 200000
permanent_ckpt_interval: 40000


finetune_enabled: false
finetune_ckpt_path: null

finetune_ignored_params:
- model.fs2.encoder.embed_tokens
- model.fs2.txt_embed
- model.fs2.spk_embed
finetune_strict_shapes: true
11 changes: 11 additions & 0 deletions configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,14 @@ pl_trainer_precision: '32-true'
pl_trainer_num_nodes: 1
pl_trainer_strategy: 'auto'
ddp_backend: 'nccl' # choose from 'gloo', 'nccl', 'nccl_no_p2p'

###########
# finetune
###########

finetune_enabled: false
finetune_ckpt_path: null
finetune_ignored_params: []


finetune_strict_shapes: true
8 changes: 8 additions & 0 deletions configs/variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,11 @@ max_updates: 288000
num_ckpt_keep: 5
permanent_ckpt_start: 180000
permanent_ckpt_interval: 10000

finetune_enabled: false
finetune_ckpt_path: null
finetune_ignored_params:
- model.spk_embed
- model.fs2.txt_embed
- model.fs2.encoder.embed_tokens
finetune_strict_shapes: true
93 changes: 93 additions & 0 deletions docs/ConfigurationSchemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,98 @@ int

2048

### finetune_enabled

Whether to finetune from a pretrained model.

#### visibility

all

#### scope

training

#### customizability

normal

#### type

bool

#### default

False

### finetune_ckpt_path

Path to the pretrained model for finetuning.

#### visibility

all

#### scope

training

#### customizability

normal

#### type

str

#### default

null

### finetune_ignored_params

Prefixes of parameter key names in the state dict of the pretrained model that need to be dropped before finetuning.

#### visibility

all

#### scope

training

#### customizability

normal

#### type

list

### finetune_strict_shapes

Whether to raise error if the tensor shapes of any parameter of the pretrained model and the target model mismatch. If set to `False`, parameters with mismatching shapes will be skipped.

#### visibility

all

#### scope

training

#### customizability

normal

#### type

bool

#### default

True

### fmax

Maximum frequency of mel extraction.
Expand Down Expand Up @@ -3324,3 +3416,4 @@ int

2048


2 changes: 2 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
import os

import sys
from pathlib import Path

Expand All @@ -22,6 +23,7 @@ def run_task():
pkg = ".".join(hparams["task_cls"].split(".")[:-1])
cls_name = hparams["task_cls"].split(".")[-1]
task_cls = getattr(importlib.import_module(pkg), cls_name)

task_cls.start()


Expand Down
16 changes: 14 additions & 2 deletions utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch.nn.functional as F

from basics.base_module import CategorizedModule
from utils.hparams import hparams
from utils.training_utils import get_latest_checkpoint_path


def tensors_to_scalars(metrics):
Expand Down Expand Up @@ -149,7 +151,8 @@ def filter_kwargs(dict_to_filter, kwarg_obj):

sig = inspect.signature(kwarg_obj)
filter_keys = [param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD]
filtered_dict = {filter_key: dict_to_filter[filter_key] for filter_key in filter_keys if filter_key in dict_to_filter}
filtered_dict = {filter_key: dict_to_filter[filter_key] for filter_key in filter_keys if
filter_key in dict_to_filter}
return filtered_dict


Expand Down Expand Up @@ -208,6 +211,14 @@ def load_ckpt(
print(f'| load {shown_model_name} from \'{checkpoint_path}\'.')







# return load_pre_train_model()


def remove_padding(x, padding_idx=0):
if x is None:
return None
Expand Down Expand Up @@ -265,7 +276,8 @@ def simulate_lr_scheduler(optimizer_args, scheduler_args, last_epoch=-1, num_par
[{'params': torch.nn.Parameter(), 'initial_lr': optimizer_args['lr']} for _ in range(num_param_groups)],
**optimizer_args
)
scheduler = build_object_from_config(scheduler_args['scheduler_cls'], optimizer, last_epoch=last_epoch, **scheduler_args)
scheduler = build_object_from_config(scheduler_args['scheduler_cls'], optimizer, last_epoch=last_epoch,
**scheduler_args)

if hasattr(scheduler, '_get_closed_form_lr'):
return scheduler._get_closed_form_lr()
Expand Down

0 comments on commit 7847af2

Please sign in to comment.