Skip to content

Commit 12e8f1e

Browse files
authored
Merge pull request #18 from Axiomatic-AI/print-statements
Print statements nicely
2 parents c19ba00 + 4b8a13f commit 12e8f1e

File tree

1 file changed

+101
-1
lines changed

1 file changed

+101
-1
lines changed

src/axiomatic/pic_helpers.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ipywidgets import interactive, IntSlider # type: ignore
44
from typing import List, Optional
55

6-
from . import Parameter
6+
from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation
77

88

99
def plot_circuit(component):
@@ -201,3 +201,103 @@ def plot_parameter_history(parameters: List[Parameter], parameter_history: List[
201201
]
202202
)
203203
plt.show()
204+
205+
206+
def print_statements(statements: StatementDictionary, validation: Optional[StatementValidationDictionary] = None):
207+
"""
208+
Print a list of statements in nice readable format.
209+
"""
210+
211+
validation = StatementValidationDictionary(
212+
cost_functions=(validation.cost_functions if validation is not None else None) or [StatementValidation()]*len(statements.cost_functions or []),
213+
parameter_constraints=(validation.parameter_constraints if validation is not None else None) or [StatementValidation()]*len(statements.parameter_constraints or []),
214+
structure_constraints=(validation.structure_constraints if validation is not None else None) or [StatementValidation()]*len(statements.structure_constraints or []),
215+
unformalizable_statements=(validation.unformalizable_statements if validation is not None else None) or [StatementValidation()]*len(statements.unformalizable_statements or [])
216+
)
217+
218+
if len(validation.cost_functions or []) != len(statements.cost_functions or []):
219+
raise ValueError("Number of cost functions and validations do not match.")
220+
if len(validation.parameter_constraints or []) != len(statements.parameter_constraints or []):
221+
raise ValueError("Number of parameter constraints and validations do not match.")
222+
if len(validation.structure_constraints or []) != len(statements.structure_constraints or []):
223+
raise ValueError("Number of structure constraints and validations do not match.")
224+
if len(validation.unformalizable_statements or []) != len(statements.unformalizable_statements or []):
225+
raise ValueError("Number of unformalizable statements and validations do not match.")
226+
227+
print("-----------------------------------\n")
228+
for cost_stmt, cost_val in zip(statements.cost_functions or [], validation.cost_functions or []):
229+
print("Type:", cost_stmt.type)
230+
print("Statement:", cost_stmt.text)
231+
print("Formalization:", end=" ")
232+
if cost_stmt.formalization is None:
233+
print("UNFORMALIZED")
234+
else:
235+
code = cost_stmt.formalization.code
236+
if cost_stmt.formalization.mapping is not None:
237+
for var_name, computation in cost_stmt.formalization.mapping.items():
238+
if computation is not None:
239+
args_str = ", ".join(
240+
[
241+
f"{argname}="
242+
+ (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
243+
for argname, argvalue in computation.arguments.items()
244+
]
245+
)
246+
code = code.replace(var_name, f"{computation.name}({args_str})")
247+
print(code)
248+
val = cost_stmt.validation or cost_val
249+
if val.satisfiable is not None and val.message is not None:
250+
print(f"Satisfiable: {val.satisfiable}")
251+
print(val.message)
252+
print("\n-----------------------------------\n")
253+
for param_stmt, param_val in zip(statements.cost_functions or [], validation.cost_functions or []):
254+
print("Type:", param_stmt.type)
255+
print("Statement:", param_stmt.text)
256+
print("Formalization:", end=" ")
257+
if param_stmt.formalization is None:
258+
print("UNFORMALIZED")
259+
else:
260+
code = param_stmt.formalization.code
261+
if param_stmt.formalization.mapping is not None:
262+
for var_name, computation in param_stmt.formalization.mapping.items():
263+
if computation is not None:
264+
args_str = ", ".join(
265+
[
266+
f"{argname}="
267+
+ (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
268+
for argname, argvalue in computation.arguments.items()
269+
]
270+
)
271+
code = code.replace(var_name, f"{computation.name}({args_str})")
272+
print(code)
273+
val = param_stmt.validation or param_val
274+
if val.satisfiable is not None and val.message is not None and val.holds is not None:
275+
print(f"Satisfiable: {val.satisfiable}")
276+
print(f"Holds: {val.holds} ({val.message})")
277+
print("\n-----------------------------------\n")
278+
for struct_stmt, struct_val in zip(statements.structure_constraints or [], validation.structure_constraints or []):
279+
print("Type:", struct_stmt.type)
280+
print("Statement:", struct_stmt.text)
281+
print("Formalization:", end=" ")
282+
if struct_stmt.formalization is None:
283+
print("UNFORMALIZED")
284+
else:
285+
func_constr = struct_stmt.formalization
286+
args_str = ", ".join(
287+
[
288+
f"{argname}=" + (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
289+
for argname, argvalue in func_constr.arguments.items()
290+
]
291+
)
292+
func_str = f"{func_constr.function_name}({args_str}) == {func_constr.expected_result}"
293+
print(func_str)
294+
val = struct_stmt.validation or struct_val
295+
if val.satisfiable is not None and val.holds is not None:
296+
print(f"Satisfiable: {val.satisfiable}")
297+
print(f"Holds: {val.holds}")
298+
print("\n-----------------------------------\n")
299+
for unf_stmt in statements.unformalizable_statements or []:
300+
print("Type:", unf_stmt.type)
301+
print("Statement:", unf_stmt.text)
302+
print("Formalization: UNFORMALIZABLE")
303+
print("\n-----------------------------------\n")

0 commit comments

Comments
 (0)