Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support different image #93

Merged
merged 11 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
upvenly committed Jun 1, 2023
commit f00658dcd722aa832e185cd892c520309b1e76c8
19 changes: 19 additions & 0 deletions training/benchmarks/wav2vec2/pytorch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
### 模型信息

This repository provides an optimized implementation of the wav2vec 2.0 model, as described in the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://proceedings.neurips.cc/paper/2020/file/92d1e1eb1cd6f9fba3227870bb6d7f07-Paper.pdf). It is based on the [Fairseq codebase](https://github.com/facebookresearch/fairseq) published by the authors of the paper. The wav2vec 2.0 model is pre-trained unsupervised on large corpora of speech recordings. Afterward, it can be quickly fine-tuned in a supervised way for speech recognition or serve as an extractor of high-level features and pseudo-phonemes for other applications.

### 代码来源

https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/wav2vec2/


### 数据集下载地址(global proxy)
http://www.openslr.org/resources/12


### 框架与芯片支持情况
| | Pytorch |
| ---------- | ------- |
| Nvidia GPU | ✅ |
| 昆仑芯 XPU | N/A |
| 天数智芯 | N/A |
Empty file.
156 changes: 156 additions & 0 deletions training/benchmarks/wav2vec2/pytorch/common/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import numpy as np
from torch.utils.data import DataLoader

from common.fairseq.data import data_utils
from common.helpers import print_once
from common.sampler import DistributedIndicesSampler


def adjust_max_tokens(train_dataset, world_size, args):

def get_steps_per_epoch(world_size, max_tokens, update_freq):
train_loader, sampler = get_batch_iterator(
train_dataset,
True,
max_tokens=max_tokens,
max_sentences=args.batch_size,
max_positions=(max_tokens, max_tokens),
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=world_size,
shard_id=0,
num_workers=args.num_workers)

steps_per_epoch = len(train_loader) // update_freq
return steps_per_epoch

steps_ref = get_steps_per_epoch(args.ref_world_size, args.ref_max_tokens,
1)

min_ = args.ref_max_tokens // 20
max_ = args.ref_max_tokens * 20

prev_max_tokens = 0
align_to = 1000
while min_ < max_:
max_tokens = (max_ + min_) // 2 // align_to * align_to # try to round
if max_tokens == prev_max_tokens:
break
prev_max_tokens = max_tokens
steps = get_steps_per_epoch(world_size, max_tokens, args.update_freq)
print_once(f"max_tokens={max_tokens} yields {steps} steps "
f"(adjusting for {steps_ref}).")
if steps == steps_ref:
break
elif steps > steps_ref:
min_ = max_tokens
else:
max_ = max_tokens

args.max_tokens = max_tokens
args.max_tokens_valid = max_tokens


def filter_indices_by_size(indices,
dataset,
max_positions=None,
ignore_invalid_inputs=False):
"""
Filter examples that are too large

Args:
indices (np.array): original array of sample indices
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_positions (optional): max sentence length supported by the
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
Returns:
np.array: array of filtered sample indices
"""
indices, ignored = dataset.filter_indices_by_size(indices, max_positions)
# TODO: consider removing this function. If `len(ignored) > 0`,
# an error is raised in fairseq dataset code, both in sup and unsup case
if len(ignored) > 0:
if not ignore_invalid_inputs:
raise Exception(
("Size of sample #{} is invalid (={}) since max_positions={}, "
"skip this example with --skip-invalid-size-inputs-valid-test"
).format(ignored[0], dataset.size(ignored[0]), max_positions))
print(("WARNING: {:,} samples have invalid sizes and will be skipped, "
"max_positions={}, first few sample ids={}").format(
len(ignored), max_positions, ignored[:10]))
return indices


def get_batch_iterator(
dataset,
training,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
num_concat_batches=1,
):
# get indices ordered by example size
with data_utils.numpy_seed(seed):
indices = dataset.ordered_indices()
# filter examples that are too large
if max_positions is not None:
indices = filter_indices_by_size(indices, dataset, max_positions,
ignore_invalid_inputs)

# create mini-batches with given size constraints
batch_inds, non_grouped_batch_inds = dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
num_concat_batches=num_concat_batches,
)

batch_ids = copy.deepcopy(non_grouped_batch_inds)
[bi.fill(i) for i, bi in enumerate(batch_ids)]
inds_ids = zip(np.concatenate(batch_inds), np.concatenate(batch_ids))
dataset.batch_ids = {idx: batch_idx for idx, batch_idx in inds_ids}

# Batches are already specified, now we just need to shuffle them
batch_ind_sampler = DistributedIndicesSampler(batch_inds,
shuffle=training,
num_replicas=num_shards,
rank=shard_id,
seed=seed,
drop_last=training,
fillvalue=[])
print("DataLoaderDataLoaderDataLoaderDataLoader")

loader = DataLoader(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_ind_sampler,
num_workers=num_workers,
pin_memory=True,
persistent_workers=num_workers > 0,
)
return loader, batch_ind_sampler
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""isort:skip_file"""

from .add_target_dataset import AddTargetDataset, BaseWrapperDataset
from .audio.raw_audio_dataset import FileAudioDataset

__all__ = [
"AddTargetDataset",
"FileAudioDataset",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from . import data_utils


class BaseWrapperDataset(torch.utils.data.Dataset):

def __init__(self, dataset):
super().__init__()
self.dataset = dataset

def __getitem__(self, index):
return self.dataset[index]

def __len__(self):
return len(self.dataset)

@property
def sizes(self):
return self.dataset.sizes

def num_tokens(self, index):
return self.dataset.num_tokens(index)

def size(self, index):
return self.dataset.size(index)

def ordered_indices(self):
return self.dataset.ordered_indices()

def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
num_concat_batches=1,
):
return self.dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
num_concat_batches=num_concat_batches,
)

def filter_indices_by_size(self, indices, max_sizes):
return self.dataset.filter_indices_by_size(indices, max_sizes)


class AddTargetDataset(BaseWrapperDataset):

def __init__(
self,
dataset,
labels,
pad,
eos,
batch_targets,
process_label=None,
add_to_input=False,
):
super().__init__(dataset)
self.labels = labels
self.batch_targets = batch_targets
self.pad = pad
self.eos = eos
self.process_label = process_label
self.add_to_input = add_to_input

def get_label(self, index):
return (self.labels[index] if self.process_label is None else
self.process_label(self.labels[index]))

def __getitem__(self, index):
item = self.dataset[index]
item["label"] = self.get_label(index)
return item

def size(self, index):
sz = self.dataset.size(index)
own_sz = len(self.get_label(index))
return (sz, own_sz)

def collater(self, samples):
collated = self.dataset.collater(samples)
if len(collated) == 0:
return collated
indices = set(collated["id"].tolist())
target = [s["label"] for s in samples if s["id"] in indices]

if self.batch_targets:
collated["target_lengths"] = torch.LongTensor(
[len(t) for t in target])
target = data_utils.collate_tokens(target,
pad_idx=self.pad,
left_pad=False)
collated["ntokens"] = collated["target_lengths"].sum().item()
else:
collated["ntokens"] = sum([len(t) for t in target])

collated["target"] = target

if self.add_to_input:
eos = target.new_full((target.size(0), 1), self.eos)
collated["target"] = torch.cat([target, eos], dim=-1).long()
collated["net_input"]["prev_output_tokens"] = torch.cat(
[eos, target], dim=-1).long()
collated["ntokens"] += target.size(0)
return collated

def __setattr__(self, attr, val):
if attr == "batch_ids":
self.dataset.batch_ids = val
else:
super().__setattr__(attr, val)
Loading