Skip to content

Commit 25ba64e

Browse files
committed
restore deleted validation_results
1 parent afd7f99 commit 25ba64e

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
from IPython.display import display, Math, HTML # type: ignore
2+
import hypernetx as hnx # type: ignore
3+
import matplotlib.pyplot as plt # type: ignore
4+
import re
5+
6+
7+
def display_full_results(validation_results, requirements=None, show_hypergraph=True):
8+
"""Display equation validation results optimized for dark theme notebooks."""
9+
validations = validation_results.get("validations", {})
10+
11+
matching = []
12+
non_matching = []
13+
14+
for eq_name, value in validations.items():
15+
equation_data = {
16+
"name": eq_name,
17+
"latex": value.get("original_format", ""),
18+
"lhs": value.get("lhs_value"),
19+
"rhs": value.get("rhs_value"),
20+
"diff": abs(value.get("lhs_value", 0) - value.get("rhs_value", 0)),
21+
"percent_diff": abs(value.get("lhs_value", 0) - value.get("rhs_value", 0))
22+
/ max(abs(value.get("rhs_value", 0)), 1e-10)
23+
* 100,
24+
"used_values": value.get("used_values", {}),
25+
}
26+
if value.get("is_valid"):
27+
matching.append(equation_data)
28+
else:
29+
non_matching.append(equation_data)
30+
31+
# Summary header with dark theme
32+
total = len(validations)
33+
display(
34+
HTML(
35+
'<div style="background-color:#1e1e1e; padding:20px; border-radius:10px; margin:20px 0; '
36+
'border:1px solid #3e3e3e;">'
37+
f'<h2 style="font-family:Arial; color:#e0e0e0; margin-bottom:15px">Equation Validation Analysis</h2>'
38+
f'<p style="font-family:Arial; font-size:16px; color:#e0e0e0">'
39+
f"<b>Total equations analyzed:</b> {total}<br>"
40+
f'<span style="color:#4caf50">✅ Matching equations: {len(matching)}</span><br>'
41+
f'<span style="color:#ff5252">❌ Non-matching equations: {len(non_matching)}</span></p>'
42+
"</div>"
43+
)
44+
)
45+
46+
# Non-matching equations
47+
if non_matching:
48+
display(
49+
HTML(
50+
'<div style="background-color:#2d1f1f; padding:20px; border-radius:10px; margin:20px 0; '
51+
'border:1px solid #4a2f2f;">'
52+
'<h3 style="color:#ff5252; font-family:Arial">⚠️ Equations Not Satisfied</h3>'
53+
)
54+
)
55+
56+
for eq in non_matching:
57+
display(HTML(f'<h4 style="color:#e0e0e0; font-family:Arial">{eq["name"]}</h4>'))
58+
display(Math(eq["latex"]))
59+
display(
60+
HTML(
61+
'<div style="font-family:monospace; margin-left:20px; margin-bottom:20px; '
62+
"background-color:#2a2a2a; color:#e0e0e0; padding:15px; border-radius:5px; "
63+
'border-left:4px solid #ff5252">'
64+
f"Left side = {eq['lhs']:.6g}<br>"
65+
f"Right side = {eq['rhs']:.6g}<br>"
66+
f"Absolute difference = {eq['diff']:.6g}<br>"
67+
f"Relative difference = {eq['percent_diff']:.2f}%<br>"
68+
"<br>Used values:<br>"
69+
+ "<br>".join([f"{k} = {v:.6g}" for k, v in eq["used_values"].items()])
70+
+ "</div>"
71+
)
72+
)
73+
74+
display(HTML("</div>"))
75+
76+
# Matching equations
77+
if matching:
78+
display(
79+
HTML(
80+
'<div style="background-color:#1f2d1f; padding:20px; border-radius:10px; margin:20px 0; '
81+
'border:1px solid #2f4a2f;">'
82+
'<h3 style="color:#4caf50; font-family:Arial">✅ Satisfied Equations</h3>'
83+
)
84+
)
85+
86+
for eq in matching:
87+
display(HTML(f'<h4 style="color:#e0e0e0; font-family:Arial">{eq["name"]}</h4>'))
88+
display(Math(eq["latex"]))
89+
display(
90+
HTML(
91+
'<div style="font-family:monospace; margin-left:20px; margin-bottom:20px; '
92+
"background-color:#2a2a2a; color:#e0e0e0; padding:15px; border-radius:5px; "
93+
'border-left:4px solid #4caf50">'
94+
f"Value = {eq['lhs']:.6g}<br>"
95+
"<br>Used values:<br>"
96+
+ "<br>".join([f"{k} = {v:.6g}" for k, v in eq["used_values"].items()])
97+
+ "</div>"
98+
)
99+
)
100+
101+
display(HTML("</div>"))
102+
103+
# Hypergraph visualization
104+
if show_hypergraph and requirements:
105+
display(
106+
HTML(
107+
'<div style="background-color:#1e1e1e; padding:20px; border-radius:10px; margin:20px 0; '
108+
'border:1px solid #3e3e3e;">'
109+
'<h3 style="color:#e0e0e0; font-family:Arial">🔍 Equation Relationship Analysis</h3>'
110+
'<p style="font-family:Arial; color:#e0e0e0">The following graph shows how variables are connected through equations:</p>'
111+
"</div>"
112+
)
113+
)
114+
115+
list_api_requirements = requirements
116+
117+
# Match get_eq_hypergraph settings exactly
118+
plt.rcParams["text.usetex"] = False
119+
plt.rcParams["mathtext.fontset"] = "stix"
120+
plt.rcParams["font.family"] = "serif"
121+
122+
# Prepare hypergraph data
123+
hyperedges = {}
124+
for eq_name, details in validations.items():
125+
# Create a set of variables used in this equation
126+
used_vars = {f"${var}$" for var in details["used_values"].keys()}
127+
hyperedges[_get_latex_string_format(details["original_format"])] = used_vars
128+
129+
# Create and plot the hypergraph
130+
H = hnx.Hypergraph(hyperedges)
131+
plt.figure(figsize=(16, 12))
132+
133+
# Draw hypergraph with exact same settings as get_eq_hypergraph
134+
hnx.draw(
135+
H,
136+
with_edge_labels=True,
137+
edge_labels_on_edge=False,
138+
node_labels_kwargs={"fontsize": 14},
139+
edge_labels_kwargs={"fontsize": 14},
140+
layout_kwargs={"seed": 42, "scale": 2.5},
141+
)
142+
143+
node_labels = list(H.nodes)
144+
symbol_explanations = _get_node_names_for_node_lables(node_labels, list_api_requirements)
145+
146+
explanation_text = "\n".join([f"${symbol}$: {desc}" for symbol, desc in symbol_explanations])
147+
plt.annotate(
148+
explanation_text,
149+
xy=(1.05, 0.5),
150+
xycoords="axes fraction",
151+
fontsize=14,
152+
verticalalignment="center",
153+
)
154+
155+
plt.title(r"Enhanced Hypergraph of Equations and Variables", fontsize=20)
156+
plt.show()
157+
158+
return None
159+
160+
161+
def _get_node_names_for_node_lables(node_labels, api_requirements):
162+
"""
163+
Creates mapping between symbols and their descriptions.
164+
165+
Args:
166+
node_labels: List of node labels (symbols) from the hypergraph
167+
api_requirements: Can be either:
168+
- List of dicts with {"latex_symbol": str, "requirement_name": str}
169+
- Dictionary mapping variable names to their descriptions
170+
"""
171+
node_names = []
172+
173+
# Handle case where api_requirements is a dictionary
174+
if isinstance(api_requirements, dict):
175+
for symbol in node_labels:
176+
clean_symbol = symbol.replace("$", "")
177+
if clean_symbol in api_requirements:
178+
node_names.append((clean_symbol, api_requirements[clean_symbol]))
179+
return node_names
180+
181+
# Handle case where api_requirements is a list of dicts
182+
for symbol in node_labels:
183+
clean_symbol = symbol.replace("$", "")
184+
for req in api_requirements:
185+
if isinstance(req, dict) and req.get("latex_symbol") == clean_symbol:
186+
node_names.append((req["latex_symbol"], req["requirement_name"]))
187+
break
188+
189+
return node_names
190+
191+
192+
def _get_latex_string_format(input_string):
193+
"""
194+
Properly formats LaTeX strings for matplotlib when text.usetex is False.
195+
No escaping needed since mathtext handles backslashes properly.
196+
"""
197+
return f"${input_string}$" # No backslash escaping required
198+
199+
200+
def _get_requirements_set(requirements):
201+
variable_set = set()
202+
for req in requirements:
203+
variable_set.add(req["latex_symbol"])
204+
205+
return variable_set
206+
207+
208+
def _find_vars_in_eq(equation, variable_set):
209+
patterns = [re.escape(var) for var in variable_set]
210+
combined_pattern = r"|".join(patterns)
211+
matches = re.findall(combined_pattern, equation)
212+
return {rf"${match}$" for match in matches}

0 commit comments

Comments
 (0)