Skip to content

Commit 3604898

Browse files
zacharesvahanhov
andauthored
Clearer code and simpler method for within LLM message passing (#3)
* novelty debugging * running solution * message passing slightly better * simplified serialize * current code * flamingo inspired * message passing correctly implemented * positions update * removing commented code * causal message passing * edge case in case using another model besides serialize * update message passing and position embedding * Update src/transformers/models/bloom/modeling_bloom.py * removed unnecessary code * clearer message passing code * Update src/transformers/models/bloom/causal_message_passing.py * Update src/transformers/models/bloom/causal_message_passing.py * Update src/transformers/models/bloom/causal_message_passing.py Co-authored-by: vahanhov <32771381+vahanhov@users.noreply.github.com> --------- Co-authored-by: vahanhov <32771381+vahanhov@users.noreply.github.com>
1 parent 844767b commit 3604898

File tree

2 files changed

+245
-115
lines changed

2 files changed

+245
-115
lines changed
Lines changed: 229 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
""" A set of functions to perform message passing on a serialized graph in an LLM """
1+
""" A module for learning to pass information between elements on a serialized graph in an LLM
2+
without violating the causality constraint of autoregressive generation (passing information
3+
backwards in the sequence)
4+
"""
25

36
import enum
7+
from functools import partial
48
from collections import defaultdict
59
import itertools
610
from typing import Dict, List, Optional, Tuple, Union
711

812
import numpy as np
913
import torch
14+
from torch_scatter import scatter, scatter_softmax
1015
import torch_geometric
1116

1217
from .desequence_graph_ids import SequenceElement
@@ -18,92 +23,49 @@ class GNNLayerFactory(enum.Enum):
1823
gat = torch_geometric.nn.GATConv
1924

2025

21-
def build_message_passing_matrices(
22-
token_ids: torch.Tensor,
23-
edge_sequences: List[List[Tuple[SequenceElement, Optional[SequenceElement], Optional[SequenceElement]]]]
24-
) -> List[Dict[str, torch.Tensor]]:
25-
""" Returns the adjacency matrices required to perform causal message passing in between
26-
language model blocks of an autoregressive language model
26+
def graph_cross_attention(
27+
values: torch.Tensor,
28+
key_representations: torch.Tensor,
29+
query_representations: torch.Tensor,
30+
edge_index: torch.Tensor
31+
) -> torch.Tensor:
32+
""" Performs graph attention on a set of prior probabilities uing the representation of each
33+
node in the graph to calculate the attention weights. The implemented attention is dot
34+
product attention as implemented in the transformer architecture
2735
"""
28-
message_passing_dicts = []
29-
for edge_sequence in edge_sequences:
30-
message_passing_dict = defaultdict(list)
31-
node2edge_idxs = defaultdict(list)
32-
prev_node_idx = defaultdict(lambda: -1)
33-
34-
def add_element(end_idx: int, element_type: str):
35-
""" Adds an element to the edge or node graphs used for message passing """
36-
assert element_type in ['nodes', 'edges']
37-
message_passing_dict[f"tokens2{element_type}"].append(end_idx - 1)
38-
message_passing_dict[f"{element_type}2tokens"].append(end_idx)
39-
40-
for edge_idx, sequenced_edge in enumerate(edge_sequence):
41-
pred_node, edge, succ_node = sequenced_edge
42-
if edge_idx == len(edge_sequence) - 1:
43-
if (
44-
not isinstance(succ_node, SequenceElement)
45-
and not isinstance(edge, SequenceElement)
46-
):
47-
continue
48-
else:
49-
add_element(pred_node.end_idx, 'nodes')
50-
num_nodes = len(message_passing_dict["tokens2nodes"])
51-
if prev_node_idx[pred_node.ids] != -1:
52-
message_passing_dict['edge_index_nodes'].append(
53-
[prev_node_idx[pred_node.ids], num_nodes - 1]
54-
)
55-
else:
56-
add_element(pred_node.end_idx, 'nodes')
57-
add_element(succ_node.end_idx, 'edges')
58-
add_element(succ_node.end_idx, 'nodes')
59-
node2edge_idxs[pred_node.ids].append(edge_idx)
60-
node2edge_idxs[succ_node.ids].append(edge_idx)
61-
num_nodes = len(message_passing_dict["tokens2nodes"])
62-
message_passing_dict['edge_index_nodes'].append([num_nodes - 2, num_nodes - 1])
63-
if prev_node_idx[pred_node.ids] != -1:
64-
message_passing_dict['edge_index_nodes'].append(
65-
[prev_node_idx[pred_node.ids], num_nodes - 2]
66-
)
67-
if prev_node_idx[succ_node.ids] != -1:
68-
message_passing_dict['edge_index_nodes'].append(
69-
[prev_node_idx[succ_node.ids], num_nodes - 1]
70-
)
71-
prev_node_idx[pred_node.ids] = num_nodes - 2
72-
prev_node_idx[succ_node.ids] = num_nodes - 1
36+
scaling_constant = torch.Tensor(
37+
np.sqrt([key_representations.size(1)])
38+
).to(key_representations.device)
39+
dot_products = (
40+
query_representations[edge_index[1]]
41+
* key_representations[edge_index[0]]
42+
).sum(1) / scaling_constant
43+
weights = scatter_softmax(src=dot_products, index=edge_index[1], dim=0)
44+
weighted_probs = weights.unsqueeze(1) * values[edge_index[0]]
45+
return scatter(src=weighted_probs, index=edge_index[1], dim=0)
7346

74-
for edge_idxs in node2edge_idxs.values():
75-
if len(edge_idxs) < 2:
76-
continue
77-
for (idx0, idx1) in itertools.combinations(list(set(edge_idxs)), 2):
78-
message_passing_dict['edge_index_edges'].append(sorted([idx0, idx1]))
7947

80-
def to_torch(array: Union[List[int], List[List[int]]]) -> torch.Tensor:
81-
""" Converts an array to a torch Tensor and returns it"""
82-
if len(array) == 0 or isinstance(array[0], int):
83-
return torch.from_numpy(np.array(array)).long().to(token_ids.device)
84-
else:
85-
return torch.from_numpy(np.array(array).transpose(1, 0)).long().to(token_ids.device)
48+
class GatedGraphCrossAttentionLayer(torch.nn.Module):
49+
""" A module for performing gated cross attention between elements in a graph that
50+
have been serialized in a sequence of tokens and the token sequence
8651
87-
message_passing_dict['tokens2edges'] = to_torch(message_passing_dict['tokens2edges'])
88-
message_passing_dict['edges2tokens'] = to_torch(message_passing_dict['edges2tokens'])
89-
message_passing_dict['tokens2nodes'] = to_torch(message_passing_dict['tokens2nodes'])
90-
message_passing_dict['nodes2tokens'] = to_torch(message_passing_dict['nodes2tokens'])
91-
message_passing_dict['edge_index_nodes'] = to_torch(message_passing_dict['edge_index_nodes'])
92-
message_passing_dict['edge_index_edges'] = to_torch(message_passing_dict['edge_index_edges'])
93-
message_passing_dicts.append(dict(message_passing_dict))
94-
return message_passing_dicts
52+
a key element of this layer is that it enforces that information about elements in the graph
53+
can only be passed to tokens describing later elements in the sequence
9554
55+
This layer contains methods to pass information either between nodes or edges within
56+
the serialized graph
9657
97-
class CausalMessagePassingLayer(torch.nn.Module):
98-
""" A torch.nn.Module for performing causal message passing within an autoregressive
99-
language model
58+
This layer is heavily inspired by Flamingo a paper on incorporating image information
59+
into LLM inference - https://arxiv.org/pdf/2204.14198
10060
"""
10161
def __init__(self, gnn_type: str, embedding_size: int):
10262
super().__init__()
103-
self.nodes_layer = GNNLayerFactory[gnn_type].value(embedding_size, embedding_size)
104-
self.edges_layer = GNNLayerFactory[gnn_type].value(embedding_size, embedding_size)
105-
self.gating_parameter_a = torch.nn.Parameter(torch.zeros(1))
106-
self.gating_parameter_b = torch.nn.Parameter(torch.zeros(1))
63+
self.gnn_layer = GNNLayerFactory[gnn_type].value(embedding_size, embedding_size)
64+
self.gating_message_passing = torch.nn.Parameter(torch.zeros(1))
65+
self.gating_linear = torch.nn.Parameter(torch.zeros(1))
66+
self.key_embedder = torch.nn.Linear(embedding_size, 64)
67+
self.query_embedder = torch.nn.Linear(embedding_size, 64)
68+
self.linear_layer = torch.nn.Linear(embedding_size, embedding_size)
10769

10870
def forward(
10971
self,
@@ -112,28 +74,195 @@ def forward(
11274
) -> torch.Tensor:
11375
new_token_embeddings = []
11476
for t_embeddings, message_passing_dict in zip(token_embeddings, message_passing_dicts):
115-
token_edges_embeddings = torch.zeros_like(t_embeddings)
116-
token_nodes_embeddings = torch.zeros_like(t_embeddings)
117-
if message_passing_dict['tokens2edges'].numel() > 0:
118-
edges_embeddings = t_embeddings[message_passing_dict['tokens2edges']]
119-
if message_passing_dict['edge_index_edges'].numel() > 0:
120-
edges_embeddings = self.edges_layer(
121-
edges_embeddings,
122-
message_passing_dict['edge_index_edges']
123-
)
124-
token_edges_embeddings[message_passing_dict['edges2tokens']] = edges_embeddings
125-
if message_passing_dict['tokens2nodes'].numel() > 0:
126-
nodes_embeddings = t_embeddings[message_passing_dict['tokens2nodes']]
127-
if message_passing_dict['edge_index_nodes'].numel() > 0:
128-
nodes_embeddings = self.nodes_layer(
129-
nodes_embeddings,
130-
message_passing_dict['edge_index_nodes']
77+
new_t_embeddings = torch.zeros_like(t_embeddings)
78+
if message_passing_dict['tokens2elements'].numel() > 0:
79+
element_embeddings = t_embeddings[message_passing_dict['tokens2elements']]
80+
if message_passing_dict['edge_index'].numel() > 0:
81+
element_embeddings = self.gnn_layer(
82+
element_embeddings,
83+
message_passing_dict['edge_index']
13184
)
132-
token_nodes_embeddings[message_passing_dict['nodes2tokens']] = nodes_embeddings
133-
new_t_embeddings = (
134-
t_embeddings
135-
+ torch.tanh(self.gating_parameter_a) * token_edges_embeddings
136-
+ torch.tanh(self.gating_parameter_b) * token_nodes_embeddings
137-
)
85+
start_idx, end_idx = message_passing_dict['slice_idxs']
86+
new_t_embeddings[start_idx:end_idx] = graph_cross_attention(
87+
values=element_embeddings,
88+
key_representations=self.key_embedder(element_embeddings),
89+
query_representations=self.query_embedder(t_embeddings),
90+
edge_index=message_passing_dict['elements2tokens']
91+
)[start_idx:]
92+
new_t_embeddings = t_embeddings + torch.tanh(self.gating_message_passing) * new_t_embeddings
93+
new_t_embeddings = new_t_embeddings + torch.tanh(self.gating_linear) * self.linear_layer(new_t_embeddings)
13894
new_token_embeddings.append(new_t_embeddings.unsqueeze(0))
13995
return torch.cat(new_token_embeddings, dim=0)
96+
97+
@classmethod
98+
def build_node_information_passing(
99+
cls,
100+
edge_sequences: List[List[Tuple[SequenceElement, Optional[SequenceElement], Optional[SequenceElement]]]],
101+
device: torch.device
102+
) -> List[Dict[str, torch.Tensor]]:
103+
""" Returns the indice mappings required to perform pass node information in between
104+
language model blocks of an autoregressive language model for nodes in a serialized
105+
graph
106+
"""
107+
message_passing_dicts = []
108+
for edge_sequence in edge_sequences:
109+
message_passing_dict = {'tokens2elements': [], 'elements2tokens': [], 'edge_index': []}
110+
add_node = partial(
111+
cls.add_node,
112+
end_idx=cls.get_sequence_end(edge_sequence),
113+
last_occurence_idx=defaultdict(lambda: -1),
114+
message_passing_dict=message_passing_dict
115+
)
116+
for edge_idx, sequenced_edge in enumerate(edge_sequence):
117+
pred_node, edge, succ_node = sequenced_edge
118+
if edge_idx == len(edge_sequence) - 1:
119+
if (
120+
not isinstance(succ_node, SequenceElement)
121+
and not isinstance(edge, SequenceElement)
122+
):
123+
continue
124+
else:
125+
add_node(pred_node)
126+
else:
127+
add_node(pred_node)
128+
add_node(succ_node)
129+
message_passing_dicts.append(cls.to_torch(dict(message_passing_dict), device))
130+
return message_passing_dicts
131+
132+
@classmethod
133+
def build_edge_information_passing(
134+
cls,
135+
edge_sequences: List[List[Tuple[SequenceElement, Optional[SequenceElement], Optional[SequenceElement]]]],
136+
device: torch.device
137+
) -> List[Dict[str, torch.Tensor]]:
138+
""" Returns the indice mappings required to perform pass edge information in between
139+
language model blocks of an autoregressive language model for nodes in a serialized
140+
graph
141+
"""
142+
message_passing_dicts = []
143+
for edge_sequence in edge_sequences:
144+
message_passing_dict = {'tokens2elements': [], 'elements2tokens': [], 'edge_index': []}
145+
node2edge_idxs = defaultdict(list)
146+
add_edge = partial(
147+
cls.add_edge,
148+
end_idx=cls.get_sequence_end(edge_sequence),
149+
node2edge_idxs=node2edge_idxs,
150+
message_passing_dict=message_passing_dict
151+
)
152+
for sequenced_edge in edge_sequence[:-1]:
153+
add_edge(sequenced_edge)
154+
155+
# calculating adjacency matrix between edges (edges in this adjacency matrix always
156+
# point from edges earlier in the serialized version of the graph to edges later in
157+
# the graph)
158+
for edge_idxs in node2edge_idxs.values():
159+
if len(edge_idxs) < 2:
160+
continue
161+
for (idx0, idx1) in itertools.combinations(list(set(edge_idxs)), 2):
162+
message_passing_dict['edge_index'].append(sorted([idx0, idx1]))
163+
message_passing_dicts.append(cls.to_torch(dict(message_passing_dict), device))
164+
return message_passing_dicts
165+
166+
@staticmethod
167+
def get_sequence_end(
168+
edge_sequence: List[Tuple[SequenceElement, Optional[SequenceElement], Optional[SequenceElement]]],
169+
) -> Tuple[int, int]:
170+
""" Returns last index + 1 of elements in the serialized graph sequence """
171+
pred_node, edge, succ_node = edge_sequence[-1]
172+
if isinstance(succ_node, SequenceElement):
173+
end_idx = succ_node.end_idx
174+
elif isinstance(edge, SequenceElement):
175+
end_idx = edge.end_idx
176+
else:
177+
end_idx = pred_node.end_idx
178+
return end_idx
179+
180+
@classmethod
181+
def add_node(
182+
cls,
183+
current_occurence: SequenceElement,
184+
end_idx: int,
185+
last_occurence_idx: Dict[Tuple[int], int],
186+
message_passing_dict: Dict[str, Union[List[int], List[List[int]]]]
187+
):
188+
""" Each time a node is listed in a serialized version of its corresponding graph, it is
189+
added as a node in a new artificial graph. This means in the new artificial graph, a
190+
node in the original graph may appear more than once. For every node added to the
191+
artificial graph, this function adds an edge which maps between occurences of
192+
the same node in the original graph if the node has been printed previously in the
193+
serialized graph. The edge points from the previous occurence to the current occurence.
194+
i.e. H_1 - O, O - H_2, would create an edge from O -> O since it occurs more than
195+
once in the graph
196+
"""
197+
prev_length = len(message_passing_dict[f"tokens2elements"])
198+
cls.add_element_for_information_passing(
199+
start_idx=current_occurence.end_idx,
200+
end_idx=end_idx,
201+
message_passing_dict=message_passing_dict
202+
)
203+
curr_length = len(message_passing_dict[f"tokens2elements"])
204+
if last_occurence_idx[current_occurence.ids] != -1 and curr_length > prev_length:
205+
current_idx = len(message_passing_dict["tokens2elements"]) - 1
206+
message_passing_dict['edge_index'].append(
207+
[last_occurence_idx[current_occurence.ids], current_idx]
208+
)
209+
last_occurence_idx[current_occurence.ids] = current_idx
210+
211+
@classmethod
212+
def add_edge(
213+
cls,
214+
sequenced_edge: Tuple[SequenceElement, SequenceElement, SequenceElement],
215+
end_idx: int,
216+
node2edge_idxs: Dict[Tuple[int], List[int]],
217+
message_passing_dict: Dict[str, Union[List[int], List[List[int]]]]
218+
):
219+
""" Adds an edge as element to pass information between in a serialized graph """
220+
pred_node, _, succ_node = sequenced_edge
221+
prev_length = len(message_passing_dict[f"tokens2elements"])
222+
cls.add_element_for_information_passing(
223+
start_idx=succ_node.end_idx,
224+
end_idx=end_idx,
225+
message_passing_dict=message_passing_dict
226+
)
227+
curr_length = len(message_passing_dict[f"tokens2elements"])
228+
if curr_length > prev_length:
229+
current_idx = len(message_passing_dict["tokens2elements"]) - 1
230+
node2edge_idxs[pred_node.ids].append(current_idx)
231+
node2edge_idxs[succ_node.ids].append(current_idx)
232+
233+
@staticmethod
234+
def add_element_for_information_passing(
235+
start_idx: int,
236+
end_idx: int,
237+
message_passing_dict: Dict[str, Union[List[int], List[List[int]]]]
238+
):
239+
""" Adds an element to the message passing dictionary, the element is either a node
240+
or an edge. Adding the element means adding the necessary indices to the mapping
241+
tokens2elements and elements2tokens, so that it is possible to map to elements
242+
and back
243+
"""
244+
if start_idx != end_idx:
245+
message_passing_dict[f"tokens2elements"].append(start_idx - 1)
246+
for sequence_idx in range(start_idx, end_idx):
247+
message_passing_dict[f"elements2tokens"].append(
248+
[len(message_passing_dict[f"tokens2elements"]) - 1, sequence_idx]
249+
)
250+
251+
@staticmethod
252+
def to_torch(
253+
array_dict: Dict[str, Union[List[int], List[List[int]]]],
254+
device: torch.device
255+
) -> Dict[str, torch.Tensor]:
256+
""" Converts a dictionary of lists of integers to a dictionary of torch Tensor and returns it
257+
"""
258+
for key, array in array_dict.items():
259+
if len(array) == 0 or isinstance(array[0], int):
260+
array_dict[key] = torch.from_numpy(np.array(array)).long().to(device)
261+
else:
262+
array_dict[key] = torch.from_numpy(np.array(array).transpose(1, 0)).long().to(device)
263+
if array_dict['elements2tokens'].numel() > 0:
264+
array_dict['slice_idxs'] = torch.from_numpy(np.array([
265+
array_dict['elements2tokens'][1].min().item(),
266+
array_dict['elements2tokens'][1].max().item() + 1
267+
])).long().to(device)
268+
return array_dict

0 commit comments

Comments
 (0)