diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 29aa95a9815cac..fdb03f7650bb24 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -682,7 +682,7 @@ def init_graph_information_passing( """ Initializes a set of message passing layers to perform message passing of between graph elements described in an input token id sequence """ - assert element_type in ['nodes', 'edges'], 'unsupported message passing type' + assert element_type in ['node_correspondence', 'edge'], 'unsupported message passing type' self.message_passing_type = element_type self.graph_token_ids = graph_token_ids self.num_gnn_layers = ( @@ -781,7 +781,7 @@ def forward( extract_edge_sequence(t_ids.tolist(), self.graph_token_ids) for t_ids in input_ids ] if self.message_passing_type == 'nodes': - get_matrices = GatedCausalMessagePassingLayer.build_node_information_passing + get_matrices = GatedCausalMessagePassingLayer.build_node_correspondence_information_passing else: get_matrices = GatedCausalMessagePassingLayer.build_edge_information_passing message_passing_dicts = get_matrices(edge_sequences, self.device) diff --git a/src/transformers/models/processing_graphs_within_model/causal_message_passing.py b/src/transformers/models/processing_graphs_within_model/causal_message_passing.py index 3c5b6bd1a22d13..50e6e30a58a3b5 100644 --- a/src/transformers/models/processing_graphs_within_model/causal_message_passing.py +++ b/src/transformers/models/processing_graphs_within_model/causal_message_passing.py @@ -52,8 +52,7 @@ def forward( element_embeddings = t_embeddings[message_passing_dict['tokens2elements']] if message_passing_dict['edge_index'].numel() > 0: element_embeddings = self.gnn_layer( - element_embeddings, - message_passing_dict['edge_index'] + element_embeddings, message_passing_dict['edge_index'] ) new_t_embeddings[message_passing_dict['elements2tokens']] = element_embeddings new_t_embeddings = t_embeddings + torch.tanh(self.gating_message_passing) * new_t_embeddings @@ -61,7 +60,7 @@ def forward( return torch.cat(new_token_embeddings, dim=0) @classmethod - def build_node_information_passing( + def build_node_correspondence_information_passing( cls, edge_sequences: List[List[Tuple[SequenceElement, Optional[SequenceElement], Optional[SequenceElement]]]], device: torch.device @@ -77,7 +76,7 @@ def build_node_information_passing( message_passing_dicts.append(cls.to_torch(dict(message_passing_dict), device)) continue add_node = partial( - cls.add_node, + cls.add_node_correspondence, end_idx=cls.get_sequence_end(edge_sequence), last_occurence_idx=defaultdict(lambda: -1), message_passing_dict=message_passing_dict @@ -149,7 +148,7 @@ def get_sequence_end( return end_idx @classmethod - def add_node( + def add_node_correspondence( cls, current_occurence: SequenceElement, end_idx: int, @@ -172,12 +171,13 @@ def add_node( message_passing_dict=message_passing_dict ) curr_length = len(message_passing_dict[f"tokens2elements"]) - if last_occurence_idx[current_occurence.ids] != -1 and curr_length > prev_length: + full_ids = current_occurence.graph_id + ("--node_ids--",) + current_occurence.ids + if last_occurence_idx[full_ids] != -1 and curr_length > prev_length: current_idx = len(message_passing_dict["tokens2elements"]) - 1 message_passing_dict['edge_index'].append( - [last_occurence_idx[current_occurence.ids], current_idx] + [last_occurence_idx[full_ids], current_idx] ) - last_occurence_idx[current_occurence.ids] = current_idx + last_occurence_idx[full_ids] = current_idx @classmethod def add_edge( @@ -189,17 +189,19 @@ def add_edge( ): """ Adds an edge as element to pass information between in a serialized graph """ pred_node, _, succ_node = sequenced_edge - prev_length = len(message_passing_dict[f"tokens2elements"]) + prev_length = len(message_passing_dict["tokens2elements"]) cls.add_element_for_information_passing( start_idx=succ_node.end_idx, end_idx=end_idx, message_passing_dict=message_passing_dict ) - curr_length = len(message_passing_dict[f"tokens2elements"]) + curr_length = len(message_passing_dict["tokens2elements"]) if curr_length > prev_length: current_idx = len(message_passing_dict["tokens2elements"]) - 1 - node2edge_idxs[pred_node.ids].append(current_idx) - node2edge_idxs[succ_node.ids].append(current_idx) + pred_ids = pred_node.graph_id + ("--node_ids--",) + pred_node.ids + node2edge_idxs[pred_ids].append(current_idx) + succ_ids = succ_node.graph_id + ("--node_ids--",) + succ_node.ids + node2edge_idxs[succ_ids].append(current_idx) @staticmethod def add_element_for_information_passing( diff --git a/src/transformers/models/processing_graphs_within_model/desequence_graph_ids.py b/src/transformers/models/processing_graphs_within_model/desequence_graph_ids.py index 78c57f2909cef2..376073998ad9b3 100644 --- a/src/transformers/models/processing_graphs_within_model/desequence_graph_ids.py +++ b/src/transformers/models/processing_graphs_within_model/desequence_graph_ids.py @@ -14,6 +14,7 @@ class SequenceElement: end_idx: int ids: Tuple[int] length: int + graph_id: Tuple[int] def extract_edge_sequence( @@ -50,31 +51,65 @@ def _extract_graph_elements( if none is found, returns an empty list """ sequence = [] - prev_token_id, prev_idx, final_idx = None, -1, len(token_ids) + sog_idx, graph_id = None, None + prev_token_id, prev_idx, final_idx = None, -1, None for token_idx, token_id in enumerate(token_ids): - if token_id == graph_tokens['pred_node'] and prev_token_id is None: + if ( + token_id == graph_tokens['sog'] + and prev_token_id is None + and sog_idx is None + ): + sog_idx = token_idx + elif ( + token_id == graph_tokens['pred_node'] + and prev_token_id is None + and sog_idx is not None + ): + graph_id = tuple(token_ids[sog_idx:token_idx])[1:] prev_token_id, prev_idx = token_id, token_idx + elif ( + token_id == graph_tokens['eog'] + and prev_token_id is not None + and graph_id is not None + ): + sequence.append(SequenceElement( + token=prev_token_id, + start_idx=prev_idx, + end_idx=token_idx, + ids=tuple(token_ids[prev_idx:token_idx])[1:], + length=token_idx - prev_idx, + graph_id=graph_id + )) + sog_idx, graph_id = None, None + prev_token_id, prev_idx, final_idx = None, -1, len(token_ids) elif ( token_id in [graph_tokens['pred_node'], graph_tokens['edge'], graph_tokens['succ_node']] and prev_token_id is not None + and graph_id is not None ): sequence.append(SequenceElement( token=prev_token_id, start_idx=prev_idx, end_idx=token_idx, ids=tuple(token_ids[prev_idx:token_idx])[1:], - length=token_idx - prev_idx + length=token_idx - prev_idx, + graph_id=graph_id )) prev_token_id, prev_idx = token_id, token_idx - elif token_id in [graph_tokens['eos'], graph_tokens['pad']] and prev_token_id is not None: + elif ( + token_id in [graph_tokens['eos'], graph_tokens['pad']] + and prev_token_id is not None + and graph_id is not None + ): final_idx = token_idx break - if prev_token_id is not None: + if final_idx is not None: sequence.append(SequenceElement( token=prev_token_id, start_idx=prev_idx, end_idx=final_idx, ids=tuple(token_ids[prev_idx:final_idx])[1:], - length=final_idx - prev_idx + length=final_idx - prev_idx, + graph_id=graph_id )) return sequence diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index e3bc02421f64ca..aa8c66b91d1c93 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -986,7 +986,7 @@ def init_graph_information_passing( """ Initializes a set of message passing layers to perform message passing of between graph elements described in an input token id sequence """ - assert element_type in ['nodes', 'edges'], 'unsupported message passing type' + assert element_type in ['node_correspondence', 'edge'], 'unsupported message passing type' self.message_passing_type = element_type self.graph_token_ids = graph_token_ids self.num_gnn_layers = ( @@ -1105,7 +1105,7 @@ def forward( extract_edge_sequence(t_ids.tolist(), self.graph_token_ids) for t_ids in input_ids ] if self.message_passing_type == 'nodes': - get_matrices = GatedCausalMessagePassingLayer.build_node_information_passing + get_matrices = GatedCausalMessagePassingLayer.build_node_correspondence_information_passing else: get_matrices = GatedCausalMessagePassingLayer.build_edge_information_passing message_passing_dicts = get_matrices(edge_sequences, self.device)