Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8f99d47

Browse files
Ashish VaswaniRyan Sepassi
Ashish Vaswani
authored and
Ryan Sepassi
committed
The current attention computes compatibilities between single query, key, and value positions. This CL extends it to computing them between windows of queries, keys, and values. It's like a combination of convolution and attention. Does not change defaults.
PiperOrigin-RevId: 165405437
1 parent c4526ed commit 8f99d47

File tree

1 file changed

+78
-20
lines changed

1 file changed

+78
-20
lines changed

tensor2tensor/layers/common_attention.py

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,70 @@ def gather_blocks(x, indices):
681681
return tf.reshape(output, v_shape)
682682

683683

684+
def compute_qkv(query_antecedent, memory_antecedent, total_key_depth,
685+
total_value_depth, q_filter_width=1, kv_filter_width=1,
686+
q_padding="VALID", kv_padding="VALID"):
687+
"""Computes query, key and value.
688+
689+
Args:
690+
query_antecedent: a Tensor with shape [batch, length_q, channels]
691+
memory_antecedent: a Tensor with shape [batch, length_m, channels]
692+
total_key_depth: an integer
693+
total_value_depth: and integer
694+
q_filter_width: An integer specifying how wide you want the query to be.
695+
kv_filter_width: An integer specifying how wide you want the keys and values
696+
to be.
697+
q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
698+
kv_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
699+
700+
Returns:
701+
q, k, v : [batch, length, depth] tensors
702+
"""
703+
if memory_antecedent is None and q_filter_width == kv_filter_width == 1:
704+
# self attention with single position q, k, and v
705+
combined = common_layers.conv1d(
706+
query_antecedent,
707+
total_key_depth * 2 + total_value_depth,
708+
1,
709+
name="qkv_transform")
710+
q, k, v = tf.split(
711+
combined, [total_key_depth, total_key_depth, total_value_depth],
712+
axis=2)
713+
return q, k, v
714+
715+
if memory_antecedent is None:
716+
# self attention
717+
q = common_layers.conv1d(
718+
query_antecedent,
719+
total_key_depth,
720+
q_filter_width,
721+
padding=q_padding,
722+
name="q_transform")
723+
kv_combined = common_layers.conv1d(
724+
query_antecedent,
725+
total_key_depth + total_value_depth,
726+
kv_filter_width,
727+
padding=kv_padding,
728+
name="kv_transform")
729+
k, v = tf.split(kv_combined, [total_key_depth, total_value_depth],
730+
axis=2)
731+
return q, k, v
732+
733+
# encoder-decoder attention
734+
q = common_layers.conv1d(
735+
query_antecedent, total_key_depth, q_filter_width, padding=q_padding,
736+
name="q_transform")
737+
combined = common_layers.conv1d(
738+
memory_antecedent,
739+
total_key_depth + total_value_depth,
740+
1,
741+
padding=kv_padding,
742+
name="kv_transform")
743+
k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2)
744+
745+
return q, k, v
746+
747+
684748
def multihead_attention(query_antecedent,
685749
memory_antecedent,
686750
bias,
@@ -693,6 +757,10 @@ def multihead_attention(query_antecedent,
693757
attention_type="dot_product",
694758
block_length=128,
695759
block_width=128,
760+
q_filter_width=1,
761+
kv_filter_width=1,
762+
q_padding="VALID",
763+
kv_padding="VALID",
696764
name=None):
697765
"""Multihead scaled-dot-product attention with input/output transformations.
698766
@@ -711,6 +779,12 @@ def multihead_attention(query_antecedent,
711779
"local_unmasked"
712780
block_length: an integer - relevant for "local_mask_right"
713781
block_width: an integer - relevant for "local_unmasked"
782+
q_filter_width: An integer specifying how wide you want the query to be.
783+
kv_filter_width: An integer specifying how wide you want the keys and values
784+
to be.
785+
q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
786+
kv_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
787+
714788
name: an optional string
715789
716790
Returns:
@@ -726,30 +800,14 @@ def multihead_attention(query_antecedent,
726800
if total_value_depth % num_heads != 0:
727801
raise ValueError("Value depth (%d) must be divisible by the number of "
728802
"attention heads (%d)." % (total_value_depth, num_heads))
729-
730803
with tf.variable_scope(
731804
name,
732805
default_name="multihead_attention",
733806
values=[query_antecedent, memory_antecedent]):
734-
if memory_antecedent is None:
735-
# self attention
736-
combined = common_layers.conv1d(
737-
query_antecedent,
738-
total_key_depth * 2 + total_value_depth,
739-
1,
740-
name="qkv_transform")
741-
q, k, v = tf.split(
742-
combined, [total_key_depth, total_key_depth, total_value_depth],
743-
axis=2)
744-
else:
745-
q = common_layers.conv1d(
746-
query_antecedent, total_key_depth, 1, name="q_transform")
747-
combined = common_layers.conv1d(
748-
memory_antecedent,
749-
total_key_depth + total_value_depth,
750-
1,
751-
name="kv_transform")
752-
k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2)
807+
q, k, v = compute_qkv(query_antecedent, memory_antecedent, total_key_depth,
808+
total_value_depth, q_filter_width, kv_filter_width,
809+
q_padding, kv_padding)
810+
753811
q = split_heads(q, num_heads)
754812
k = split_heads(k, num_heads)
755813
v = split_heads(v, num_heads)

0 commit comments

Comments
 (0)