Skip to content

Commit 9f156f0

Browse files
zacharesvahanhov
authored andcommitted
feat(causal message passing) (#2)
* novelty debugging * running solution * message passing slightly better * simplified serialize * current code * flamingo inspired * message passing correctly implemented * positions update * removing commented code * causal message passing * edge case in case using another model besides serialize * update message passing and position embedding * Update src/transformers/models/bloom/modeling_bloom.py * removed unnecessary code
1 parent 591cf9c commit 9f156f0

File tree

4 files changed

+232
-213
lines changed

4 files changed

+232
-213
lines changed
Lines changed: 106 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
""" A set of functions to perform message passing on a serialized graph in an LLM """
22

3+
import enum
34
from collections import defaultdict
45
import itertools
5-
from typing import Callable, Dict, List, Optional, Tuple
6+
from typing import Dict, List, Optional, Tuple, Union
67

78
import numpy as np
89
import torch
9-
from torch_scatter import scatter
10+
import torch_geometric
1011

1112
from .desequence_graph_ids import SequenceElement
1213

1314

15+
class GNNLayerFactory(enum.Enum):
16+
gcn = torch_geometric.nn.GCNConv
17+
sage = torch_geometric.nn.SAGEConv
18+
gat = torch_geometric.nn.GATConv
19+
20+
1421
def build_message_passing_matrices(
1522
token_ids: torch.Tensor,
1623
edge_sequences: List[List[Tuple[SequenceElement, Optional[SequenceElement], Optional[SequenceElement]]]]
@@ -19,114 +26,114 @@ def build_message_passing_matrices(
1926
language model blocks of an autoregressive language model
2027
"""
2128
message_passing_dicts = []
22-
for t_ids, edge_sequence in zip(token_ids, edge_sequences):
23-
message_passing_dict = {'tokens2edges': [], 'edges2tokens': [], 'inverse_edge_index': []}
29+
for edge_sequence in edge_sequences:
30+
message_passing_dict = defaultdict(list)
2431
node2edge_idxs = defaultdict(list)
32+
prev_node_idx = defaultdict(lambda: -1)
33+
34+
def add_element(end_idx: int, element_type: str):
35+
""" Adds an element to the edge or node graphs used for message passing """
36+
assert element_type in ['nodes', 'edges']
37+
message_passing_dict[f"tokens2{element_type}"].append(end_idx - 1)
38+
message_passing_dict[f"{element_type}2tokens"].append(end_idx)
39+
2540
for edge_idx, sequenced_edge in enumerate(edge_sequence):
2641
pred_node, edge, succ_node = sequenced_edge
27-
node2edge_idxs[pred_node.ids].append(edge_idx)
28-
if isinstance(succ_node, SequenceElement):
29-
end_idx = succ_node.end_idx
30-
node2edge_idxs[succ_node.ids].append(edge_idx)
31-
elif isinstance(edge, SequenceElement):
32-
end_idx = edge.end_idx
42+
if edge_idx == len(edge_sequence) - 1:
43+
if (
44+
not isinstance(succ_node, SequenceElement)
45+
and not isinstance(edge, SequenceElement)
46+
):
47+
continue
48+
else:
49+
add_element(pred_node.end_idx, 'nodes')
50+
num_nodes = len(message_passing_dict["tokens2nodes"])
51+
if prev_node_idx[pred_node.ids] != -1:
52+
message_passing_dict['edge_index_nodes'].append(
53+
[prev_node_idx[pred_node.ids], num_nodes - 1]
54+
)
3355
else:
34-
end_idx = pred_node.end_idx
35-
for token_idx in range(pred_node.start_idx, end_idx):
36-
message_passing_dict['tokens2edges'].append([token_idx, edge_idx])
37-
message_passing_dict['edges2tokens'].append([edge_idx, token_idx])
38-
if len(message_passing_dict['edges2tokens']) > 0:
39-
message_passing_dict['edges2tokens'] = add_missing_idxs(
40-
message_passing_dict['edges2tokens'],
41-
num_incoming_nodes=len(edge_sequence),
42-
num_outgoing_nodes=len(t_ids)
43-
)
44-
message_passing_dict['inverse_edge_index'] = []
56+
add_element(pred_node.end_idx, 'nodes')
57+
add_element(succ_node.end_idx, 'edges')
58+
add_element(succ_node.end_idx, 'nodes')
59+
node2edge_idxs[pred_node.ids].append(edge_idx)
60+
node2edge_idxs[succ_node.ids].append(edge_idx)
61+
num_nodes = len(message_passing_dict["tokens2nodes"])
62+
message_passing_dict['edge_index_nodes'].append([num_nodes - 2, num_nodes - 1])
63+
if prev_node_idx[pred_node.ids] != -1:
64+
message_passing_dict['edge_index_nodes'].append(
65+
[prev_node_idx[pred_node.ids], num_nodes - 2]
66+
)
67+
if prev_node_idx[succ_node.ids] != -1:
68+
message_passing_dict['edge_index_nodes'].append(
69+
[prev_node_idx[succ_node.ids], num_nodes - 1]
70+
)
71+
prev_node_idx[pred_node.ids] = num_nodes - 2
72+
prev_node_idx[succ_node.ids] = num_nodes - 1
73+
4574
for edge_idxs in node2edge_idxs.values():
4675
if len(edge_idxs) < 2:
4776
continue
4877
for (idx0, idx1) in itertools.combinations(list(set(edge_idxs)), 2):
49-
message_passing_dict['inverse_edge_index'].append(
50-
[idx0, idx1] if idx0 < idx1 else [idx1, idx0]
51-
)
52-
if len(message_passing_dict['inverse_edge_index']) > 0:
53-
message_passing_dict['inverse_edge_index'] = add_missing_idxs(
54-
message_passing_dict['inverse_edge_index'],
55-
num_incoming_nodes=len(edge_sequence),
56-
num_outgoing_nodes=len(edge_sequence)
57-
)
58-
message_passing_dicts.append({
59-
key: torch.from_numpy(np.array(value).transpose(1, 0)).long().to(token_ids.device)
60-
if len(value) > 0 else torch.from_numpy(np.array(value)).long().to(token_ids.device)
61-
for key, value in message_passing_dict.items()
62-
})
63-
return message_passing_dicts
78+
message_passing_dict['edge_index_edges'].append(sorted([idx0, idx1]))
6479

80+
def to_torch(array: Union[List[int], List[List[int]]]) -> torch.Tensor:
81+
""" Converts an array to a torch Tensor and returns it"""
82+
if len(array) == 0 or isinstance(array[0], int):
83+
return torch.from_numpy(np.array(array)).long().to(token_ids.device)
84+
else:
85+
return torch.from_numpy(np.array(array).transpose(1, 0)).long().to(token_ids.device)
6586

66-
def add_missing_idxs(
67-
edge_index: List[List[int]],
68-
*,
69-
num_incoming_nodes: int,
70-
num_outgoing_nodes: int
71-
) -> List[List[int]]:
72-
""" Adds edges from a dummy node to all outgoing nodes which do not have an edge pointing to
73-
them. This is to facilitate causal message passing where a node should have access to its
74-
own embedding using torch.scatter to perform message passing.
75-
"""
76-
existing_idxs = set([node_idxs[-1] for node_idxs in edge_index])
77-
missing_idxs = set(range(num_outgoing_nodes)) - existing_idxs
78-
for missing_idx in missing_idxs:
79-
edge_index.append([num_incoming_nodes, missing_idx])
80-
return edge_index
87+
message_passing_dict['tokens2edges'] = to_torch(message_passing_dict['tokens2edges'])
88+
message_passing_dict['edges2tokens'] = to_torch(message_passing_dict['edges2tokens'])
89+
message_passing_dict['tokens2nodes'] = to_torch(message_passing_dict['tokens2nodes'])
90+
message_passing_dict['nodes2tokens'] = to_torch(message_passing_dict['nodes2tokens'])
91+
message_passing_dict['edge_index_nodes'] = to_torch(message_passing_dict['edge_index_nodes'])
92+
message_passing_dict['edge_index_edges'] = to_torch(message_passing_dict['edge_index_edges'])
93+
message_passing_dicts.append(dict(message_passing_dict))
94+
return message_passing_dicts
8195

8296

83-
def perform_causal_message_passing(
84-
token_embeddings: torch.Tensor,
85-
message_passing_dicts: List[Dict[str, torch.Tensor]],
86-
linear_layer: Optional[Callable] = None,
87-
reduce: str = 'mean'
88-
) -> torch.Tensor:
89-
""" Returns token embeddings in a sequence where causal message passing has been performed on
90-
the token ids based on the serialized graph described in the sequence
97+
class CausalMessagePassingLayer(torch.nn.Module):
98+
""" A torch.nn.Module for performing causal message passing within an autoregressive
99+
language model
91100
"""
92-
new_token_embeddings = []
93-
for t_embeddings, message_passing_dict in zip(token_embeddings, message_passing_dicts):
94-
if message_passing_dict['inverse_edge_index'].numel() == 0:
95-
new_t_embeddings = t_embeddings
96-
else:
97-
edge_embeddings = scatter(
98-
src=t_embeddings[message_passing_dict['tokens2edges'][0]],
99-
dim=0,
100-
index=message_passing_dict['tokens2edges'][1],
101-
reduce=reduce
102-
)
103-
# adding dummy tensor to make sure that the output tensor of message passing is the
104-
# correct size because causal message passing does not allow self loops
105-
edge_embeddings = torch.cat([
106-
edge_embeddings,
107-
torch.zeros_like(edge_embeddings[0].unsqueeze(0))
108-
], dim=0)
109-
# adding dummy tensor to make sure that the output tensor of message passing is the
110-
# correct size because causal message passing does not allow self loops
111-
edge_embeddings = scatter(
112-
src=edge_embeddings[message_passing_dict['inverse_edge_index'][0]],
113-
dim=0,
114-
index=message_passing_dict['inverse_edge_index'][1],
115-
reduce=reduce
116-
)
117-
if linear_layer is not None:
118-
edge_embeddings = linear_layer(edge_embeddings)
119-
edge_embeddings = torch.relu(edge_embeddings)
120-
edge_embeddings = torch.cat([
121-
edge_embeddings,
122-
torch.zeros_like(edge_embeddings[0].unsqueeze(0))
123-
], dim=0)
124-
new_t_embeddings = scatter(
125-
src=edge_embeddings[message_passing_dict['edges2tokens'][0]],
126-
dim=0,
127-
index=message_passing_dict['edges2tokens'][1],
128-
reduce=reduce
101+
def __init__(self, gnn_type: str, embedding_size: int):
102+
super().__init__()
103+
self.nodes_layer = GNNLayerFactory[gnn_type].value(embedding_size, embedding_size)
104+
self.edges_layer = GNNLayerFactory[gnn_type].value(embedding_size, embedding_size)
105+
self.gating_parameter_a = torch.nn.Parameter(torch.zeros(1))
106+
self.gating_parameter_b = torch.nn.Parameter(torch.zeros(1))
107+
108+
def forward(
109+
self,
110+
token_embeddings: torch.Tensor,
111+
message_passing_dicts: List[Dict[str, torch.Tensor]]
112+
) -> torch.Tensor:
113+
new_token_embeddings = []
114+
for t_embeddings, message_passing_dict in zip(token_embeddings, message_passing_dicts):
115+
token_edges_embeddings = torch.zeros_like(t_embeddings)
116+
token_nodes_embeddings = torch.zeros_like(t_embeddings)
117+
if message_passing_dict['tokens2edges'].numel() > 0:
118+
edges_embeddings = t_embeddings[message_passing_dict['tokens2edges']]
119+
if message_passing_dict['edge_index_edges'].numel() > 0:
120+
edges_embeddings = self.edges_layer(
121+
edges_embeddings,
122+
message_passing_dict['edge_index_edges']
123+
)
124+
token_edges_embeddings[message_passing_dict['edges2tokens']] = edges_embeddings
125+
if message_passing_dict['tokens2nodes'].numel() > 0:
126+
nodes_embeddings = t_embeddings[message_passing_dict['tokens2nodes']]
127+
if message_passing_dict['edge_index_nodes'].numel() > 0:
128+
nodes_embeddings = self.nodes_layer(
129+
nodes_embeddings,
130+
message_passing_dict['edge_index_nodes']
131+
)
132+
token_nodes_embeddings[message_passing_dict['nodes2tokens']] = nodes_embeddings
133+
new_t_embeddings = (
134+
t_embeddings
135+
+ torch.tanh(self.gating_parameter_a) * token_edges_embeddings
136+
+ torch.tanh(self.gating_parameter_b) * token_nodes_embeddings
129137
)
130-
assert new_t_embeddings.shape == t_embeddings.shape
131-
new_token_embeddings.append(new_t_embeddings.unsqueeze(0))
132-
return torch.cat(new_token_embeddings, dim=0) + token_embeddings
138+
new_token_embeddings.append(new_t_embeddings.unsqueeze(0))
139+
return torch.cat(new_token_embeddings, dim=0)

src/transformers/models/bloom/desequence_graph_ids.py

Lines changed: 17 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
""" A set of functions to identify a serialized graph within a list of token ids """
22

3-
from collections import defaultdict
43
from typing import Dict, List, Optional, Tuple
54
from dataclasses import dataclass
65

@@ -25,56 +24,21 @@ def extract_edge_sequence(
2524
sequence = _extract_graph_elements(token_ids, graph_tokens)
2625
edges = []
2726
if len(sequence) > 2:
28-
node_explored = defaultdict(lambda: False)
2927
for elem0, elem1, elem2 in zip(sequence[:-2], sequence[1:-1], sequence[2:]):
3028
if (
31-
elem0.token in graph_tokens['nodes']
32-
and elem1.token in graph_tokens['edge']
33-
and elem2.token in graph_tokens['nodes']
29+
elem0.token == graph_tokens['gen_edge']
30+
and elem1.token == graph_tokens['edge']
31+
and elem2.token == graph_tokens['node']
3432
): # edge syntax
35-
if (
36-
# test to see if there is an ungenerated node contained within the identified edge
37-
# essentially another syntax error that can occur during graph generation
38-
(elem0.token in graph_tokens['node'] and not node_explored[elem0.ids])
39-
or (
40-
elem2.token in graph_tokens['node']
41-
and not node_explored[elem2.ids]
42-
and elem2.ids != elem0.ids # to account for self loops
43-
)
44-
):
45-
continue
46-
if elem0.token in graph_tokens['gen_node'] and not node_explored[elem0.ids]:
47-
node_explored[elem0.ids] = True
48-
if elem2.token in graph_tokens['gen_node'] and not node_explored[elem2.ids]:
49-
node_explored[elem2.ids] = True
5033
edges.append((elem0, elem1, elem2))
5134
if (
52-
(
53-
len(edges) > 0
54-
and edges[-1][0] != sequence[-3]
55-
and edges[-1][1] != sequence[-2]
56-
and edges[-1][2] != sequence[-1]
57-
)
58-
or len(edges) == 0
35+
len(sequence) > 1
36+
and sequence[-2].token == graph_tokens['gen_edge']
37+
and sequence[-1].token == graph_tokens['edge']
5938
):
60-
if (
61-
len(sequence) > 1
62-
and sequence[-2].token in graph_tokens['nodes']
63-
and sequence[-1].token in graph_tokens['edge']
64-
and not (
65-
sequence[-2].token in graph_tokens['node']
66-
and not node_explored[sequence[-2].ids]
67-
)
68-
):
69-
edges.append((sequence[-2], sequence[-1], None))
70-
elif (
71-
len(sequence) > 0 and sequence[-1].token in graph_tokens['nodes']
72-
and not (
73-
sequence[-1].token in graph_tokens['node']
74-
and not node_explored[sequence[-1].ids]
75-
)
76-
):
77-
edges.append((sequence[-1], None, None))
39+
edges.append((sequence[-2], sequence[-1], None))
40+
elif len(sequence) > 0 and sequence[-1].token == graph_tokens['gen_edge']:
41+
edges.append((sequence[-1], None, None))
7842
return edges
7943

8044

@@ -85,12 +49,17 @@ def _extract_graph_elements(
8549
""" Returns a parsable representation of the serialized graph in a sequence of token ids,
8650
if none is found, returns an empty list
8751
"""
52+
if len(graph_tokens) == 0:
53+
return []
8854
sequence = []
8955
prev_token_id, prev_idx, final_idx = None, -1, len(token_ids)
9056
for token_idx, token_id in enumerate(token_ids):
91-
if token_id in graph_tokens['gen_node'] and prev_token_id is None:
57+
if token_id == graph_tokens['gen_edge'] and prev_token_id is None:
9258
prev_token_id, prev_idx = token_id, token_idx
93-
elif token_id in graph_tokens['graph'] and prev_token_id is not None:
59+
elif (
60+
token_id in [graph_tokens['gen_edge'], graph_tokens['edge'], graph_tokens['node']]
61+
and prev_token_id is not None
62+
):
9463
sequence.append(SequenceElement(
9564
token=prev_token_id,
9665
start_idx=prev_idx,
@@ -99,7 +68,7 @@ def _extract_graph_elements(
9968
length=token_idx - prev_idx
10069
))
10170
prev_token_id, prev_idx = token_id, token_idx
102-
elif token_id in graph_tokens['eos'] and prev_token_id is not None:
71+
elif token_id in [graph_tokens['eos'], graph_tokens['pad']] and prev_token_id is not None:
10372
final_idx = token_idx
10473
break
10574
if prev_token_id is not None:

0 commit comments

Comments
 (0)