Skip to content
Merged
205 changes: 106 additions & 99 deletions src/transformers/models/bloom/causal_message_passing.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
""" A set of functions to perform message passing on a serialized graph in an LLM """

import enum
from collections import defaultdict
import itertools
from typing import Callable, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from torch_scatter import scatter
import torch_geometric

from .desequence_graph_ids import SequenceElement


class GNNLayerFactory(enum.Enum):
gcn = torch_geometric.nn.GCNConv
sage = torch_geometric.nn.SAGEConv
gat = torch_geometric.nn.GATConv


def build_message_passing_matrices(
token_ids: torch.Tensor,
edge_sequences: List[List[Tuple[SequenceElement, Optional[SequenceElement], Optional[SequenceElement]]]]
Expand All @@ -19,114 +26,114 @@ def build_message_passing_matrices(
language model blocks of an autoregressive language model
"""
message_passing_dicts = []
for t_ids, edge_sequence in zip(token_ids, edge_sequences):
message_passing_dict = {'tokens2edges': [], 'edges2tokens': [], 'inverse_edge_index': []}
for edge_sequence in edge_sequences:
message_passing_dict = defaultdict(list)
node2edge_idxs = defaultdict(list)
prev_node_idx = defaultdict(lambda: -1)

def add_element(end_idx: int, element_type: str):
""" Adds an element to the edge or node graphs used for message passing """
assert element_type in ['nodes', 'edges']
message_passing_dict[f"tokens2{element_type}"].append(end_idx - 1)
message_passing_dict[f"{element_type}2tokens"].append(end_idx)

for edge_idx, sequenced_edge in enumerate(edge_sequence):
pred_node, edge, succ_node = sequenced_edge
node2edge_idxs[pred_node.ids].append(edge_idx)
if isinstance(succ_node, SequenceElement):
end_idx = succ_node.end_idx
node2edge_idxs[succ_node.ids].append(edge_idx)
elif isinstance(edge, SequenceElement):
end_idx = edge.end_idx
if edge_idx == len(edge_sequence) - 1:
if (
not isinstance(succ_node, SequenceElement)
and not isinstance(edge, SequenceElement)
):
continue
else:
add_element(pred_node.end_idx, 'nodes')
num_nodes = len(message_passing_dict["tokens2nodes"])
if prev_node_idx[pred_node.ids] != -1:
message_passing_dict['edge_index_nodes'].append(
[prev_node_idx[pred_node.ids], num_nodes - 1]
)
else:
end_idx = pred_node.end_idx
for token_idx in range(pred_node.start_idx, end_idx):
message_passing_dict['tokens2edges'].append([token_idx, edge_idx])
message_passing_dict['edges2tokens'].append([edge_idx, token_idx])
if len(message_passing_dict['edges2tokens']) > 0:
message_passing_dict['edges2tokens'] = add_missing_idxs(
message_passing_dict['edges2tokens'],
num_incoming_nodes=len(edge_sequence),
num_outgoing_nodes=len(t_ids)
)
message_passing_dict['inverse_edge_index'] = []
add_element(pred_node.end_idx, 'nodes')
add_element(succ_node.end_idx, 'edges')
add_element(succ_node.end_idx, 'nodes')
node2edge_idxs[pred_node.ids].append(edge_idx)
node2edge_idxs[succ_node.ids].append(edge_idx)
num_nodes = len(message_passing_dict["tokens2nodes"])
message_passing_dict['edge_index_nodes'].append([num_nodes - 2, num_nodes - 1])
if prev_node_idx[pred_node.ids] != -1:
message_passing_dict['edge_index_nodes'].append(
[prev_node_idx[pred_node.ids], num_nodes - 2]
)
if prev_node_idx[succ_node.ids] != -1:
message_passing_dict['edge_index_nodes'].append(
[prev_node_idx[succ_node.ids], num_nodes - 1]
)
prev_node_idx[pred_node.ids] = num_nodes - 2
prev_node_idx[succ_node.ids] = num_nodes - 1

for edge_idxs in node2edge_idxs.values():
if len(edge_idxs) < 2:
continue
for (idx0, idx1) in itertools.combinations(list(set(edge_idxs)), 2):
message_passing_dict['inverse_edge_index'].append(
[idx0, idx1] if idx0 < idx1 else [idx1, idx0]
)
if len(message_passing_dict['inverse_edge_index']) > 0:
message_passing_dict['inverse_edge_index'] = add_missing_idxs(
message_passing_dict['inverse_edge_index'],
num_incoming_nodes=len(edge_sequence),
num_outgoing_nodes=len(edge_sequence)
)
message_passing_dicts.append({
key: torch.from_numpy(np.array(value).transpose(1, 0)).long().to(token_ids.device)
if len(value) > 0 else torch.from_numpy(np.array(value)).long().to(token_ids.device)
for key, value in message_passing_dict.items()
})
return message_passing_dicts
message_passing_dict['edge_index_edges'].append(sorted([idx0, idx1]))

def to_torch(array: Union[List[int], List[List[int]]]) -> torch.Tensor:
""" Converts an array to a torch Tensor and returns it"""
if len(array) == 0 or isinstance(array[0], int):
return torch.from_numpy(np.array(array)).long().to(token_ids.device)
else:
return torch.from_numpy(np.array(array).transpose(1, 0)).long().to(token_ids.device)

def add_missing_idxs(
edge_index: List[List[int]],
*,
num_incoming_nodes: int,
num_outgoing_nodes: int
) -> List[List[int]]:
""" Adds edges from a dummy node to all outgoing nodes which do not have an edge pointing to
them. This is to facilitate causal message passing where a node should have access to its
own embedding using torch.scatter to perform message passing.
"""
existing_idxs = set([node_idxs[-1] for node_idxs in edge_index])
missing_idxs = set(range(num_outgoing_nodes)) - existing_idxs
for missing_idx in missing_idxs:
edge_index.append([num_incoming_nodes, missing_idx])
return edge_index
message_passing_dict['tokens2edges'] = to_torch(message_passing_dict['tokens2edges'])
message_passing_dict['edges2tokens'] = to_torch(message_passing_dict['edges2tokens'])
message_passing_dict['tokens2nodes'] = to_torch(message_passing_dict['tokens2nodes'])
message_passing_dict['nodes2tokens'] = to_torch(message_passing_dict['nodes2tokens'])
message_passing_dict['edge_index_nodes'] = to_torch(message_passing_dict['edge_index_nodes'])
message_passing_dict['edge_index_edges'] = to_torch(message_passing_dict['edge_index_edges'])
message_passing_dicts.append(dict(message_passing_dict))
return message_passing_dicts


def perform_causal_message_passing(
token_embeddings: torch.Tensor,
message_passing_dicts: List[Dict[str, torch.Tensor]],
linear_layer: Optional[Callable] = None,
reduce: str = 'mean'
) -> torch.Tensor:
""" Returns token embeddings in a sequence where causal message passing has been performed on
the token ids based on the serialized graph described in the sequence
class CausalMessagePassingLayer(torch.nn.Module):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually we should add a link to the paper here, because it's not at all trivial what's happening here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

""" A torch.nn.Module for performing causal message passing within an autoregressive
language model
"""
new_token_embeddings = []
for t_embeddings, message_passing_dict in zip(token_embeddings, message_passing_dicts):
if message_passing_dict['inverse_edge_index'].numel() == 0:
new_t_embeddings = t_embeddings
else:
edge_embeddings = scatter(
src=t_embeddings[message_passing_dict['tokens2edges'][0]],
dim=0,
index=message_passing_dict['tokens2edges'][1],
reduce=reduce
)
# adding dummy tensor to make sure that the output tensor of message passing is the
# correct size because causal message passing does not allow self loops
edge_embeddings = torch.cat([
edge_embeddings,
torch.zeros_like(edge_embeddings[0].unsqueeze(0))
], dim=0)
# adding dummy tensor to make sure that the output tensor of message passing is the
# correct size because causal message passing does not allow self loops
edge_embeddings = scatter(
src=edge_embeddings[message_passing_dict['inverse_edge_index'][0]],
dim=0,
index=message_passing_dict['inverse_edge_index'][1],
reduce=reduce
)
if linear_layer is not None:
edge_embeddings = linear_layer(edge_embeddings)
edge_embeddings = torch.relu(edge_embeddings)
edge_embeddings = torch.cat([
edge_embeddings,
torch.zeros_like(edge_embeddings[0].unsqueeze(0))
], dim=0)
new_t_embeddings = scatter(
src=edge_embeddings[message_passing_dict['edges2tokens'][0]],
dim=0,
index=message_passing_dict['edges2tokens'][1],
reduce=reduce
def __init__(self, gnn_type: str, embedding_size: int):
super().__init__()
self.nodes_layer = GNNLayerFactory[gnn_type].value(embedding_size, embedding_size)
self.edges_layer = GNNLayerFactory[gnn_type].value(embedding_size, embedding_size)
self.gating_parameter_a = torch.nn.Parameter(torch.zeros(1))
self.gating_parameter_b = torch.nn.Parameter(torch.zeros(1))

def forward(
self,
token_embeddings: torch.Tensor,
message_passing_dicts: List[Dict[str, torch.Tensor]]
) -> torch.Tensor:
new_token_embeddings = []
for t_embeddings, message_passing_dict in zip(token_embeddings, message_passing_dicts):
token_edges_embeddings = torch.zeros_like(t_embeddings)
token_nodes_embeddings = torch.zeros_like(t_embeddings)
if message_passing_dict['tokens2edges'].numel() > 0:
edges_embeddings = t_embeddings[message_passing_dict['tokens2edges']]
if message_passing_dict['edge_index_edges'].numel() > 0:
edges_embeddings = self.edges_layer(
edges_embeddings,
message_passing_dict['edge_index_edges']
)
token_edges_embeddings[message_passing_dict['edges2tokens']] = edges_embeddings
if message_passing_dict['tokens2nodes'].numel() > 0:
nodes_embeddings = t_embeddings[message_passing_dict['tokens2nodes']]
if message_passing_dict['edge_index_nodes'].numel() > 0:
nodes_embeddings = self.nodes_layer(
nodes_embeddings,
message_passing_dict['edge_index_nodes']
)
token_nodes_embeddings[message_passing_dict['nodes2tokens']] = nodes_embeddings
new_t_embeddings = (
t_embeddings
+ torch.tanh(self.gating_parameter_a) * token_edges_embeddings
+ torch.tanh(self.gating_parameter_b) * token_nodes_embeddings
)
assert new_t_embeddings.shape == t_embeddings.shape
new_token_embeddings.append(new_t_embeddings.unsqueeze(0))
return torch.cat(new_token_embeddings, dim=0) + token_embeddings
new_token_embeddings.append(new_t_embeddings.unsqueeze(0))
return torch.cat(new_token_embeddings, dim=0)
65 changes: 17 additions & 48 deletions src/transformers/models/bloom/desequence_graph_ids.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
""" A set of functions to identify a serialized graph within a list of token ids """

from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass

Expand All @@ -25,56 +24,21 @@ def extract_edge_sequence(
sequence = _extract_graph_elements(token_ids, graph_tokens)
edges = []
if len(sequence) > 2:
node_explored = defaultdict(lambda: False)
for elem0, elem1, elem2 in zip(sequence[:-2], sequence[1:-1], sequence[2:]):
if (
elem0.token in graph_tokens['nodes']
and elem1.token in graph_tokens['edge']
and elem2.token in graph_tokens['nodes']
elem0.token == graph_tokens['gen_edge']
and elem1.token == graph_tokens['edge']
and elem2.token == graph_tokens['node']
): # edge syntax
if (
# test to see if there is an ungenerated node contained within the identified edge
# essentially another syntax error that can occur during graph generation
(elem0.token in graph_tokens['node'] and not node_explored[elem0.ids])
or (
elem2.token in graph_tokens['node']
and not node_explored[elem2.ids]
and elem2.ids != elem0.ids # to account for self loops
)
):
continue
if elem0.token in graph_tokens['gen_node'] and not node_explored[elem0.ids]:
node_explored[elem0.ids] = True
if elem2.token in graph_tokens['gen_node'] and not node_explored[elem2.ids]:
node_explored[elem2.ids] = True
edges.append((elem0, elem1, elem2))
if (
(
len(edges) > 0
and edges[-1][0] != sequence[-3]
and edges[-1][1] != sequence[-2]
and edges[-1][2] != sequence[-1]
)
or len(edges) == 0
len(sequence) > 1
and sequence[-2].token == graph_tokens['gen_edge']
and sequence[-1].token == graph_tokens['edge']
):
if (
len(sequence) > 1
and sequence[-2].token in graph_tokens['nodes']
and sequence[-1].token in graph_tokens['edge']
and not (
sequence[-2].token in graph_tokens['node']
and not node_explored[sequence[-2].ids]
)
):
edges.append((sequence[-2], sequence[-1], None))
elif (
len(sequence) > 0 and sequence[-1].token in graph_tokens['nodes']
and not (
sequence[-1].token in graph_tokens['node']
and not node_explored[sequence[-1].ids]
)
):
edges.append((sequence[-1], None, None))
edges.append((sequence[-2], sequence[-1], None))
elif len(sequence) > 0 and sequence[-1].token == graph_tokens['gen_edge']:
edges.append((sequence[-1], None, None))
return edges


Expand All @@ -85,12 +49,17 @@ def _extract_graph_elements(
""" Returns a parsable representation of the serialized graph in a sequence of token ids,
if none is found, returns an empty list
"""
if len(graph_tokens) == 0:
return []
sequence = []
prev_token_id, prev_idx, final_idx = None, -1, len(token_ids)
for token_idx, token_id in enumerate(token_ids):
if token_id in graph_tokens['gen_node'] and prev_token_id is None:
if token_id == graph_tokens['gen_edge'] and prev_token_id is None:
prev_token_id, prev_idx = token_id, token_idx
elif token_id in graph_tokens['graph'] and prev_token_id is not None:
elif (
token_id in [graph_tokens['gen_edge'], graph_tokens['edge'], graph_tokens['node']]
and prev_token_id is not None
):
sequence.append(SequenceElement(
token=prev_token_id,
start_idx=prev_idx,
Expand All @@ -99,7 +68,7 @@ def _extract_graph_elements(
length=token_idx - prev_idx
))
prev_token_id, prev_idx = token_id, token_idx
elif token_id in graph_tokens['eos'] and prev_token_id is not None:
elif token_id in [graph_tokens['eos'], graph_tokens['pad']] and prev_token_id is not None:
final_idx = token_idx
break
if prev_token_id is not None:
Expand Down
Loading