Skip to content

Commit 95349c5

Browse files
committed
Add documentation in data_processing files
1 parent 152c4ab commit 95349c5

File tree

4 files changed

+82
-170
lines changed

4 files changed

+82
-170
lines changed

README.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
## Running the Code
12
In order to extract the features from the corpus proto files, run:
23
python data_generation.py
34

@@ -7,8 +8,30 @@ python train.py --model_name="lstm_gcn_to_lstm_attention" --device=cuda:0 --prin
78
All the possible options when running a model can be seen by running:
89
python train.py --help
910

11+
## Pretrained Models
1012
A pretrained version of the best performing model (as a state dictionary) can be downloaded at
1113
https://drive.google.com/file/d/1fm7hGzr-tziNhUMh8duc8s4j5gWW3uKm/view?usp=sharing
1214

13-
15+
## High-Level Code Structure
16+
- data_processing/: contains the code for extracting, storing, analysing and processing data
17+
- data_analysis.ipynb: notebook containing analysis of the extracted data
18+
- data_extraction.py: contains the logic to extract the features data from the proto files of
19+
the corpus
20+
- data_generation.py: file to be called to generate the features data
21+
- data_util.py: contains utilities to work with data
22+
- text_util.py: contains utilities to work with text
23+
- models/: contains all the code for the different models
24+
- full_model.py: class of the complete methodNaming model
25+
- gat_encoder.py: class for the Graph Attention Network encoder
26+
- gcn_encoder.py: class for the Graph Convolutional Network encoder
27+
- graph_attention_layer.py: class for the Graph Attention Layer used by the Graph Attention
28+
Network
29+
- graph_convolutional_layer.py: class for the Graph Convolutional Layer used by the Graph
30+
Convolutional Network
31+
- lstm_decoder.py: class for the LSTM sequence decoder
32+
- lstm_encoder.py: class for the LSTM sequence encoder
33+
- training.py: contains code to train and evaluate the models
34+
- evaluation_util.py: contains utilities to compute evaluation metrics
35+
- train.py: entry-point for training the models
36+
- train_model.py: contains logic to train the models
1437

data_processing/data_extraction.py

Lines changed: 24 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@
1010

1111

1212
def get_dataset_from_dir(dir="../corpus/r252-corpus-features/"):
13+
"""
14+
Extract methods source code, names and graphs structure.
15+
:param dir: directory where to look for proto files
16+
:return: (methods_source, methods_names, methods_graphs)
17+
"""
1318
methods_source = []
1419
methods_names = []
1520
methods_graphs = []
1621

1722
proto_files = list(Path(dir).rglob("*.proto"))
1823
print("A total of {} files have been found".format(len(proto_files)))
1924

20-
# proto_files = [Path("../features-javac-master/Test.java.proto")]
21-
2225
for i, file in enumerate(proto_files):
23-
# nx_graph = get_nx_graph(file)
24-
# if i % 100 == 0:
25-
print("Extracting data from file {}".format(i+1))
2626
file_methods_source, file_methods_names, file_methods_graph = get_file_methods_data(
2727
file)
2828
methods_source += file_methods_source
@@ -34,9 +34,8 @@ def get_dataset_from_dir(dir="../corpus/r252-corpus-features/"):
3434

3535
def get_file_methods_data(file):
3636
"""
37-
Extract the source code tokens, identifier names and graph for methods in a source file
38-
represented by a graph. Identifier tokens are split into subtokens. Constructors are not
39-
included in the methods.
37+
Extract the source code tokens, identifier names and graph for methods in a source file.
38+
Identifier tokens are split into subtokens. Constructors are not included in the methods.
4039
:param file: file
4140
:return: (methods_source, methods_names, methods_graph) where methods_source[i] is a list of the tokens for
4241
the source of ith method in the file, methods_names[i] is a list of tokens for name of the
@@ -68,23 +67,6 @@ def get_file_methods_data(file):
6867
methods_graph.append((method_edges, non_tokens_nodes_features))
6968
methods_names.append(split_identifier_into_parts(method_name_node.contents))
7069

71-
# start_line_number = node.startLineNumber
72-
# end_line_number = node.endLineNumber
73-
# method_source = []
74-
# for other_node in g.node:
75-
# if other_node.startLineNumber >= start_line_number and other_node.endLineNumber \
76-
# <= end_line_number:
77-
# # if other_node.type == FeatureNode.TOKEN:
78-
# # method_source.append(other_node.contents)
79-
# # elif other_node.type == FeatureNode.IDENTIFIER_TOKEN:
80-
# # sub_identifiers = split_identifier_into_parts(other_node.contents)
81-
# # method_source += sub_identifiers
82-
# if other_node.id == method_name_node.id:
83-
# method_source.append('_')
84-
# elif other_node.type == FeatureNode.TOKEN or other_node.type == \
85-
# FeatureNode.IDENTIFIER_TOKEN:
86-
# method_source.append(other_node.contents)
87-
8870
method_source = []
8971

9072
for other_node in method_nodes.values():
@@ -101,6 +83,9 @@ def get_file_methods_data(file):
10183

10284

10385
def get_file_graph(file):
86+
"""
87+
Compute graph for the given file.
88+
"""
10489
with file.open('rb') as f:
10590
g = Graph()
10691
g.ParseFromString(f.read())
@@ -117,6 +102,9 @@ def get_file_graph(file):
117102

118103

119104
def get_method_edges(method_node_id, file_adj_list, file_nodes):
105+
"""
106+
Compute edges of a method graph for a method starting at the node 'method_node_id'.
107+
"""
120108
method_nodes_ids = []
121109

122110
get_method_nodes_rec(method_node_id, method_nodes_ids, file_adj_list)
@@ -136,6 +124,9 @@ def get_method_edges(method_node_id, file_adj_list, file_nodes):
136124

137125

138126
def get_method_nodes_rec(node_id, method_nodes_ids, file_adj_list):
127+
"""
128+
Utilities to recursively retrieve all edges of a method graph.
129+
"""
139130
method_nodes_ids.append(node_id)
140131

141132
for edge in file_adj_list[node_id]:
@@ -144,6 +135,9 @@ def get_method_nodes_rec(node_id, method_nodes_ids, file_adj_list):
144135

145136

146137
def remap_edges(edges, nodes):
138+
"""
139+
Remap edges so that ids start from 0 and are consecutive.
140+
"""
147141
old_id_to_new_id = {}
148142
i = 0
149143
nodes_values = sorted(nodes.values(), key=lambda node: node.id)
@@ -174,26 +168,10 @@ def is_token(node_value):
174168
return node_value.type == FeatureNode.TOKEN or node_value.type == FeatureNode.IDENTIFIER_TOKEN
175169

176170

177-
def get_tokens(g):
171+
def get_method_name_node(g, method_node):
178172
"""
179-
Get the tokens for a file. Identifiers are split in subtokens.
180-
:param g: graph representing the file
181-
:return: list of tokens
173+
Return the node corresponding to the name of a method.
182174
"""
183-
token_nodes = list(filter(lambda n: n.type in (FeatureNode.TOKEN, FeatureNode.IDENTIFIER_TOKEN),
184-
g.node))
185-
tokens = []
186-
for token_node in token_nodes:
187-
if token_node.type == FeatureNode.IDENTIFIER_TOKEN:
188-
sub_identifiers = split_identifier_into_parts(token_node.contents)
189-
tokens += sub_identifiers
190-
else:
191-
tokens.append(token_node.contents)
192-
193-
return tokens
194-
195-
196-
def get_method_name_node(g, method_node):
197175
method_id = method_node.id
198176
method_name_node_id = 0
199177

@@ -221,6 +199,9 @@ def get_class_name_node(g):
221199

222200

223201
def get_nx_graph(file):
202+
"""
203+
Get networkx graph corresponding to a file.
204+
"""
224205
nx_graph = nx.DiGraph()
225206
with file.open('rb') as f:
226207
g = Graph()
@@ -231,116 +212,3 @@ def get_nx_graph(file):
231212
edge.type][0]
232213
nx_graph.add_edge(edge.sourceId, edge.destinationId, edge_type=edge_type)
233214
return nx_graph
234-
235-
236-
def get_tokens_dataset_from_dir(dir="../corpus/r252-corpus-features/"):
237-
methods_source = []
238-
methods_names = []
239-
methods_graphs = []
240-
241-
proto_files = list(Path(dir).rglob("*.proto"))
242-
print("A total of {} files have been found".format(len(proto_files)))
243-
244-
# proto_files = [Path("../features-javac-master/Test.java.proto")]
245-
246-
for i, file in enumerate(proto_files):
247-
# nx_graph = get_nx_graph(file)
248-
if i % 10 == 0:
249-
print("Extracting data from file {}".format(i+1))
250-
file_methods_source, file_methods_names, file_methods_graph = \
251-
get_file_methods_data(file)
252-
methods_source += file_methods_source
253-
methods_names += file_methods_names
254-
methods_graphs += file_methods_graph
255-
256-
return methods_source, methods_names, methods_graphs
257-
258-
259-
def get_method_nodes(method_node, file_graph):
260-
method_nodes = [method_node]
261-
get_method_nodes_rec(method_node, file_graph, method_nodes)
262-
263-
return method_nodes
264-
265-
266-
# def get_method_nodes_rec(node, file_graph, method_nodes):
267-
# print(len(method_nodes))
268-
# for e in file_graph.edge:
269-
# neighbour = e.destinationId
270-
# if neighbour not in method_nodes:
271-
# method_nodes.append(neighbour)
272-
# get_method_nodes(neighbour, nx_graph, method_nodes)
273-
274-
275-
def get_augmented_graph(file):
276-
# TODO: Does each method in a file have a different graph?
277-
with file.open('rb') as f:
278-
g = Graph()
279-
g.ParseFromString(f.read())
280-
281-
augmented_graph = nx.Graph()
282-
new_node_id = max([node.id for node in g.node]) + 1
283-
284-
split_identifiers_node = [node for node in g.node if node.type == FeatureNode.IDENTIFIER_TOKEN
285-
and len(split_identifier_into_parts(node.contents)) > 1]
286-
287-
# Add all edges
288-
for edge in g.edge:
289-
edge_type = [name for name, value in list(vars(FeatureEdge).items())[8:] if value ==
290-
edge.type][0]
291-
augmented_graph.add_edge(edge.sourceId, edge.destinationId, edge_type=edge_type)
292-
293-
# Add new edges for split identifiers and sub identifiers
294-
for node in split_identifiers_node:
295-
sub_identifiers = split_identifier_into_parts(node.contents)
296-
sub_identifiers_ids = list(range(new_node_id, new_node_id + len(sub_identifiers)))
297-
new_node_id += len(sub_identifiers)
298-
299-
# ADD NEXT_TOKEN edge from node before identifier to first sub-identifier
300-
previous_token_node_id = find_previous_token_node_id(node, g)
301-
augmented_graph.add_edge(previous_token_node_id, sub_identifiers_ids[0],
302-
edge_type="NEXT_TOKEN")
303-
304-
# ADD NEXT_TOKEN edge from last sub-identifier to node after identifier
305-
next_token_node_id = find_next_token_node_id(node, g)
306-
augmented_graph.add_edge(sub_identifiers_ids[-1], next_token_node_id,
307-
edge_type="NEXT_TOKEN")
308-
309-
# ADD AST_CHILD edge from ast parent of node to first sub-identifier
310-
# ast_parent_node_id = find_ast_parent_node_id(node, g)
311-
# augmented_graph.add_edge(ast_parent_node_id, sub_identifiers_ids[0],
312-
# edge_type="ASSOCIATED_TOKEN")
313-
314-
for i, sub_identifier_id in enumerate(sub_identifiers_ids):
315-
# Add IN_TOKEN edges from sub-identifiers to identifier
316-
augmented_graph.add_edge(sub_identifier_id, node.id, edge_type="IN_TOKEN")
317-
318-
# ADD NEXT_TOKEN edges from sub-identifier to next sub-identifier
319-
if i < len(sub_identifiers_ids) - 1:
320-
augmented_graph.add_edge(sub_identifiers_ids[i], sub_identifiers_ids[i + 1],
321-
edge_type="NEXT_TOKEN")
322-
return augmented_graph
323-
324-
325-
def find_previous_token_node_id(node, g):
326-
for edge in g.edge:
327-
if edge.destinationId == node.id and edge.type == FeatureEdge.NEXT_TOKEN:
328-
return edge.sourceId
329-
330-
return None
331-
332-
333-
def find_next_token_node_id(node, g):
334-
for edge in g.edge:
335-
if edge.sourceId == node.id and edge.type == FeatureEdge.NEXT_TOKEN:
336-
return edge.destinationId
337-
338-
return None
339-
340-
341-
def find_ast_parent_node_id(node, g):
342-
for edge in g.edge:
343-
if edge.destinationId == node.id and edge.type == FeatureEdge.ASSOCIATED_TOKEN:
344-
return edge.sourceId
345-
346-
return None

data_processing/data_generation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from data_processing.data_extraction import get_dataset_from_dir
22
import pickle
33

4+
# Generate data
45
methods_source, methods_names, methods_graphs = get_dataset_from_dir(
56
"../corpus/r252-corpus-features/")
67

8+
# Store data
79
pickle.dump({'methods_source': methods_source, 'methods_names': methods_names, 'methods_graphs':
810
methods_graphs}, open('data/methods_tokens_graphs.pkl', 'wb'))

0 commit comments

Comments
 (0)