Skip to content

Commit aab8e15

Browse files
Add test case
1 parent 746d1e0 commit aab8e15

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed

python/tests/test_extract_bug.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""
2+
Tests extraction with a DAG-based cost model.
3+
from https://github.com/egraphs-good/egglog-python/issues/387#issuecomment-3628927075
4+
"""
5+
6+
from dataclasses import dataclass, field
7+
8+
from egglog import *
9+
from egglog import bindings
10+
11+
# A cost model, approximately equivalent to, greedy_dag_cost_model,
12+
# which operates purely on the `bindings` level, for the sake of
13+
# minimization.
14+
15+
ENode = tuple[str, tuple[bindings.Value, ...]]
16+
17+
18+
@dataclass
19+
class DAGCostValue:
20+
"""Cost value for DAG-based extraction."""
21+
22+
cost: int
23+
_values: dict[ENode, int]
24+
25+
def __eq__(self, rhs: object) -> bool:
26+
if not isinstance(rhs, DAGCostValue):
27+
return False
28+
return self.cost == rhs.cost
29+
30+
def __lt__(self, other: "DAGCostValue") -> bool:
31+
return self.cost < other.cost
32+
33+
def __le__(self, other: "DAGCostValue") -> bool:
34+
return self.cost <= other.cost
35+
36+
def __gt__(self, other: "DAGCostValue") -> bool:
37+
return self.cost > other.cost
38+
39+
def __ge__(self, other: "DAGCostValue") -> bool:
40+
return self.cost >= other.cost
41+
42+
def __hash__(self) -> int:
43+
return hash(self.cost)
44+
45+
def __str__(self) -> str:
46+
return f"DAGCostValue(cost={self.cost})"
47+
48+
def __repr__(self) -> str:
49+
return f"DAGCostValue(cost={self.cost}, nchildren={len(self._values)})"
50+
51+
52+
@dataclass
53+
class DAGCost:
54+
"""
55+
DAG-based cost model for e-graph extraction.
56+
57+
This cost model counts each unique e-node once, implementing
58+
a greedy DAG extraction strategy.
59+
"""
60+
61+
graph: bindings.EGraph
62+
cache: dict[ENode, DAGCostValue] = field(default_factory=dict)
63+
64+
def merge_costs(self, costs: list[DAGCostValue], node: ENode, self_cost: int = 0) -> DAGCostValue:
65+
# if node in self.cache:
66+
# return self.cache[node]
67+
values: dict[ENode, int] = {}
68+
for child in costs:
69+
values.update(child._values)
70+
cost = DAGCostValue(cost=sum(values.values(), start=self_cost), _values=values)
71+
cost._values[node] = self_cost
72+
# self.cache[node] = cost
73+
# print(f"merge {costs=} out={cost}")
74+
return cost
75+
76+
def cost_fold(self, fn: str, enode: ENode, children_costs: list[DAGCostValue]) -> DAGCostValue:
77+
return self.merge_costs(children_costs, enode, 1)
78+
# print(f"fold {fn=} {out=}")
79+
80+
def enode_cost(self, name: str, args: list[bindings.Value]) -> ENode:
81+
return (name, tuple(args))
82+
83+
def container_cost(self, tp: str, value: bindings.Value, element_costs: list[DAGCostValue]) -> DAGCostValue:
84+
return self.merge_costs(element_costs, (tp, (value,)), 1)
85+
86+
def base_value_cost(self, tp: str, value: bindings.Value) -> DAGCostValue:
87+
return self.merge_costs([], (tp, (value,)), 1)
88+
89+
@property
90+
def egg_cost_model(self) -> bindings.CostModel:
91+
return bindings.CostModel(
92+
fold=self.cost_fold,
93+
enode_cost=self.enode_cost,
94+
container_cost=self.container_cost,
95+
base_value_cost=self.base_value_cost,
96+
)
97+
98+
99+
def test_dag_cost_model():
100+
graph = EGraph()
101+
102+
commands = graph._egraph.parse_program("""
103+
(sort S)
104+
105+
(constructor Si (i64) S)
106+
(constructor Swide (S S S S S S S S) S )
107+
(constructor Ssa (S) S)
108+
(constructor Ssb (S) S)
109+
(constructor Ssc (S) S)
110+
(constructor Sp (S S) S)
111+
112+
113+
(let w
114+
(Swide (Si 0) (Si 1) (Si 2) (Si 3) (Si 4) (Si 5) (Si 6) (Si 7)))
115+
116+
(let l (Ssa (Ssb (Ssc (Si 0)))))
117+
(let x (Ssa w))
118+
(let v (Sp w x))
119+
120+
(union x l)
121+
""")
122+
graph._egraph.run_program(*commands)
123+
124+
cost_model = DAGCost(graph._egraph)
125+
extractor = bindings.Extractor(["S"], graph._egraph, cost_model.egg_cost_model)
126+
termdag = bindings.TermDag()
127+
value = graph._egraph.lookup_function("v", [])
128+
assert value is not None
129+
cost, _term = extractor.extract_best(graph._egraph, termdag, value, "S")
130+
131+
assert cost.cost in {19, 21}

0 commit comments

Comments
 (0)