1- """ A set of functions to perform message passing on a serialized graph in an LLM """
1+ """ A module for learning to pass information between elements on a serialized graph in an LLM
2+ without violating the causality constraint of autoregressive generation (passing information
3+ backwards in the sequence)
4+ """
25
36import enum
7+ from functools import partial
48from collections import defaultdict
59import itertools
610from typing import Dict , List , Optional , Tuple , Union
711
812import numpy as np
913import torch
14+ from torch_scatter import scatter , scatter_softmax
1015import torch_geometric
1116
1217from .desequence_graph_ids import SequenceElement
@@ -18,92 +23,49 @@ class GNNLayerFactory(enum.Enum):
1823 gat = torch_geometric .nn .GATConv
1924
2025
21- def build_message_passing_matrices (
22- token_ids : torch .Tensor ,
23- edge_sequences : List [List [Tuple [SequenceElement , Optional [SequenceElement ], Optional [SequenceElement ]]]]
24- ) -> List [Dict [str , torch .Tensor ]]:
25- """ Returns the adjacency matrices required to perform causal message passing in between
26- language model blocks of an autoregressive language model
26+ def graph_cross_attention (
27+ values : torch .Tensor ,
28+ key_representations : torch .Tensor ,
29+ query_representations : torch .Tensor ,
30+ edge_index : torch .Tensor
31+ ) -> torch .Tensor :
32+ """ Performs graph attention on a set of prior probabilities uing the representation of each
33+ node in the graph to calculate the attention weights. The implemented attention is dot
34+ product attention as implemented in the transformer architecture
2735 """
28- message_passing_dicts = []
29- for edge_sequence in edge_sequences :
30- message_passing_dict = defaultdict (list )
31- node2edge_idxs = defaultdict (list )
32- prev_node_idx = defaultdict (lambda : - 1 )
33-
34- def add_element (end_idx : int , element_type : str ):
35- """ Adds an element to the edge or node graphs used for message passing """
36- assert element_type in ['nodes' , 'edges' ]
37- message_passing_dict [f"tokens2{ element_type } " ].append (end_idx - 1 )
38- message_passing_dict [f"{ element_type } 2tokens" ].append (end_idx )
39-
40- for edge_idx , sequenced_edge in enumerate (edge_sequence ):
41- pred_node , edge , succ_node = sequenced_edge
42- if edge_idx == len (edge_sequence ) - 1 :
43- if (
44- not isinstance (succ_node , SequenceElement )
45- and not isinstance (edge , SequenceElement )
46- ):
47- continue
48- else :
49- add_element (pred_node .end_idx , 'nodes' )
50- num_nodes = len (message_passing_dict ["tokens2nodes" ])
51- if prev_node_idx [pred_node .ids ] != - 1 :
52- message_passing_dict ['edge_index_nodes' ].append (
53- [prev_node_idx [pred_node .ids ], num_nodes - 1 ]
54- )
55- else :
56- add_element (pred_node .end_idx , 'nodes' )
57- add_element (succ_node .end_idx , 'edges' )
58- add_element (succ_node .end_idx , 'nodes' )
59- node2edge_idxs [pred_node .ids ].append (edge_idx )
60- node2edge_idxs [succ_node .ids ].append (edge_idx )
61- num_nodes = len (message_passing_dict ["tokens2nodes" ])
62- message_passing_dict ['edge_index_nodes' ].append ([num_nodes - 2 , num_nodes - 1 ])
63- if prev_node_idx [pred_node .ids ] != - 1 :
64- message_passing_dict ['edge_index_nodes' ].append (
65- [prev_node_idx [pred_node .ids ], num_nodes - 2 ]
66- )
67- if prev_node_idx [succ_node .ids ] != - 1 :
68- message_passing_dict ['edge_index_nodes' ].append (
69- [prev_node_idx [succ_node .ids ], num_nodes - 1 ]
70- )
71- prev_node_idx [pred_node .ids ] = num_nodes - 2
72- prev_node_idx [succ_node .ids ] = num_nodes - 1
36+ scaling_constant = torch .Tensor (
37+ np .sqrt ([key_representations .size (1 )])
38+ ).to (key_representations .device )
39+ dot_products = (
40+ query_representations [edge_index [1 ]]
41+ * key_representations [edge_index [0 ]]
42+ ).sum (1 ) / scaling_constant
43+ weights = scatter_softmax (src = dot_products , index = edge_index [1 ], dim = 0 )
44+ weighted_probs = weights .unsqueeze (1 ) * values [edge_index [0 ]]
45+ return scatter (src = weighted_probs , index = edge_index [1 ], dim = 0 )
7346
74- for edge_idxs in node2edge_idxs .values ():
75- if len (edge_idxs ) < 2 :
76- continue
77- for (idx0 , idx1 ) in itertools .combinations (list (set (edge_idxs )), 2 ):
78- message_passing_dict ['edge_index_edges' ].append (sorted ([idx0 , idx1 ]))
7947
80- def to_torch (array : Union [List [int ], List [List [int ]]]) -> torch .Tensor :
81- """ Converts an array to a torch Tensor and returns it"""
82- if len (array ) == 0 or isinstance (array [0 ], int ):
83- return torch .from_numpy (np .array (array )).long ().to (token_ids .device )
84- else :
85- return torch .from_numpy (np .array (array ).transpose (1 , 0 )).long ().to (token_ids .device )
48+ class GatedGraphCrossAttentionLayer (torch .nn .Module ):
49+ """ A module for performing gated cross attention between elements in a graph that
50+ have been serialized in a sequence of tokens and the token sequence
8651
87- message_passing_dict ['tokens2edges' ] = to_torch (message_passing_dict ['tokens2edges' ])
88- message_passing_dict ['edges2tokens' ] = to_torch (message_passing_dict ['edges2tokens' ])
89- message_passing_dict ['tokens2nodes' ] = to_torch (message_passing_dict ['tokens2nodes' ])
90- message_passing_dict ['nodes2tokens' ] = to_torch (message_passing_dict ['nodes2tokens' ])
91- message_passing_dict ['edge_index_nodes' ] = to_torch (message_passing_dict ['edge_index_nodes' ])
92- message_passing_dict ['edge_index_edges' ] = to_torch (message_passing_dict ['edge_index_edges' ])
93- message_passing_dicts .append (dict (message_passing_dict ))
94- return message_passing_dicts
52+ a key element of this layer is that it enforces that information about elements in the graph
53+ can only be passed to tokens describing later elements in the sequence
9554
55+ This layer contains methods to pass information either between nodes or edges within
56+ the serialized graph
9657
97- class CausalMessagePassingLayer (torch .nn .Module ):
98- """ A torch.nn.Module for performing causal message passing within an autoregressive
99- language model
58+ This layer is heavily inspired by Flamingo a paper on incorporating image information
59+ into LLM inference - https://arxiv.org/pdf/2204.14198
10060 """
10161 def __init__ (self , gnn_type : str , embedding_size : int ):
10262 super ().__init__ ()
103- self .nodes_layer = GNNLayerFactory [gnn_type ].value (embedding_size , embedding_size )
104- self .edges_layer = GNNLayerFactory [gnn_type ].value (embedding_size , embedding_size )
105- self .gating_parameter_a = torch .nn .Parameter (torch .zeros (1 ))
106- self .gating_parameter_b = torch .nn .Parameter (torch .zeros (1 ))
63+ self .gnn_layer = GNNLayerFactory [gnn_type ].value (embedding_size , embedding_size )
64+ self .gating_message_passing = torch .nn .Parameter (torch .zeros (1 ))
65+ self .gating_linear = torch .nn .Parameter (torch .zeros (1 ))
66+ self .key_embedder = torch .nn .Linear (embedding_size , 64 )
67+ self .query_embedder = torch .nn .Linear (embedding_size , 64 )
68+ self .linear_layer = torch .nn .Linear (embedding_size , embedding_size )
10769
10870 def forward (
10971 self ,
@@ -112,28 +74,195 @@ def forward(
11274 ) -> torch .Tensor :
11375 new_token_embeddings = []
11476 for t_embeddings , message_passing_dict in zip (token_embeddings , message_passing_dicts ):
115- token_edges_embeddings = torch .zeros_like (t_embeddings )
116- token_nodes_embeddings = torch .zeros_like (t_embeddings )
117- if message_passing_dict ['tokens2edges' ].numel () > 0 :
118- edges_embeddings = t_embeddings [message_passing_dict ['tokens2edges' ]]
119- if message_passing_dict ['edge_index_edges' ].numel () > 0 :
120- edges_embeddings = self .edges_layer (
121- edges_embeddings ,
122- message_passing_dict ['edge_index_edges' ]
123- )
124- token_edges_embeddings [message_passing_dict ['edges2tokens' ]] = edges_embeddings
125- if message_passing_dict ['tokens2nodes' ].numel () > 0 :
126- nodes_embeddings = t_embeddings [message_passing_dict ['tokens2nodes' ]]
127- if message_passing_dict ['edge_index_nodes' ].numel () > 0 :
128- nodes_embeddings = self .nodes_layer (
129- nodes_embeddings ,
130- message_passing_dict ['edge_index_nodes' ]
77+ new_t_embeddings = torch .zeros_like (t_embeddings )
78+ if message_passing_dict ['tokens2elements' ].numel () > 0 :
79+ element_embeddings = t_embeddings [message_passing_dict ['tokens2elements' ]]
80+ if message_passing_dict ['edge_index' ].numel () > 0 :
81+ element_embeddings = self .gnn_layer (
82+ element_embeddings ,
83+ message_passing_dict ['edge_index' ]
13184 )
132- token_nodes_embeddings [message_passing_dict ['nodes2tokens' ]] = nodes_embeddings
133- new_t_embeddings = (
134- t_embeddings
135- + torch .tanh (self .gating_parameter_a ) * token_edges_embeddings
136- + torch .tanh (self .gating_parameter_b ) * token_nodes_embeddings
137- )
85+ start_idx , end_idx = message_passing_dict ['slice_idxs' ]
86+ new_t_embeddings [start_idx :end_idx ] = graph_cross_attention (
87+ values = element_embeddings ,
88+ key_representations = self .key_embedder (element_embeddings ),
89+ query_representations = self .query_embedder (t_embeddings ),
90+ edge_index = message_passing_dict ['elements2tokens' ]
91+ )[start_idx :]
92+ new_t_embeddings = t_embeddings + torch .tanh (self .gating_message_passing ) * new_t_embeddings
93+ new_t_embeddings = new_t_embeddings + torch .tanh (self .gating_linear ) * self .linear_layer (new_t_embeddings )
13894 new_token_embeddings .append (new_t_embeddings .unsqueeze (0 ))
13995 return torch .cat (new_token_embeddings , dim = 0 )
96+
97+ @classmethod
98+ def build_node_information_passing (
99+ cls ,
100+ edge_sequences : List [List [Tuple [SequenceElement , Optional [SequenceElement ], Optional [SequenceElement ]]]],
101+ device : torch .device
102+ ) -> List [Dict [str , torch .Tensor ]]:
103+ """ Returns the indice mappings required to perform pass node information in between
104+ language model blocks of an autoregressive language model for nodes in a serialized
105+ graph
106+ """
107+ message_passing_dicts = []
108+ for edge_sequence in edge_sequences :
109+ message_passing_dict = {'tokens2elements' : [], 'elements2tokens' : [], 'edge_index' : []}
110+ add_node = partial (
111+ cls .add_node ,
112+ end_idx = cls .get_sequence_end (edge_sequence ),
113+ last_occurence_idx = defaultdict (lambda : - 1 ),
114+ message_passing_dict = message_passing_dict
115+ )
116+ for edge_idx , sequenced_edge in enumerate (edge_sequence ):
117+ pred_node , edge , succ_node = sequenced_edge
118+ if edge_idx == len (edge_sequence ) - 1 :
119+ if (
120+ not isinstance (succ_node , SequenceElement )
121+ and not isinstance (edge , SequenceElement )
122+ ):
123+ continue
124+ else :
125+ add_node (pred_node )
126+ else :
127+ add_node (pred_node )
128+ add_node (succ_node )
129+ message_passing_dicts .append (cls .to_torch (dict (message_passing_dict ), device ))
130+ return message_passing_dicts
131+
132+ @classmethod
133+ def build_edge_information_passing (
134+ cls ,
135+ edge_sequences : List [List [Tuple [SequenceElement , Optional [SequenceElement ], Optional [SequenceElement ]]]],
136+ device : torch .device
137+ ) -> List [Dict [str , torch .Tensor ]]:
138+ """ Returns the indice mappings required to perform pass edge information in between
139+ language model blocks of an autoregressive language model for nodes in a serialized
140+ graph
141+ """
142+ message_passing_dicts = []
143+ for edge_sequence in edge_sequences :
144+ message_passing_dict = {'tokens2elements' : [], 'elements2tokens' : [], 'edge_index' : []}
145+ node2edge_idxs = defaultdict (list )
146+ add_edge = partial (
147+ cls .add_edge ,
148+ end_idx = cls .get_sequence_end (edge_sequence ),
149+ node2edge_idxs = node2edge_idxs ,
150+ message_passing_dict = message_passing_dict
151+ )
152+ for sequenced_edge in edge_sequence [:- 1 ]:
153+ add_edge (sequenced_edge )
154+
155+ # calculating adjacency matrix between edges (edges in this adjacency matrix always
156+ # point from edges earlier in the serialized version of the graph to edges later in
157+ # the graph)
158+ for edge_idxs in node2edge_idxs .values ():
159+ if len (edge_idxs ) < 2 :
160+ continue
161+ for (idx0 , idx1 ) in itertools .combinations (list (set (edge_idxs )), 2 ):
162+ message_passing_dict ['edge_index' ].append (sorted ([idx0 , idx1 ]))
163+ message_passing_dicts .append (cls .to_torch (dict (message_passing_dict ), device ))
164+ return message_passing_dicts
165+
166+ @staticmethod
167+ def get_sequence_end (
168+ edge_sequence : List [Tuple [SequenceElement , Optional [SequenceElement ], Optional [SequenceElement ]]],
169+ ) -> Tuple [int , int ]:
170+ """ Returns last index + 1 of elements in the serialized graph sequence """
171+ pred_node , edge , succ_node = edge_sequence [- 1 ]
172+ if isinstance (succ_node , SequenceElement ):
173+ end_idx = succ_node .end_idx
174+ elif isinstance (edge , SequenceElement ):
175+ end_idx = edge .end_idx
176+ else :
177+ end_idx = pred_node .end_idx
178+ return end_idx
179+
180+ @classmethod
181+ def add_node (
182+ cls ,
183+ current_occurence : SequenceElement ,
184+ end_idx : int ,
185+ last_occurence_idx : Dict [Tuple [int ], int ],
186+ message_passing_dict : Dict [str , Union [List [int ], List [List [int ]]]]
187+ ):
188+ """ Each time a node is listed in a serialized version of its corresponding graph, it is
189+ added as a node in a new artificial graph. This means in the new artificial graph, a
190+ node in the original graph may appear more than once. For every node added to the
191+ artificial graph, this function adds an edge which maps between occurences of
192+ the same node in the original graph if the node has been printed previously in the
193+ serialized graph. The edge points from the previous occurence to the current occurence.
194+ i.e. H_1 - O, O - H_2, would create an edge from O -> O since it occurs more than
195+ once in the graph
196+ """
197+ prev_length = len (message_passing_dict [f"tokens2elements" ])
198+ cls .add_element_for_information_passing (
199+ start_idx = current_occurence .end_idx ,
200+ end_idx = end_idx ,
201+ message_passing_dict = message_passing_dict
202+ )
203+ curr_length = len (message_passing_dict [f"tokens2elements" ])
204+ if last_occurence_idx [current_occurence .ids ] != - 1 and curr_length > prev_length :
205+ current_idx = len (message_passing_dict ["tokens2elements" ]) - 1
206+ message_passing_dict ['edge_index' ].append (
207+ [last_occurence_idx [current_occurence .ids ], current_idx ]
208+ )
209+ last_occurence_idx [current_occurence .ids ] = current_idx
210+
211+ @classmethod
212+ def add_edge (
213+ cls ,
214+ sequenced_edge : Tuple [SequenceElement , SequenceElement , SequenceElement ],
215+ end_idx : int ,
216+ node2edge_idxs : Dict [Tuple [int ], List [int ]],
217+ message_passing_dict : Dict [str , Union [List [int ], List [List [int ]]]]
218+ ):
219+ """ Adds an edge as element to pass information between in a serialized graph """
220+ pred_node , _ , succ_node = sequenced_edge
221+ prev_length = len (message_passing_dict [f"tokens2elements" ])
222+ cls .add_element_for_information_passing (
223+ start_idx = succ_node .end_idx ,
224+ end_idx = end_idx ,
225+ message_passing_dict = message_passing_dict
226+ )
227+ curr_length = len (message_passing_dict [f"tokens2elements" ])
228+ if curr_length > prev_length :
229+ current_idx = len (message_passing_dict ["tokens2elements" ]) - 1
230+ node2edge_idxs [pred_node .ids ].append (current_idx )
231+ node2edge_idxs [succ_node .ids ].append (current_idx )
232+
233+ @staticmethod
234+ def add_element_for_information_passing (
235+ start_idx : int ,
236+ end_idx : int ,
237+ message_passing_dict : Dict [str , Union [List [int ], List [List [int ]]]]
238+ ):
239+ """ Adds an element to the message passing dictionary, the element is either a node
240+ or an edge. Adding the element means adding the necessary indices to the mapping
241+ tokens2elements and elements2tokens, so that it is possible to map to elements
242+ and back
243+ """
244+ if start_idx != end_idx :
245+ message_passing_dict [f"tokens2elements" ].append (start_idx - 1 )
246+ for sequence_idx in range (start_idx , end_idx ):
247+ message_passing_dict [f"elements2tokens" ].append (
248+ [len (message_passing_dict [f"tokens2elements" ]) - 1 , sequence_idx ]
249+ )
250+
251+ @staticmethod
252+ def to_torch (
253+ array_dict : Dict [str , Union [List [int ], List [List [int ]]]],
254+ device : torch .device
255+ ) -> Dict [str , torch .Tensor ]:
256+ """ Converts a dictionary of lists of integers to a dictionary of torch Tensor and returns it
257+ """
258+ for key , array in array_dict .items ():
259+ if len (array ) == 0 or isinstance (array [0 ], int ):
260+ array_dict [key ] = torch .from_numpy (np .array (array )).long ().to (device )
261+ else :
262+ array_dict [key ] = torch .from_numpy (np .array (array ).transpose (1 , 0 )).long ().to (device )
263+ if array_dict ['elements2tokens' ].numel () > 0 :
264+ array_dict ['slice_idxs' ] = torch .from_numpy (np .array ([
265+ array_dict ['elements2tokens' ][1 ].min ().item (),
266+ array_dict ['elements2tokens' ][1 ].max ().item () + 1
267+ ])).long ().to (device )
268+ return array_dict
0 commit comments