@@ -681,6 +681,70 @@ def gather_blocks(x, indices):
681
681
return tf .reshape (output , v_shape )
682
682
683
683
684
+ def compute_qkv (query_antecedent , memory_antecedent , total_key_depth ,
685
+ total_value_depth , q_filter_width = 1 , kv_filter_width = 1 ,
686
+ q_padding = "VALID" , kv_padding = "VALID" ):
687
+ """Computes query, key and value.
688
+
689
+ Args:
690
+ query_antecedent: a Tensor with shape [batch, length_q, channels]
691
+ memory_antecedent: a Tensor with shape [batch, length_m, channels]
692
+ total_key_depth: an integer
693
+ total_value_depth: and integer
694
+ q_filter_width: An integer specifying how wide you want the query to be.
695
+ kv_filter_width: An integer specifying how wide you want the keys and values
696
+ to be.
697
+ q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
698
+ kv_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
699
+
700
+ Returns:
701
+ q, k, v : [batch, length, depth] tensors
702
+ """
703
+ if memory_antecedent is None and q_filter_width == kv_filter_width == 1 :
704
+ # self attention with single position q, k, and v
705
+ combined = common_layers .conv1d (
706
+ query_antecedent ,
707
+ total_key_depth * 2 + total_value_depth ,
708
+ 1 ,
709
+ name = "qkv_transform" )
710
+ q , k , v = tf .split (
711
+ combined , [total_key_depth , total_key_depth , total_value_depth ],
712
+ axis = 2 )
713
+ return q , k , v
714
+
715
+ if memory_antecedent is None :
716
+ # self attention
717
+ q = common_layers .conv1d (
718
+ query_antecedent ,
719
+ total_key_depth ,
720
+ q_filter_width ,
721
+ padding = q_padding ,
722
+ name = "q_transform" )
723
+ kv_combined = common_layers .conv1d (
724
+ query_antecedent ,
725
+ total_key_depth + total_value_depth ,
726
+ kv_filter_width ,
727
+ padding = kv_padding ,
728
+ name = "kv_transform" )
729
+ k , v = tf .split (kv_combined , [total_key_depth , total_value_depth ],
730
+ axis = 2 )
731
+ return q , k , v
732
+
733
+ # encoder-decoder attention
734
+ q = common_layers .conv1d (
735
+ query_antecedent , total_key_depth , q_filter_width , padding = q_padding ,
736
+ name = "q_transform" )
737
+ combined = common_layers .conv1d (
738
+ memory_antecedent ,
739
+ total_key_depth + total_value_depth ,
740
+ 1 ,
741
+ padding = kv_padding ,
742
+ name = "kv_transform" )
743
+ k , v = tf .split (combined , [total_key_depth , total_value_depth ], axis = 2 )
744
+
745
+ return q , k , v
746
+
747
+
684
748
def multihead_attention (query_antecedent ,
685
749
memory_antecedent ,
686
750
bias ,
@@ -693,6 +757,10 @@ def multihead_attention(query_antecedent,
693
757
attention_type = "dot_product" ,
694
758
block_length = 128 ,
695
759
block_width = 128 ,
760
+ q_filter_width = 1 ,
761
+ kv_filter_width = 1 ,
762
+ q_padding = "VALID" ,
763
+ kv_padding = "VALID" ,
696
764
name = None ):
697
765
"""Multihead scaled-dot-product attention with input/output transformations.
698
766
@@ -711,6 +779,12 @@ def multihead_attention(query_antecedent,
711
779
"local_unmasked"
712
780
block_length: an integer - relevant for "local_mask_right"
713
781
block_width: an integer - relevant for "local_unmasked"
782
+ q_filter_width: An integer specifying how wide you want the query to be.
783
+ kv_filter_width: An integer specifying how wide you want the keys and values
784
+ to be.
785
+ q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
786
+ kv_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
787
+
714
788
name: an optional string
715
789
716
790
Returns:
@@ -726,30 +800,14 @@ def multihead_attention(query_antecedent,
726
800
if total_value_depth % num_heads != 0 :
727
801
raise ValueError ("Value depth (%d) must be divisible by the number of "
728
802
"attention heads (%d)." % (total_value_depth , num_heads ))
729
-
730
803
with tf .variable_scope (
731
804
name ,
732
805
default_name = "multihead_attention" ,
733
806
values = [query_antecedent , memory_antecedent ]):
734
- if memory_antecedent is None :
735
- # self attention
736
- combined = common_layers .conv1d (
737
- query_antecedent ,
738
- total_key_depth * 2 + total_value_depth ,
739
- 1 ,
740
- name = "qkv_transform" )
741
- q , k , v = tf .split (
742
- combined , [total_key_depth , total_key_depth , total_value_depth ],
743
- axis = 2 )
744
- else :
745
- q = common_layers .conv1d (
746
- query_antecedent , total_key_depth , 1 , name = "q_transform" )
747
- combined = common_layers .conv1d (
748
- memory_antecedent ,
749
- total_key_depth + total_value_depth ,
750
- 1 ,
751
- name = "kv_transform" )
752
- k , v = tf .split (combined , [total_key_depth , total_value_depth ], axis = 2 )
807
+ q , k , v = compute_qkv (query_antecedent , memory_antecedent , total_key_depth ,
808
+ total_value_depth , q_filter_width , kv_filter_width ,
809
+ q_padding , kv_padding )
810
+
753
811
q = split_heads (q , num_heads )
754
812
k = split_heads (k , num_heads )
755
813
v = split_heads (v , num_heads )
0 commit comments