-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtask_base.py
186 lines (143 loc) · 7.9 KB
/
task_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from TAGLAS import get_task
from typing import (
Optional,
Union,
Callable,
)
from torch.utils.data import Dataset
import torch
import string
import numpy as np
def generate_random_node_order(num_nodes: int):
order, _ = torch.sort(torch.randperm(676)[:num_nodes])
return order
def generate_alphabet_id():
# support graph with less than 26 * 26 nodes.
alphabet = list(string.ascii_uppercase)
alphabet_ext = [a + b for a in alphabet for b in alphabet]
alphabet += alphabet_ext
return np.array(["[NODEID." + s + "]" for s in alphabet], dtype=object)
def generate_node_id(num_nodes):
char_ids = generate_alphabet_id()
node_order = generate_random_node_order(num_nodes)
return char_ids[node_order.numpy()]
def construct_prompt_graph(num_nodes: int, question_indexs: list, single_direction=False):
r"""
Construct and return the prompt graph given num_nodes in the original graph and question_indexs.
Question_indexs should be a list and each element save the all node indexs of the question.
"""
num_prompt_nodes = len(question_indexs)
if num_prompt_nodes == 0:
return None, None, None
prompt_node_idx = [num_nodes + i for i in range(num_prompt_nodes)]
prompt_edge_rows = []
prompt_edge_cols = []
prompt_node_index = []
for i, question_index in enumerate(question_indexs):
prompt_node_index.append(num_nodes + i)
prompt_edge_rows.extend(question_index)
prompt_edge_cols.extend([prompt_node_idx[i] for _ in range(len(question_index))])
prompt_edges = torch.tensor([prompt_edge_rows, prompt_edge_cols], dtype=torch.long)
num_prompt_edge = prompt_edges.size(-1)
prompt_edge_map = torch.zeros(num_prompt_edge, dtype=torch.long)
prompt_node_index = torch.tensor(prompt_node_index, dtype=torch.long)
if not single_direction:
reverse_prompt_edges = prompt_edges[[1, 0], :]
prompt_edges = torch.cat([prompt_edges, reverse_prompt_edges], dim=-1)
reverse_edge_map = torch.ones(num_prompt_edge, dtype=torch.long)
prompt_edge_map = torch.cat([prompt_edge_map, reverse_edge_map], dim=-1)
return prompt_edges, prompt_edge_map, prompt_node_index
def add_prompt_graph_to_data(
data,
prompt_edge_index=None,
prompt_edge_map=None,
prompt_edge_text=None,
prompt_node_text=None):
if prompt_edge_index is not None:
assert prompt_edge_map is not None
assert prompt_edge_map is not None
assert prompt_edge_text is not None
data.edge_index = torch.cat([data.edge_index, prompt_edge_index], dim=-1)
edge_map = data.edge_map
num_feature_edge_type = edge_map.numel()
if num_feature_edge_type == 0:
data.edge_map = torch.cat([edge_map, prompt_edge_map + num_feature_edge_type], dim=-1)
else:
data.edge_map = torch.cat([edge_map, prompt_edge_map + edge_map.max()], dim=-1)
data.edge_attr = np.concatenate([data.edge_attr, prompt_edge_text[prompt_edge_map.numpy()]], axis=-1)
if prompt_edge_text is not None:
data.x = np.concatenate([data.x, prompt_node_text])
return data
def build_GOFA_task_graph(data, add_prompt_graph=True, is_pretrain=False, single_direction=False, **kwargs):
r"""GOFA task graph construction function. This function will 1. add node id to nodes in the graph.
2.specify the Node of generation, either be the target node or add prompt node to the graph.
if is_pretrain set to False, set to fine-tune model, will assume each graph only contain one QA pair.
Otherwise, assume the data come with multiple questions and answers. For both two modes, function will automatically
add prompt node to the graph if question related to more than one node. Otherwise, add prompt node if add_prompt_graph=True.
"""
num_nodes = data.node_map.size(0)
node_ids = generate_node_id(num_nodes)
data.node_ids = node_ids
# add node id to node text
for i in range(num_nodes):
data.x[i] = f"This is node {node_ids[i]}." + data.x[i]
# Replace placeholder in question and answer with node id
for i in range(num_nodes):
for q in range(len(data.question)):
data.question[q] = data.question[q].replace(f"[NODE_INDEX {i}]", node_ids[i])
for a in range(len(data.answer)):
data.answer[a] = data.answer[a].replace(f"[NODE_INDEX {i}]", node_ids[i])
for l in range(len(data.label)):
data.label[l] = data.label[l].replace(f"[NODE_INDEX {i}]", node_ids[i])
# add prompt graph
prompt_edge_text = np.array(['This edge connects the nodes in graph to a prompt node.',
"This edge connects the prompt node to a node in the graph."], dtype=object)
if not is_pretrain:
# fine-tune mode
question_indexs = [data.target_index.tolist()]
if len(data.target_index) > 1:
prompt_edge_index, prompt_edge_map, prompt_node_index = construct_prompt_graph(num_nodes, question_indexs,
single_direction)
data.question_index = prompt_node_index
else:
if add_prompt_graph:
prompt_edge_index, prompt_edge_map, prompt_node_index = construct_prompt_graph(num_nodes,
question_indexs,
single_direction)
data.question_index = prompt_node_index
else:
prompt_edge_index = prompt_edge_map = None
data.question_index = data.target_index
data = add_prompt_graph_to_data(data=data, prompt_edge_index=prompt_edge_index, prompt_edge_map=prompt_edge_map,
prompt_edge_text=prompt_edge_text, prompt_node_text=data.question)
return data
else:
#pretrain mode
question_indexs = data.target_index
if add_prompt_graph:
prompt_edge_index, prompt_edge_map, prompt_node_index = construct_prompt_graph(num_nodes,
question_indexs,
single_direction)
data.question_index = prompt_node_index
data = add_prompt_graph_to_data(data=data, prompt_edge_index=prompt_edge_index,
prompt_edge_map=prompt_edge_map, prompt_edge_text=prompt_edge_text,
prompt_node_text=data.question)
else:
question_prompt_flag = [True if len(q) > 1 else False for q in question_indexs]
prompt_question_indexs = [q for q in question_indexs if len(q) > 1]
prompt_questions = [q for i, q in enumerate(data.question) if question_prompt_flag[i]]
prompt_edge_index, prompt_edge_map, prompt_node_index = construct_prompt_graph(num_nodes, prompt_question_indexs,
single_direction)
data = add_prompt_graph_to_data(data=data, prompt_edge_index=prompt_edge_index,
prompt_edge_map=prompt_edge_map, prompt_edge_text=prompt_edge_text,
prompt_node_text=prompt_questions)
final_question_index = []
j = 0
for i in range(len(question_prompt_flag)):
if question_prompt_flag[i]:
final_question_index.append(prompt_node_index[j].item())
j += 1
else:
final_question_index.append(question_indexs[i][0])
data.question_index = torch.tensor(final_question_index, dtype=torch.long)
return data