Skip to content

TensorFlow training/inference optimization #7605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
94 changes: 58 additions & 36 deletions src/transformers/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,40 +212,48 @@ def __init__(self, config, **kwargs):
)

self.num_attention_heads = config.num_attention_heads

assert config.hidden_size % config.num_attention_heads == 0

self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
self.query = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="query",
)
self.key = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
self.key = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="key",
)
self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
self.value = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="value",
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)

def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
query_tensor = self.query(hidden_states)

return tf.transpose(x, perm=[0, 2, 1, 3])
# `key_tensor` = [B, S, N, H]
key_tensor = self.key(hidden_states)

def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = tf.matmul(
query_layer, key_layer, transpose_b=True
) # (batch size, num_heads, seq_len_q, seq_len_k)
dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores
attention_scores = attention_scores / tf.math.sqrt(dk)
# `value_tensor` = [B, S, N, H]
value_tensor = self.value(hidden_states)

# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BSNH,BTNH->BNTS", key_tensor, query_tensor)
dk = tf.cast(self.attention_head_size, dtype=attention_scores.dtype) # scale attention_scores
attention_scores = tf.multiply(attention_scores, tf.math.rsqrt(dk))

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
Expand All @@ -262,12 +270,8 @@ def call(self, hidden_states, attention_mask, head_mask, output_attentions, trai
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
context_layer = tf.reshape(
context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
attention_output = tf.einsum("BNTS,BSNH->BTNH", attention_probs, value_tensor)
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)

return outputs

Expand All @@ -276,8 +280,18 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.num_attention_heads = config.num_attention_heads

assert config.hidden_size % config.num_attention_heads == 0

self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abcd,cde->abe",
output_shape=(None, self.all_head_size),
bias_axes="e",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
Expand Down Expand Up @@ -314,8 +328,12 @@ class TFBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)

if isinstance(config.hidden_act, str):
Expand All @@ -334,8 +352,12 @@ class TFBertOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
bias_axes="d",
output_shape=(None, config.hidden_size),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
Expand Down
94 changes: 58 additions & 36 deletions src/transformers/modeling_tf_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,40 +66,48 @@ def __init__(self, config, **kwargs):
)

self.num_attention_heads = config.num_attention_heads

assert config.hidden_size % config.num_attention_heads == 0

self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
self.query = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="query",
)
self.key = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
self.key = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="key",
)
self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
self.value = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="value",
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)

def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
query_tensor = self.query(hidden_states)

return tf.transpose(x, perm=[0, 2, 1, 3])
# `key_tensor` = [B, S, N, H]
key_tensor = self.key(hidden_states)

def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = tf.matmul(
query_layer, key_layer, transpose_b=True
) # (batch size, num_heads, seq_len_q, seq_len_k)
dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores
attention_scores = attention_scores / tf.math.sqrt(dk)
# `value_tensor` = [B, S, N, H]
value_tensor = self.value(hidden_states)

# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BSNH,BTNH->BNTS", key_tensor, query_tensor)
dk = tf.cast(self.attention_head_size, dtype=attention_scores.dtype) # scale attention_scores
attention_scores = tf.multiply(attention_scores, tf.math.rsqrt(dk))

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
Expand All @@ -116,12 +124,8 @@ def call(self, hidden_states, attention_mask, head_mask, output_attentions, trai
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
context_layer = tf.reshape(
context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
attention_output = tf.einsum("BNTS,BSNH->BTNH", attention_probs, value_tensor)
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)

return outputs

Expand All @@ -131,8 +135,18 @@ class TFElectraSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.num_attention_heads = config.num_attention_heads

assert config.hidden_size % config.num_attention_heads == 0

self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abcd,cde->abe",
output_shape=(None, self.all_head_size),
bias_axes="e",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
Expand Down Expand Up @@ -171,8 +185,12 @@ class TFElectraIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)

if isinstance(config.hidden_act, str):
Expand All @@ -192,8 +210,12 @@ class TFElectraOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
bias_axes="d",
output_shape=(None, config.hidden_size),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
Expand Down
17 changes: 12 additions & 5 deletions src/transformers/modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,12 @@ class TFLongformerIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)

if isinstance(config.hidden_act, str):
Expand All @@ -291,8 +295,12 @@ class TFLongformerOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
bias_axes="d",
output_shape=(None, config.hidden_size),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
Expand Down Expand Up @@ -326,7 +334,6 @@ def call(self, hidden_states):
return pooled_output


# Copied from transformers.modeling_tf_bert.TFBertSelfOutput
class TFLongformerSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
Expand Down
Loading