Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update wave2vec #110

Merged
merged 38 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
7bc933c
add wav2vec2
upvenly May 16, 2023
53c0314
add wav2vec2
upvenly May 16, 2023
91e0825
support-different-image
upvenly May 18, 2023
4916362
update waveglow
upvenly Jun 1, 2023
0089955
update waveglow
upvenly Jun 1, 2023
887cfbb
update waveglow
upvenly Jun 1, 2023
dfbcb23
update waveglow
upvenly Jun 1, 2023
a680dc8
update waveglow
upvenly Jun 1, 2023
265ebac
update waveglow
upvenly Jun 1, 2023
44e35d1
update waveglow
upvenly Jun 1, 2023
885380c
update
upvenly Jun 1, 2023
f1f7898
update
upvenly Jun 1, 2023
9e8dc0c
update
upvenly Jun 1, 2023
f00658d
update
upvenly Jun 1, 2023
d53ace4
update
upvenly Jun 1, 2023
e80b2ce
update
upvenly Jun 1, 2023
d6d43c4
update
upvenly Jun 2, 2023
b9f10df
update
upvenly Jun 2, 2023
e999c7e
update according to review
upvenly Jun 5, 2023
569b24e
merge main
upvenly Jun 5, 2023
7773318
merge main
upvenly Jun 5, 2023
25ccffb
merge main
upvenly Jun 5, 2023
3c1d58b
update according to review
upvenly Jun 5, 2023
8f6e635
update according to review
upvenly Jun 5, 2023
a767546
update according to review
upvenly Jun 5, 2023
1eba6f0
update according to review
upvenly Jun 5, 2023
cfa7050
update according to review
upvenly Jun 5, 2023
d45f122
update according to review
upvenly Jun 5, 2023
e46af7b
add file
upvenly Jun 6, 2023
2ff46e8
Merge branch 'main' of github.com:FlagOpen/FlagPerf into wwl/support-…
upvenly Jun 6, 2023
2f939ad
add extern
upvenly Jun 6, 2023
b5e007c
add extern
upvenly Jun 6, 2023
4d51139
Merge branch 'main' of github.com:FlagOpen/FlagPerf into wwl/support-…
upvenly Jun 6, 2023
51fbcb9
update for adapter
upvenly Jun 6, 2023
41babcd
update wave2vec
upvenly Jun 7, 2023
dba2843
Merge branch 'main' of github.com:FlagOpen/FlagPerf into wwl/update_w…
upvenly Jun 7, 2023
dc24610
update wave2vec
upvenly Jun 7, 2023
8f6a942
update wave2vec
upvenly Jun 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update for adapter
  • Loading branch information
upvenly committed Jun 6, 2023
commit 51fbcb90eac35bcc3e4c33a71b4931b6d0d7ed31
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecogni

### 数据集下载地址(global proxy)
http://www.openslr.org/resources/12
请参考 https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechRecognition/wav2vec2#quick-start-guide 下载和处理数据


### 框架与芯片支持情况
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
encoder_attention_heads = 12
feature_grad_mult = 0.1
ema = 0.0
optimizer = "adam"
optimizer = "fused_adam"
clip_norm = 25
weight_decay = 0.01
lr_policy = "poly"
Expand Down
17 changes: 0 additions & 17 deletions training/benchmarks/wav2vec2/pytorch/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +0,0 @@

from common.fairseq.optim.fused_adam import get_fused_adam_class
from common.utils import print_once

def create_optimizer(model, args):

kw = {'lr': args.lr, 'weight_decay': args.weight_decay}
if args.optimizer == 'adam' and not (args.fp16 or args.bf16):
print_once('WARNING: Using FusedAdam instead of Adam')
kw.update({'betas': args.adam_betas, 'eps': args.adam_eps})
fused_adam_cls = get_fused_adam_class()
print(fused_adam_cls, "fused_adam_cls")
optimizer = fused_adam_cls(model.parameters(), **kw)
else:
raise ValueError(f'Invalid optimizer "{args.optimizer}"')

return optimizer
7 changes: 3 additions & 4 deletions training/benchmarks/wav2vec2/pytorch/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from model import create_model
from train.evaluator import Evaluator
from train.training_state import TrainingState
from optimizer import create_optimizer
from loss.criterion import Wav2vecCriterion


Expand Down Expand Up @@ -50,7 +49,7 @@ def __init__(self, driver: Driver, adapter, evaluator: Evaluator,
def init(self):
self.model = create_model(self.config)
self.model = self.init_model(self.model, self.device)
self.optimizer = create_optimizer(self.model, self.config)
self.optimizer = self.adapter.create_optimizer(self.model, self.config)
self.optim = self.optimizer

Metrics = W2v2Metrics
Expand Down Expand Up @@ -153,7 +152,7 @@ def train_one_epoch(self, config, epoch, step, train_dataloader, sampler):
grads_mult_factor = world_size / self.metrics.partials[
'sample_size']

if self.config.optimizer == 'adam' and not (self.config.fp16
if self.config.optimizer == 'fused_adam' and not (self.config.fp16
or self.config.bf16):
# adam and non-amp optimizer - can use 'scale' kwarg for step
# and defer grad multiplication
Expand All @@ -168,7 +167,7 @@ def train_one_epoch(self, config, epoch, step, train_dataloader, sampler):
# calculate grad norm, maybe clip
grad_norm = self.optim.clip_grad_norm(self.config.clip_norm)

if self.config.optimizer == 'adam' and not (self.config.fp16
if self.config.optimizer == 'fused_adam' and not (self.config.fp16
or self.config.bf16):
self.scaler.step(self.optim,
scale=1. / grads_mult_factor)
Expand Down
21 changes: 18 additions & 3 deletions training/benchmarks/wav2vec2/pytorch/train/trainer_adapter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch.distributed as dist
import config

from torch import nn, Tensor
from driver.dist_pytorch import main_proc_print
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP


from common.fairseq.optim.fused_adam import get_fused_adam_class
from driver.dist_pytorch import main_proc_print


def convert_model(model: nn.Module) -> nn.Module:
Expand All @@ -27,3 +28,17 @@ def model_to_ddp(model: nn.Module) -> nn.Module:
from common.fairseq.dist import ModuleProxyWrapper
model = ModuleProxyWrapper(model)
return model


def create_optimizer(model, args):

kw = {'lr': args.lr, 'weight_decay': args.weight_decay}
if args.optimizer == 'fused_adam' and not (args.fp16 or args.bf16):
kw.update({'betas': args.adam_betas, 'eps': args.adam_eps})
fused_adam_cls = get_fused_adam_class()
print(fused_adam_cls, "fused_adam_cls")
optimizer = fused_adam_cls(model.parameters(), **kw)
else:
raise ValueError(f'Invalid optimizer "{args.optimizer}"')

return optimizer