Skip to content

Commit

Permalink
Fix XLA and AMP (huggingface#10262)
Browse files Browse the repository at this point in the history
  • Loading branch information
jplu authored Feb 19, 2021
1 parent 3d72d47 commit 86caeb7
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 34 deletions.
46 changes: 28 additions & 18 deletions src/transformers/models/t5/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,18 @@ def __init__(self, config, has_relative_attention_bias=False, **kwargs):
self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)

if self.has_relative_attention_bias:
self.relative_attention_bias = tf.keras.layers.Embedding(
self.relative_attention_num_buckets,
self.n_heads,
name="relative_attention_bias",
)
self.pruned_heads = set()

def build(self, input_shape):
if self.has_relative_attention_bias:
with tf.name_scope("relative_attention_bias"):
self.relative_attention_bias = self.add_weight(
name="embeddings",
shape=[self.relative_attention_num_buckets, self.n_heads],
)

return super().build(input_shape)

def prune_heads(self, heads):
raise NotImplementedError

Expand Down Expand Up @@ -206,18 +210,20 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
# n = -relative_position
if bidirectional:
num_buckets //= 2
relative_buckets += tf.dtypes.cast(tf.math.greater(relative_position, 0), tf.int32) * num_buckets
relative_buckets += (
tf.cast(tf.math.greater(relative_position, 0), dtype=relative_position.dtype) * num_buckets
)
relative_position = tf.math.abs(relative_position)
else:
relative_position = -tf.math.minimum(relative_position, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = tf.math.less(relative_position, max_exact)
relative_position_if_large = max_exact + tf.dtypes.cast(
tf.math.log(tf.dtypes.cast(relative_position, tf.float32) / max_exact)
relative_position_if_large = max_exact + tf.cast(
tf.math.log(relative_position / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact),
tf.int32,
dtype=relative_position.dtype,
)
relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1)
relative_buckets += tf.where(is_small, relative_position, relative_position_if_large)
Expand All @@ -233,7 +239,9 @@ def compute_bias(self, query_length, key_length):
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = tf.gather(
self.relative_attention_bias, relative_position_bucket
) # shape (query_length, key_length, num_heads)
values = tf.expand_dims(
tf.transpose(values, [2, 0, 1]), axis=0
) # shape (1, num_heads, query_length, key_length)
Expand Down Expand Up @@ -326,7 +334,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):

if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length), dtype=tf.float32)
position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length))
else:
position_bias = self.compute_bias(real_seq_length, key_length)

Expand All @@ -336,6 +344,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
position_bias = position_bias[:, :, -seq_length:, :]

if mask is not None:
position_bias = tf.cast(position_bias, dtype=mask.dtype)
position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length)

scores += position_bias
Expand Down Expand Up @@ -662,7 +671,7 @@ def call(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=tf.float32)
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=inputs["inputs_embeds"].dtype)
num_dims_attention_mask = len(shape_list(inputs["attention_mask"]))
if num_dims_attention_mask == 3:
extended_attention_mask = inputs["attention_mask"][:, None, :, :]
Expand All @@ -676,7 +685,7 @@ def call(
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None],
)
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
extended_attention_mask = causal_mask[:, None, :, :] * inputs["attention_mask"][:, None, None, :]
if inputs["past_key_values"][0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
Expand All @@ -700,7 +709,9 @@ def call(
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast(inputs["encoder_attention_mask"], dtype=tf.float32)
inputs["encoder_attention_mask"] = tf.cast(
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
Expand Down Expand Up @@ -868,8 +879,7 @@ def _shift_right(self, input_ids):
decoder_start_token_id is not None
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"

shifted_input_ids = tf.cast(input_ids, tf.int32)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)

Expand All @@ -880,7 +890,7 @@ def _shift_right(self, input_ids):
)

# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))

# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
Expand Down
16 changes: 0 additions & 16 deletions tests/test_modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,14 +305,6 @@ def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass

def test_mixed_precision(self):
# TODO JP: Make T5 float16 compliant
pass

def test_xla_mode(self):
# TODO JP: Make T5 XLA compliant
pass

@slow
def test_model_from_pretrained(self):
model = TFT5Model.from_pretrained("t5-small")
Expand Down Expand Up @@ -442,14 +434,6 @@ def test_model(self):
def test_train_pipeline_custom_model(self):
pass

def test_mixed_precision(self):
# TODO JP: Make T5 float16 compliant
pass

def test_xla_mode(self):
# TODO JP: Make T5 XLA compliant
pass


@require_tf
@require_sentencepiece
Expand Down

0 comments on commit 86caeb7

Please sign in to comment.