11""" A set of functions to perform message passing on a serialized graph in an LLM """
22
3+ import enum
34from collections import defaultdict
45import itertools
5- from typing import Callable , Dict , List , Optional , Tuple
6+ from typing import Dict , List , Optional , Tuple , Union
67
78import numpy as np
89import torch
9- from torch_scatter import scatter
10+ import torch_geometric
1011
1112from .desequence_graph_ids import SequenceElement
1213
1314
15+ class GNNLayerFactory (enum .Enum ):
16+ gcn = torch_geometric .nn .GCNConv
17+ sage = torch_geometric .nn .SAGEConv
18+ gat = torch_geometric .nn .GATConv
19+
20+
1421def build_message_passing_matrices (
1522 token_ids : torch .Tensor ,
1623 edge_sequences : List [List [Tuple [SequenceElement , Optional [SequenceElement ], Optional [SequenceElement ]]]]
@@ -19,114 +26,114 @@ def build_message_passing_matrices(
1926 language model blocks of an autoregressive language model
2027 """
2128 message_passing_dicts = []
22- for t_ids , edge_sequence in zip ( token_ids , edge_sequences ) :
23- message_passing_dict = { 'tokens2edges' : [], 'edges2tokens' : [], 'inverse_edge_index' : []}
29+ for edge_sequence in edge_sequences :
30+ message_passing_dict = defaultdict ( list )
2431 node2edge_idxs = defaultdict (list )
32+ prev_node_idx = defaultdict (lambda : - 1 )
33+
34+ def add_element (end_idx : int , element_type : str ):
35+ """ Adds an element to the edge or node graphs used for message passing """
36+ assert element_type in ['nodes' , 'edges' ]
37+ message_passing_dict [f"tokens2{ element_type } " ].append (end_idx - 1 )
38+ message_passing_dict [f"{ element_type } 2tokens" ].append (end_idx )
39+
2540 for edge_idx , sequenced_edge in enumerate (edge_sequence ):
2641 pred_node , edge , succ_node = sequenced_edge
27- node2edge_idxs [pred_node .ids ].append (edge_idx )
28- if isinstance (succ_node , SequenceElement ):
29- end_idx = succ_node .end_idx
30- node2edge_idxs [succ_node .ids ].append (edge_idx )
31- elif isinstance (edge , SequenceElement ):
32- end_idx = edge .end_idx
42+ if edge_idx == len (edge_sequence ) - 1 :
43+ if (
44+ not isinstance (succ_node , SequenceElement )
45+ and not isinstance (edge , SequenceElement )
46+ ):
47+ continue
48+ else :
49+ add_element (pred_node .end_idx , 'nodes' )
50+ num_nodes = len (message_passing_dict ["tokens2nodes" ])
51+ if prev_node_idx [pred_node .ids ] != - 1 :
52+ message_passing_dict ['edge_index_nodes' ].append (
53+ [prev_node_idx [pred_node .ids ], num_nodes - 1 ]
54+ )
3355 else :
34- end_idx = pred_node .end_idx
35- for token_idx in range (pred_node .start_idx , end_idx ):
36- message_passing_dict ['tokens2edges' ].append ([token_idx , edge_idx ])
37- message_passing_dict ['edges2tokens' ].append ([edge_idx , token_idx ])
38- if len (message_passing_dict ['edges2tokens' ]) > 0 :
39- message_passing_dict ['edges2tokens' ] = add_missing_idxs (
40- message_passing_dict ['edges2tokens' ],
41- num_incoming_nodes = len (edge_sequence ),
42- num_outgoing_nodes = len (t_ids )
43- )
44- message_passing_dict ['inverse_edge_index' ] = []
56+ add_element (pred_node .end_idx , 'nodes' )
57+ add_element (succ_node .end_idx , 'edges' )
58+ add_element (succ_node .end_idx , 'nodes' )
59+ node2edge_idxs [pred_node .ids ].append (edge_idx )
60+ node2edge_idxs [succ_node .ids ].append (edge_idx )
61+ num_nodes = len (message_passing_dict ["tokens2nodes" ])
62+ message_passing_dict ['edge_index_nodes' ].append ([num_nodes - 2 , num_nodes - 1 ])
63+ if prev_node_idx [pred_node .ids ] != - 1 :
64+ message_passing_dict ['edge_index_nodes' ].append (
65+ [prev_node_idx [pred_node .ids ], num_nodes - 2 ]
66+ )
67+ if prev_node_idx [succ_node .ids ] != - 1 :
68+ message_passing_dict ['edge_index_nodes' ].append (
69+ [prev_node_idx [succ_node .ids ], num_nodes - 1 ]
70+ )
71+ prev_node_idx [pred_node .ids ] = num_nodes - 2
72+ prev_node_idx [succ_node .ids ] = num_nodes - 1
73+
4574 for edge_idxs in node2edge_idxs .values ():
4675 if len (edge_idxs ) < 2 :
4776 continue
4877 for (idx0 , idx1 ) in itertools .combinations (list (set (edge_idxs )), 2 ):
49- message_passing_dict ['inverse_edge_index' ].append (
50- [idx0 , idx1 ] if idx0 < idx1 else [idx1 , idx0 ]
51- )
52- if len (message_passing_dict ['inverse_edge_index' ]) > 0 :
53- message_passing_dict ['inverse_edge_index' ] = add_missing_idxs (
54- message_passing_dict ['inverse_edge_index' ],
55- num_incoming_nodes = len (edge_sequence ),
56- num_outgoing_nodes = len (edge_sequence )
57- )
58- message_passing_dicts .append ({
59- key : torch .from_numpy (np .array (value ).transpose (1 , 0 )).long ().to (token_ids .device )
60- if len (value ) > 0 else torch .from_numpy (np .array (value )).long ().to (token_ids .device )
61- for key , value in message_passing_dict .items ()
62- })
63- return message_passing_dicts
78+ message_passing_dict ['edge_index_edges' ].append (sorted ([idx0 , idx1 ]))
6479
80+ def to_torch (array : Union [List [int ], List [List [int ]]]) -> torch .Tensor :
81+ """ Converts an array to a torch Tensor and returns it"""
82+ if len (array ) == 0 or isinstance (array [0 ], int ):
83+ return torch .from_numpy (np .array (array )).long ().to (token_ids .device )
84+ else :
85+ return torch .from_numpy (np .array (array ).transpose (1 , 0 )).long ().to (token_ids .device )
6586
66- def add_missing_idxs (
67- edge_index : List [List [int ]],
68- * ,
69- num_incoming_nodes : int ,
70- num_outgoing_nodes : int
71- ) -> List [List [int ]]:
72- """ Adds edges from a dummy node to all outgoing nodes which do not have an edge pointing to
73- them. This is to facilitate causal message passing where a node should have access to its
74- own embedding using torch.scatter to perform message passing.
75- """
76- existing_idxs = set ([node_idxs [- 1 ] for node_idxs in edge_index ])
77- missing_idxs = set (range (num_outgoing_nodes )) - existing_idxs
78- for missing_idx in missing_idxs :
79- edge_index .append ([num_incoming_nodes , missing_idx ])
80- return edge_index
87+ message_passing_dict ['tokens2edges' ] = to_torch (message_passing_dict ['tokens2edges' ])
88+ message_passing_dict ['edges2tokens' ] = to_torch (message_passing_dict ['edges2tokens' ])
89+ message_passing_dict ['tokens2nodes' ] = to_torch (message_passing_dict ['tokens2nodes' ])
90+ message_passing_dict ['nodes2tokens' ] = to_torch (message_passing_dict ['nodes2tokens' ])
91+ message_passing_dict ['edge_index_nodes' ] = to_torch (message_passing_dict ['edge_index_nodes' ])
92+ message_passing_dict ['edge_index_edges' ] = to_torch (message_passing_dict ['edge_index_edges' ])
93+ message_passing_dicts .append (dict (message_passing_dict ))
94+ return message_passing_dicts
8195
8296
83- def perform_causal_message_passing (
84- token_embeddings : torch .Tensor ,
85- message_passing_dicts : List [Dict [str , torch .Tensor ]],
86- linear_layer : Optional [Callable ] = None ,
87- reduce : str = 'mean'
88- ) -> torch .Tensor :
89- """ Returns token embeddings in a sequence where causal message passing has been performed on
90- the token ids based on the serialized graph described in the sequence
97+ class CausalMessagePassingLayer (torch .nn .Module ):
98+ """ A torch.nn.Module for performing causal message passing within an autoregressive
99+ language model
91100 """
92- new_token_embeddings = []
93- for t_embeddings , message_passing_dict in zip (token_embeddings , message_passing_dicts ):
94- if message_passing_dict ['inverse_edge_index' ].numel () == 0 :
95- new_t_embeddings = t_embeddings
96- else :
97- edge_embeddings = scatter (
98- src = t_embeddings [message_passing_dict ['tokens2edges' ][0 ]],
99- dim = 0 ,
100- index = message_passing_dict ['tokens2edges' ][1 ],
101- reduce = reduce
102- )
103- # adding dummy tensor to make sure that the output tensor of message passing is the
104- # correct size because causal message passing does not allow self loops
105- edge_embeddings = torch .cat ([
106- edge_embeddings ,
107- torch .zeros_like (edge_embeddings [0 ].unsqueeze (0 ))
108- ], dim = 0 )
109- # adding dummy tensor to make sure that the output tensor of message passing is the
110- # correct size because causal message passing does not allow self loops
111- edge_embeddings = scatter (
112- src = edge_embeddings [message_passing_dict ['inverse_edge_index' ][0 ]],
113- dim = 0 ,
114- index = message_passing_dict ['inverse_edge_index' ][1 ],
115- reduce = reduce
116- )
117- if linear_layer is not None :
118- edge_embeddings = linear_layer (edge_embeddings )
119- edge_embeddings = torch .relu (edge_embeddings )
120- edge_embeddings = torch .cat ([
121- edge_embeddings ,
122- torch .zeros_like (edge_embeddings [0 ].unsqueeze (0 ))
123- ], dim = 0 )
124- new_t_embeddings = scatter (
125- src = edge_embeddings [message_passing_dict ['edges2tokens' ][0 ]],
126- dim = 0 ,
127- index = message_passing_dict ['edges2tokens' ][1 ],
128- reduce = reduce
101+ def __init__ (self , gnn_type : str , embedding_size : int ):
102+ super ().__init__ ()
103+ self .nodes_layer = GNNLayerFactory [gnn_type ].value (embedding_size , embedding_size )
104+ self .edges_layer = GNNLayerFactory [gnn_type ].value (embedding_size , embedding_size )
105+ self .gating_parameter_a = torch .nn .Parameter (torch .zeros (1 ))
106+ self .gating_parameter_b = torch .nn .Parameter (torch .zeros (1 ))
107+
108+ def forward (
109+ self ,
110+ token_embeddings : torch .Tensor ,
111+ message_passing_dicts : List [Dict [str , torch .Tensor ]]
112+ ) -> torch .Tensor :
113+ new_token_embeddings = []
114+ for t_embeddings , message_passing_dict in zip (token_embeddings , message_passing_dicts ):
115+ token_edges_embeddings = torch .zeros_like (t_embeddings )
116+ token_nodes_embeddings = torch .zeros_like (t_embeddings )
117+ if message_passing_dict ['tokens2edges' ].numel () > 0 :
118+ edges_embeddings = t_embeddings [message_passing_dict ['tokens2edges' ]]
119+ if message_passing_dict ['edge_index_edges' ].numel () > 0 :
120+ edges_embeddings = self .edges_layer (
121+ edges_embeddings ,
122+ message_passing_dict ['edge_index_edges' ]
123+ )
124+ token_edges_embeddings [message_passing_dict ['edges2tokens' ]] = edges_embeddings
125+ if message_passing_dict ['tokens2nodes' ].numel () > 0 :
126+ nodes_embeddings = t_embeddings [message_passing_dict ['tokens2nodes' ]]
127+ if message_passing_dict ['edge_index_nodes' ].numel () > 0 :
128+ nodes_embeddings = self .nodes_layer (
129+ nodes_embeddings ,
130+ message_passing_dict ['edge_index_nodes' ]
131+ )
132+ token_nodes_embeddings [message_passing_dict ['nodes2tokens' ]] = nodes_embeddings
133+ new_t_embeddings = (
134+ t_embeddings
135+ + torch .tanh (self .gating_parameter_a ) * token_edges_embeddings
136+ + torch .tanh (self .gating_parameter_b ) * token_nodes_embeddings
129137 )
130- assert new_t_embeddings .shape == t_embeddings .shape
131- new_token_embeddings .append (new_t_embeddings .unsqueeze (0 ))
132- return torch .cat (new_token_embeddings , dim = 0 ) + token_embeddings
138+ new_token_embeddings .append (new_t_embeddings .unsqueeze (0 ))
139+ return torch .cat (new_token_embeddings , dim = 0 )
0 commit comments