Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into ds-init-config
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 committed Sep 16, 2021
2 parents 357fdcb + 4d5b4c7 commit c2ec02b
Show file tree
Hide file tree
Showing 14 changed files with 165 additions and 65 deletions.
9 changes: 5 additions & 4 deletions .github/workflows/self-nightly-scheduled.yml
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ jobs:
run: |
apt -y update && apt install -y libaio-dev
pip install --upgrade pip
pip install .[testing,deepspeed]
pip install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html -U
pip install .[testing,deepspeed]
pip install git+https://github.com/microsoft/DeepSpeed
- name: Are GPUs recognized by our DL frameworks
run: |
Expand Down Expand Up @@ -203,7 +203,9 @@ jobs:
run: |
apt -y update && apt install -y libaio-dev
pip install --upgrade pip
pip install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html -U
pip install .[testing,deepspeed,fairscale]
pip install git+https://github.com/microsoft/DeepSpeed
- name: Are GPUs recognized by our DL frameworks
run: |
Expand All @@ -215,8 +217,7 @@ jobs:
- name: Run all tests on GPU
run: |
python -m pytest -n 1 -v --dist=loadfile --make-reports=tests_torch_cuda_extensions_multi_gpu tests/deepspeed tests/extended
pip install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html -U
- name: Failure short reports
if: ${{ always() }}
run: cat reports/tests_torch_cuda_extensions_multi_gpu_failures_short.txt
Expand Down
33 changes: 24 additions & 9 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,11 @@ def torch_call(self, features):
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch[label_name] = [
label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
]
else:
batch[label_name] = [
[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
]

batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
Expand All @@ -321,9 +321,13 @@ def tf_call(self, features):
sequence_length = tf.convert_to_tensor(batch["input_ids"]).shape[1]
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
batch["labels"] = [
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
]
else:
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
batch["labels"] = [
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
]

batch = {k: tf.convert_to_tensor(v, dtype=tf.int64) for k, v in batch.items()}
return batch
Expand All @@ -348,9 +352,13 @@ def numpy_call(self, features):
sequence_length = np.array(batch["input_ids"]).shape[1]
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
batch["labels"] = [
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
]
else:
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
batch["labels"] = [
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
]

batch = {k: np.array(v, dtype=np.int64) for k, v in batch.items()}
return batch
Expand Down Expand Up @@ -517,6 +525,8 @@ class DataCollatorForSeq2Seq:
return_tensors: str = "pt"

def __call__(self, features, return_tensors=None):
import numpy as np

if return_tensors is None:
return_tensors = self.return_tensors
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
Expand All @@ -527,9 +537,14 @@ def __call__(self, features, return_tensors=None):
padding_side = self.tokenizer.padding_side
for feature in features:
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
feature["labels"] = (
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
)
if isinstance(feature["labels"], list):
feature["labels"] = (
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
)
elif padding_side == "right":
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
else:
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)

features = self.tokenizer.pad(
features,
Expand Down
11 changes: 5 additions & 6 deletions src/transformers/feature_extraction_sequence_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,17 @@ def pad(
for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in processed_features.items())
# truncation
inputs = self._truncate(
inputs_slice = self._truncate(
inputs,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
truncation=truncation,
)
truncated_inputs.append(inputs)
truncated_inputs.append(inputs_slice)

if padding_strategy == PaddingStrategy.LONGEST:
max_length = max(len(inputs) for inputs in required_input)
# make sure that `max_length` cannot be longer than the longest truncated length
max_length = max(len(input_slice[self.model_input_names[0]]) for input_slice in truncated_inputs)
padding_strategy = PaddingStrategy.MAX_LENGTH

batch_outputs = {}
Expand Down Expand Up @@ -322,9 +323,7 @@ def _truncate(
if not truncation:
return processed_features
elif truncation and max_length is None:
raise ValueError(
"When setting ``truncation=True``, make sure that ``max_length`` is defined and ``padding='max_length'``"
)
raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.")

required_input = processed_features[self.model_input_names[0]]

Expand Down
49 changes: 19 additions & 30 deletions src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@


def create_sinusoidal_embeddings(n_pos, dim, out):
if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(out, modifier_rank=0):
if torch.distributed.get_rank() == 0:
_create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)
else:
_create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)


def _create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
out.requires_grad = False
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
Expand All @@ -86,19 +97,9 @@ def __init__(self, config):
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
if config.sinusoidal_pos_embds:

if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(self.position_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
create_sinusoidal_embeddings(
n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
)
else:
create_sinusoidal_embeddings(
n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
)
create_sinusoidal_embeddings(
n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
)

self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
self.dropout = nn.Dropout(config.dropout)
Expand Down Expand Up @@ -475,23 +476,9 @@ def resize_position_embeddings(self, new_num_position_embeddings: int):
self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim)

if self.config.sinusoidal_pos_embds:

if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(self.embeddings.position_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
create_sinusoidal_embeddings(
n_pos=self.config.max_position_embeddings,
dim=self.config.dim,
out=self.embeddings.position_embeddings.weight,
)
else:
create_sinusoidal_embeddings(
n_pos=self.config.max_position_embeddings,
dim=self.config.dim,
out=self.embeddings.position_embeddings.weight,
)
create_sinusoidal_embeddings(
n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight
)
else:
with torch.no_grad():
if num_position_embeds_diff > 0:
Expand All @@ -502,6 +489,8 @@ def resize_position_embeddings(self, new_num_position_embeddings: int):
self.embeddings.position_embeddings.weight = nn.Parameter(
old_position_embeddings_weight[:num_position_embeds_diff]
)
# move position_embeddings to correct device
self.embeddings.position_embeddings.to(self.device)

def get_input_embeddings(self):
return self.embeddings.word_embeddings
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def _compute_mask_indices(
# scatter indices to mask
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)

if attention_mask is not None:
# make sure padded input ids cannot be masked
spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False)

return spec_aug_mask


Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/hubert/modeling_tf_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _compute_mask_indices(
tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, spec_aug_mask.shape
)

return tf.cast(spec_aug_mask, tf.float32)
return spec_aug_mask


# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/pegasus/modeling_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ def resize_position_embeddings(self, new_num_position_embeddings: int):
self.config.d_model,
self.padding_idx,
)
self.embed_positions.to(self.device)

def get_position_embeddings(self) -> nn.Embedding:
"""
Expand Down Expand Up @@ -886,6 +887,7 @@ def resize_position_embeddings(self, new_num_position_embeddings: int):
self.config.d_model,
self.padding_idx,
)
self.embed_positions.to(self.device)

def get_position_embeddings(self) -> nn.Embedding:
"""
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/roberta/modeling_tf_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,9 +508,9 @@ def call(
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(tensor=inputs["input_ids"])
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(tensor=inputs["inputs_embeds"])[:-1]
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def utterance_cmvn(
std = np.sqrt(np.maximum(var, 1e-10))
x = np.divide(x, std)

if x.shape[0] > input_length:
if input_length < x.shape[0]:
x[input_length:] = padding_value

# make sure array is in float32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def zero_mean_unit_var_norm(

for vector, length in zip(input_values, attention_mask.sum(-1)):
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
if length > normed_slice.shape[0]:
if length < normed_slice.shape[0]:
normed_slice[length:] = padding_value

normed_input_values.append(normed_slice)
Expand Down
70 changes: 63 additions & 7 deletions tests/test_feature_extraction_speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,15 @@ def test_cepstral_mean_and_variance_normalization(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]

paddings = ["longest", "max_length", "do_not_pad"]
max_lengths = [None, 16, None]
var_tolerances = [1e-3, 1e-3, 5e-1]
# TODO(Patrick, Suraj, Anton) - It's surprising that "non-padded/non-numpified" padding
# results in quite inaccurate variance computation after (see 5e-1 tolerance)
# Issue is filed and PR is underway: https://github.com/huggingface/transformers/issues/13539
# paddings = ["longest", "max_length", "do_not_pad"]
# max_lengths = [None, 16, None]
# var_tolerances = [1e-3, 1e-3, 5e-1]
paddings = ["longest", "max_length"]
max_lengths = [None, 16]
var_tolerances = [1e-3, 1e-3]
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):

inputs = feature_extractor(
Expand All @@ -163,11 +167,15 @@ def test_cepstral_mean_and_variance_normalization_np(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]

paddings = ["longest", "max_length", "do_not_pad"]
max_lengths = [None, 16, None]
var_tolerances = [1e-3, 1e-3, 5e-1]
# TODO(Patrick, Suraj, Anton) - It's surprising that "non-padded/non-numpified" padding
# results in quite inaccurate variance computation after (see 5e-1 tolerance)
# Issue is filed and PR is underway: https://github.com/huggingface/transformers/issues/13539
# paddings = ["longest", "max_length", "do_not_pad"]
# max_lengths = [None, 16, None]
# var_tolerances = [1e-3, 1e-3, 5e-1]
paddings = ["longest", "max_length"]
max_lengths = [None, 16]
var_tolerances = [1e-3, 1e-3]
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):
inputs = feature_extractor(
speech_inputs, max_length=max_length, padding=padding, return_tensors="np", return_attention_mask=True
Expand All @@ -181,10 +189,12 @@ def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3):
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol))

_check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol)
self.assertTrue(input_features[0][fbank_feat_lengths[0] :].sum() < 1e-6)
_check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol)
self.assertTrue(input_features[0][fbank_feat_lengths[1] :].sum() < 1e-6)
_check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol)

def test_cepstral_mean_and_variance_normalization_trunc(self):
def test_cepstral_mean_and_variance_normalization_trunc_max_length(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
inputs = feature_extractor(
Expand All @@ -206,3 +216,49 @@ def _check_zero_mean_unit_variance(input_vector):
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
_check_zero_mean_unit_variance(input_features[1])
_check_zero_mean_unit_variance(input_features[2])

def test_cepstral_mean_and_variance_normalization_trunc_longest(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
inputs = feature_extractor(
speech_inputs,
padding="longest",
max_length=4,
truncation=True,
return_tensors="np",
return_attention_mask=True,
)
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)

def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))

_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
_check_zero_mean_unit_variance(input_features[2])

# make sure that if max_length < longest -> then pad to max_length
self.assertEqual(input_features.shape, (3, 4, 24))

speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
inputs = feature_extractor(
speech_inputs,
padding="longest",
max_length=16,
truncation=True,
return_tensors="np",
return_attention_mask=True,
)
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)

_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
_check_zero_mean_unit_variance(input_features[2])

# make sure that if max_length < longest -> then pad to max_length
self.assertEqual(input_features.shape, (3, 6, 24))
Loading

0 comments on commit c2ec02b

Please sign in to comment.