Skip to content

Commit

Permalink
Rework casts (huggingface#10274)
Browse files Browse the repository at this point in the history
  • Loading branch information
jplu authored Feb 24, 2021
1 parent 2d458b2 commit cdcdd5f
Showing 1 changed file with 24 additions and 30 deletions.
54 changes: 24 additions & 30 deletions src/transformers/models/xlnet/modeling_tf_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def rel_attn_core(
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
if attn_mask.dtype == tf.float16:
if attn_mask.dtype == tf.float16 or attn_mask.dtype == tf.bfloat16:
attn_score = attn_score - 65500 * attn_mask
else:
attn_score = attn_score - 1e30 * attn_mask
Expand Down Expand Up @@ -476,7 +476,7 @@ def build(self, input_shape):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError

def create_mask(self, qlen, mlen, dtype=tf.float32):
def create_mask(self, qlen, mlen):
"""
Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.
Expand All @@ -495,10 +495,10 @@ def create_mask(self, qlen, mlen, dtype=tf.float32):
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask = tf.ones([qlen, qlen], dtype=dtype)
attn_mask = tf.ones([qlen, qlen])
mask_u = tf.matrix_band_part(attn_mask, 0, -1)
mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
attn_mask_pad = tf.zeros([qlen, mlen])
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
if self.same_length:
mask_l = tf.matrix_band_part(attn_mask, -1, 0)
Expand Down Expand Up @@ -537,11 +537,9 @@ def positional_embedding(pos_seq, inv_freq, bsz=None):

return pos_emb

def relative_positional_encoding(self, qlen, klen, bsz=None, dtype=None):
def relative_positional_encoding(self, qlen, klen, bsz=None):
"""create relative positional encoding."""
freq_seq = tf.range(0, self.d_model, 2.0)
if dtype is not None and dtype != tf.float32:
freq_seq = tf.cast(freq_seq, dtype=dtype)
inv_freq = 1 / (10000 ** (freq_seq / self.d_model))

if self.attn_type == "bi":
Expand All @@ -557,10 +555,6 @@ def relative_positional_encoding(self, qlen, klen, bsz=None, dtype=None):
fwd_pos_seq = tf.range(beg, end, -1.0)
bwd_pos_seq = tf.range(-beg, -end, 1.0)

if dtype is not None and dtype != tf.float32:
fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)

if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len)
Expand All @@ -576,8 +570,6 @@ def relative_positional_encoding(self, qlen, klen, bsz=None, dtype=None):
pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
else:
fwd_pos_seq = tf.range(beg, end, -1.0)
if dtype is not None and dtype != tf.float32:
fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
Expand Down Expand Up @@ -661,8 +653,6 @@ def call(
mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None and inputs["mems"][0] is not None else 0
klen = mlen + qlen

dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32

# Attention mask
# causal attention mask
if self.attn_type == "uni":
Expand All @@ -679,7 +669,8 @@ def call(
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
)
if inputs["input_mask"] is None and inputs["attention_mask"] is not None:
inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=dtype_float)
one_cst = tf.constant(1.0)
inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
if inputs["input_mask"] is not None and inputs["perm_mask"] is not None:
data_mask = inputs["input_mask"][None] + inputs["perm_mask"]
elif inputs["input_mask"] is not None and inputs["perm_mask"] is None:
Expand All @@ -692,21 +683,21 @@ def call(
if data_mask is not None:
# all mems can be attended to
if mlen > 0:
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float)
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz])
data_mask = tf.concat([mems_mask, data_mask], axis=1)
if attn_mask is None:
attn_mask = data_mask[:, :, :, None]
else:
attn_mask += data_mask[:, :, :, None]

if attn_mask is not None:
attn_mask = tf.cast(attn_mask > 0, dtype=dtype_float)
attn_mask = tf.cast(attn_mask > 0, dtype=attn_mask.dtype)

if attn_mask is not None:
non_tgt_mask = -tf.eye(qlen, dtype=dtype_float)
non_tgt_mask = -tf.eye(qlen)
if mlen > 0:
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float)
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen]), non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=non_tgt_mask.dtype)
else:
non_tgt_mask = None

Expand All @@ -729,19 +720,22 @@ def call(
if inputs["token_type_ids"] is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
if mlen > 0:
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
mem_pad = tf.zeros([mlen, bsz], dtype=inputs["token_type_ids"].dtype)
cat_ids = tf.concat([mem_pad, inputs["token_type_ids"]], 0)
else:
cat_ids = inputs["token_type_ids"]

# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = tf.cast(tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])), tf.int32)
seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float)
seg_mat = tf.cast(
tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])),
dtype=inputs["token_type_ids"].dtype,
)
seg_mat = tf.one_hot(seg_mat, 2)
else:
seg_mat = None

# Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb, training=inputs["training"])

# Prepare head mask if needed
Expand Down Expand Up @@ -1258,7 +1252,7 @@ def prepare_inputs_for_generation(self, inputs, past, use_mems=None, **kwargs):
offset = 2

effective_batch_size = inputs.shape[0]
dummy_token = tf.zeros((effective_batch_size, 1), dtype=tf.int32)
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)

if past:
inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
Expand All @@ -1267,13 +1261,13 @@ def prepare_inputs_for_generation(self, inputs, past, use_mems=None, **kwargs):

# Build permutation mask so that previous tokens don't see last token
sequence_length = inputs.shape[1]
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1), dtype=tf.float32)
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1), dtype=tf.float32)
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1))
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1))
perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1)

# We'll only predict the last token
target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1), dtype=tf.float32)
target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1), dtype=tf.float32)
target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1))
target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1))
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)

inputs = {
Expand Down

0 comments on commit cdcdd5f

Please sign in to comment.