Skip to content

Commit

Permalink
Making TF GPT2 compliant with XLA and AMP (huggingface#10230)
Browse files Browse the repository at this point in the history
* Fix XLA and AMP

* Fix AMP and XLA

* Apply style

* Apply Patrick's comment
  • Loading branch information
jplu authored Feb 18, 2021
1 parent 5da7c78 commit bdf1669
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 138 deletions.
113 changes: 0 additions & 113 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,119 +1331,6 @@ def call(self, x):
return x


class WordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
self.word_embeddings = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)

super().build(input_shape=input_shape)

def get_config(self):
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()

return dict(list(base_config.items()) + list(config.items()))

def call(self, input_ids):
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.word_embeddings, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
)

embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])

return embeddings


class TokenTypeEmbeddings(tf.keras.layers.Layer):
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)

self.type_vocab_size = type_vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)

super().build(input_shape=input_shape)

def get_config(self):
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()

return dict(list(base_config.items()) + list(config.items()))

def call(self, token_type_ids):
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0)
)

embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size])

return embeddings


class PositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)

self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)

super().build(input_shape)

def get_config(self):
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()

return dict(list(base_config.items()) + list(config.items()))

def call(self, position_ids):
input_shape = shape_list(tensor=position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]

return tf.broadcast_to(input=position_embeddings, shape=input_shape)


class TFSharedEmbeddings(tf.keras.layers.Layer):
r"""
Construct shared token embeddings.
Expand Down
42 changes: 25 additions & 17 deletions src/transformers/models/gpt2/modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=

if attention_mask is not None:
# Apply the attention mask
attention_mask = tf.cast(attention_mask, dtype=w.dtype)
w = w + attention_mask

w = tf.nn.softmax(w, axis=-1)
Expand Down Expand Up @@ -224,20 +225,26 @@ def __init__(self, config, *inputs, **kwargs):
self.num_hidden_layers = config.n_layer
self.vocab_size = config.vocab_size
self.n_embd = config.n_embd
self.n_positions = config.n_positions
self.initializer_range = config.initializer_range

self.wte = TFSharedEmbeddings(
config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
)
self.wpe = tf.keras.layers.Embedding(
config.n_positions,
config.n_embd,
embeddings_initializer=get_initializer(config.initializer_range),
name="wpe",
)
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")

def build(self, input_shape):
with tf.name_scope("wpe"):
self.wpe = self.add_weight(
name="embeddings",
shape=[self.n_positions, self.n_embd],
initializer=get_initializer(self.initializer_range),
)

super().build(input_shape)

def get_input_embeddings(self):
return self.wte

Expand Down Expand Up @@ -302,9 +309,7 @@ def call(
past_length = shape_list(inputs["past"][0][0])[-2]

if inputs["position_ids"] is None:
inputs["position_ids"] = tf.expand_dims(
tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0
)
inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)

if inputs["attention_mask"] is not None:
# We create a 3D attention mask from a 2D tensor mask.
Expand All @@ -322,11 +327,11 @@ def call(
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.

inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
else:
inputs["attention_mask"] = None
one_cst = tf.constant(1.0)
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
inputs["attention_mask"] = tf.multiply(
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand All @@ -344,15 +349,15 @@ def call(
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding")

position_embeds = self.wpe(inputs["position_ids"])
position_embeds = tf.gather(self.wpe, inputs["position_ids"])

if inputs["token_type_ids"] is not None:
inputs["token_type_ids"] = tf.reshape(
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
)
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
else:
token_type_embeds = 0
token_type_embeds = tf.constant(0.0)

position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype)
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype)
Expand Down Expand Up @@ -1024,7 +1029,10 @@ def call(
if inputs["input_ids"] is not None:
sequence_lengths = (
tf.reduce_sum(
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
dtype=inputs["input_ids"].dtype,
),
-1,
keepdims=False,
)
Expand Down
8 changes: 0 additions & 8 deletions tests/test_modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,14 +389,6 @@ def test_gpt2_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs)

def test_mixed_precision(self):
# TODO JP: Make GPT2 float16 compliant
pass

def test_xla_mode(self):
# TODO JP: Make GPT2 XLA compliant
pass

@slow
def test_model_from_pretrained(self):
for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
Expand Down

0 comments on commit bdf1669

Please sign in to comment.