Skip to content

Commit 0134012

Browse files
authored
1138 plot functions for ast in generation package (#1140)
Improved handling and visualization of ast in memilio-generation
1 parent 86eb344 commit 0134012

File tree

8 files changed

+346
-195
lines changed

8 files changed

+346
-195
lines changed

pycode/memilio-generation/memilio/generation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@
2525
from .intermediate_representation import IntermediateRepresentation
2626
from .scanner import Scanner
2727
from .scanner_config import ScannerConfig
28+
from .ast import AST
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#############################################################################
2+
# Copyright (C) 2020-2024 MEmilio
3+
#
4+
# Authors: Maximilian Betz, Daniel Richter
5+
#
6+
# Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
#############################################################################
20+
"""
21+
@file ast.py
22+
@brief Create the ast and assign ids. Get ids and nodes.
23+
"""
24+
import subprocess
25+
import tempfile
26+
import logging
27+
from clang.cindex import Cursor, TranslationUnit, Index, CompilationDatabase
28+
from typing import TYPE_CHECKING
29+
from memilio.generation import utility
30+
31+
32+
if TYPE_CHECKING:
33+
from memilio.generation import ScannerConfig
34+
35+
from typing_extensions import Self
36+
37+
38+
class AST:
39+
"""! Create the ast and assign ids.
40+
Functions for getting nodes and node ids.
41+
"""
42+
43+
def __init__(self: Self, conf: "ScannerConfig") -> None:
44+
self.config = conf
45+
self.cursor_id = -1
46+
self.id_to_val = dict()
47+
self.val_to_id = dict()
48+
self.cursor = None
49+
self.translation_unit = self.create_ast()
50+
51+
def create_ast(self: Self) -> TranslationUnit:
52+
"""! Create an abstract syntax tree for the main model.cpp file with a corresponding CompilationDatabase.
53+
A compile_commands.json is required (automatically generated in the build process).
54+
"""
55+
self.cursor_id = -1
56+
self.id_to_val.clear()
57+
self.val_to_id.clear()
58+
59+
idx = Index.create()
60+
61+
file_args = []
62+
63+
unwanted_arguments = [
64+
'-Wno-unknown-warning', "--driver-mode=g++", "-O3", "-Werror", "-Wshadow"
65+
]
66+
67+
dirname = utility.try_get_compilation_database_path(
68+
self.config.skbuild_path_to_database)
69+
compdb = CompilationDatabase.fromDirectory(dirname)
70+
commands = compdb.getCompileCommands(self.config.source_file)
71+
for command in commands:
72+
for argument in command.arguments:
73+
if argument not in unwanted_arguments:
74+
file_args.append(argument)
75+
file_args = file_args[1:-4]
76+
77+
clang_cmd = [
78+
"clang-14", self.config.source_file,
79+
"-std=c++17", '-emit-ast', '-o', '-']
80+
clang_cmd.extend(file_args)
81+
82+
try:
83+
clang_cmd_result = subprocess.run(
84+
clang_cmd, stdout=subprocess.PIPE)
85+
clang_cmd_result.check_returncode()
86+
except subprocess.CalledProcessError as e:
87+
# Capture standard error and output
88+
logging.error(
89+
f"Clang failed with return code {e.returncode}. Error: {clang_cmd_result.stderr.decode()}")
90+
raise RuntimeError(
91+
f"Clang AST generation failed. See error log for details.")
92+
93+
# Since `clang.Index.read` expects a file path, write generated abstract syntax tree to a
94+
# temporary named file. This file will be automatically deleted when closed.
95+
with tempfile.NamedTemporaryFile() as ast_file:
96+
ast_file.write(clang_cmd_result.stdout)
97+
translation_unit = idx.read(ast_file.name)
98+
99+
self._assing_ast_with_ids(translation_unit.cursor)
100+
101+
logging.info("AST generation completed successfully.")
102+
103+
return translation_unit
104+
105+
def _assing_ast_with_ids(self, cursor: Cursor) -> None:
106+
"""! Traverse the AST and assign a unique ID to each node during traversal.
107+
108+
@param cursor: The current node (Cursor) in the AST to traverse.
109+
"""
110+
# assing_ids umschreiben -> mapping
111+
self.cursor_id += 1
112+
id = self.cursor_id
113+
self.id_to_val[id] = cursor
114+
115+
if cursor.hash in self.val_to_id.keys():
116+
self.val_to_id[cursor.hash].append(id)
117+
else:
118+
self.val_to_id[cursor.hash] = [id]
119+
120+
logging.info(
121+
f"Node {cursor.spelling or cursor.kind} assigned ID {id}")
122+
123+
for child in cursor.get_children():
124+
self._assing_ast_with_ids(child)
125+
126+
@property
127+
def root_cursor(self):
128+
return self.translation_unit.cursor
129+
130+
def get_node_id(self, cursor: Cursor) -> int:
131+
"""! Returns the id of the current node.
132+
133+
Extracs the key from the current cursor from the dictonary id_to_val
134+
135+
@param cursor: The current node of the AST as a cursor object from libclang.
136+
"""
137+
for cursor_id in self.val_to_id[cursor.hash]:
138+
139+
if self.id_to_val[cursor_id] == cursor:
140+
141+
return cursor_id
142+
raise IndexError(f"Cursor {cursor} is out of bounds.")
143+
144+
def get_node_by_index(self, index: int) -> Cursor:
145+
"""! Returns the node at the specified index position.
146+
147+
@param index: Node_id from the ast.
148+
"""
149+
150+
if index < 0 or index >= len(self.id_to_val):
151+
raise IndexError(f"Index {index} is out of bounds.")
152+
return self.id_to_val[index]
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#############################################################################
2+
# Copyright (C) 2020-2024 MEmilio
3+
#
4+
# Authors: Maximilian Betz, Daniel Richter
5+
#
6+
# Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
#############################################################################
20+
21+
import os
22+
import logging
23+
from typing import Callable
24+
from graphviz import Digraph
25+
from clang.cindex import Cursor
26+
from memilio.generation.ast import AST
27+
28+
29+
class Visualization:
30+
"""! Class for plotting the abstract syntax tree in different formats.
31+
"""
32+
@staticmethod
33+
def output_ast_terminal(ast: AST, cursor: Cursor) -> None:
34+
"""! Output the abstract syntax tree to terminal.
35+
36+
@param ast: ast object from AST class.
37+
@param cursor: The current node of the AST as a cursor object from libclang.
38+
"""
39+
40+
def terminal_writer(level: int, cursor_label: str) -> None:
41+
print(indent(level) + cursor_label)
42+
43+
_output_cursor_and_children(cursor, ast, terminal_writer)
44+
45+
logging.info("AST-Terminal written.")
46+
47+
@staticmethod
48+
def output_ast_png(cursor: Cursor, max_depth: int, output_file_name: str = 'ast_graph') -> None:
49+
"""! Output the abstract syntax tree to a .png. Set the starting node and the max depth.
50+
51+
To save the abstract syntax tree as an png with a starting node and a depth u cann use the following command
52+
53+
Example command: aviz.output_ast_png(ast.get_node_by_index(1), 2)
54+
55+
aviz -> instance of the Visualization class.
56+
57+
ast -> instance of the AST class.
58+
59+
.get_node_by_index -> get a specific node by id (use .output_ast_formatted to see node ids)
60+
61+
The number 2 is a example for the depth the graph will show
62+
63+
@param cursor: The current node of the AST as a cursor object from libclang.
64+
@param max_depth: Maximal depth the graph displays.
65+
"""
66+
67+
graph = Digraph(format='png')
68+
69+
_output_cursor_and_children_graphviz_digraph(
70+
cursor, graph, max_depth, 0)
71+
72+
graph.render(filename=output_file_name, view=False)
73+
74+
output_path = os.path.abspath(f"{output_file_name}.png")
75+
logging.info(f"AST-png written to {output_path}")
76+
77+
@staticmethod
78+
def output_ast_formatted(ast: AST, cursor: Cursor, output_file_name: str = 'ast_formated.txt') -> None:
79+
"""!Output the abstract syntax tree to a file.
80+
81+
@param ast: ast object from AST class.
82+
@param cursor: The current node of the AST as a cursor object from libclang.
83+
"""
84+
85+
with open(output_file_name, 'w') as f:
86+
def file_writer(level: int, cursor_label: str) -> None:
87+
f.write(indent(level) + cursor_label + newline())
88+
_output_cursor_and_children(cursor, ast, file_writer)
89+
90+
output_path = os.path.abspath(f"{output_file_name}")
91+
logging.info(f"AST-formated written to {output_path}")
92+
93+
94+
def indent(level: int) -> str:
95+
"""! Create an indentation based on the level.
96+
"""
97+
return '│ ' * level + '├── '
98+
99+
100+
def newline() -> str:
101+
"""! Create a new line.
102+
"""
103+
return '\n'
104+
105+
106+
def _output_cursor_and_children(cursor: Cursor, ast: AST, writer: Callable[[int, str], None], level: int = 0) -> None:
107+
"""!Generic function to output the cursor and its children with a specified writer.
108+
109+
@param cursor: The current node of the AST as a libclang cursor object.
110+
@param ast: AST object from the AST class.
111+
@param writer: Function that takes `level` and `cursor_label` and handles output.
112+
@param level: The current depth in the AST for indentation purposes.
113+
"""
114+
115+
cursor_id = ast.get_node_id(cursor)
116+
117+
cursor_kind = f"<CursorKind.{cursor.kind.name}>"
118+
file_path = cursor.location.file.name if cursor.location.file else ""
119+
120+
if cursor.spelling:
121+
cursor_label = (f'ID:{cursor_id} {cursor.spelling} '
122+
f'{cursor_kind} '
123+
f'{file_path}')
124+
else:
125+
cursor_label = f'ID:{cursor_id} {cursor_kind} {file_path}'
126+
127+
writer(level, cursor_label)
128+
129+
for child in cursor.get_children():
130+
_output_cursor_and_children(
131+
child, ast, writer, level + 1)
132+
133+
134+
def _output_cursor_and_children_graphviz_digraph(cursor: Cursor, graph: Digraph, max_d: int, current_d: int, parent_node: str = None) -> None:
135+
"""! Output the cursor and its children as a graph using Graphviz.
136+
137+
@param cursor: The current node of the AST as a Cursor object from libclang.
138+
@param graph: Graphviz Digraph object where the nodes and edges will be added.
139+
@param max_d: Maximal depth.
140+
@param current_d: Current depth.
141+
@param parent_node: Name of the parent node in the graph (None for the root node).
142+
"""
143+
144+
if current_d > max_d:
145+
return
146+
147+
node_label = f"{cursor.kind.name}{newline()}({cursor.spelling})" if cursor.spelling else cursor.kind.name
148+
149+
current_node = f"{cursor.kind.name}_{cursor.hash}"
150+
151+
graph.node(current_node, label=node_label)
152+
153+
if parent_node:
154+
graph.edge(parent_node, current_node)
155+
156+
if cursor.kind.is_reference():
157+
referenced_label = f"ref_to_{cursor.referenced.kind.name}{newline()}({cursor.referenced.spelling})"
158+
referenced_node = f"ref_{cursor.referenced.hash}"
159+
graph.node(referenced_node, label=referenced_label)
160+
graph.edge(current_node, referenced_node)
161+
162+
for child in cursor.get_children():
163+
_output_cursor_and_children_graphviz_digraph(
164+
child, graph, max_d, current_d + 1, current_node)

0 commit comments

Comments
 (0)