|
| 1 | +""" |
| 2 | +Tests extraction with a DAG-based cost model. |
| 3 | +from https://github.com/egraphs-good/egglog-python/issues/387#issuecomment-3628927075 |
| 4 | +""" |
| 5 | + |
| 6 | +from dataclasses import dataclass, field |
| 7 | + |
| 8 | +from egglog import * |
| 9 | +from egglog import bindings |
| 10 | + |
| 11 | +# A cost model, approximately equivalent to, greedy_dag_cost_model, |
| 12 | +# which operates purely on the `bindings` level, for the sake of |
| 13 | +# minimization. |
| 14 | + |
| 15 | +ENode = tuple[str, tuple[bindings.Value, ...]] |
| 16 | + |
| 17 | + |
| 18 | +@dataclass |
| 19 | +class DAGCostValue: |
| 20 | + """Cost value for DAG-based extraction.""" |
| 21 | + |
| 22 | + cost: int |
| 23 | + _values: dict[ENode, int] |
| 24 | + |
| 25 | + def __eq__(self, rhs: object) -> bool: |
| 26 | + if not isinstance(rhs, DAGCostValue): |
| 27 | + return False |
| 28 | + return self.cost == rhs.cost |
| 29 | + |
| 30 | + def __lt__(self, other: "DAGCostValue") -> bool: |
| 31 | + return self.cost < other.cost |
| 32 | + |
| 33 | + def __le__(self, other: "DAGCostValue") -> bool: |
| 34 | + return self.cost <= other.cost |
| 35 | + |
| 36 | + def __gt__(self, other: "DAGCostValue") -> bool: |
| 37 | + return self.cost > other.cost |
| 38 | + |
| 39 | + def __ge__(self, other: "DAGCostValue") -> bool: |
| 40 | + return self.cost >= other.cost |
| 41 | + |
| 42 | + def __hash__(self) -> int: |
| 43 | + return hash(self.cost) |
| 44 | + |
| 45 | + def __str__(self) -> str: |
| 46 | + return f"DAGCostValue(cost={self.cost})" |
| 47 | + |
| 48 | + def __repr__(self) -> str: |
| 49 | + return f"DAGCostValue(cost={self.cost}, nchildren={len(self._values)})" |
| 50 | + |
| 51 | + |
| 52 | +@dataclass |
| 53 | +class DAGCost: |
| 54 | + """ |
| 55 | + DAG-based cost model for e-graph extraction. |
| 56 | +
|
| 57 | + This cost model counts each unique e-node once, implementing |
| 58 | + a greedy DAG extraction strategy. |
| 59 | + """ |
| 60 | + |
| 61 | + graph: bindings.EGraph |
| 62 | + cache: dict[ENode, DAGCostValue] = field(default_factory=dict) |
| 63 | + |
| 64 | + def merge_costs(self, costs: list[DAGCostValue], node: ENode, self_cost: int = 0) -> DAGCostValue: |
| 65 | + # if node in self.cache: |
| 66 | + # return self.cache[node] |
| 67 | + values: dict[ENode, int] = {} |
| 68 | + for child in costs: |
| 69 | + values.update(child._values) |
| 70 | + cost = DAGCostValue(cost=sum(values.values(), start=self_cost), _values=values) |
| 71 | + cost._values[node] = self_cost |
| 72 | + # self.cache[node] = cost |
| 73 | + # print(f"merge {costs=} out={cost}") |
| 74 | + return cost |
| 75 | + |
| 76 | + def cost_fold(self, fn: str, enode: ENode, children_costs: list[DAGCostValue]) -> DAGCostValue: |
| 77 | + return self.merge_costs(children_costs, enode, 1) |
| 78 | + # print(f"fold {fn=} {out=}") |
| 79 | + |
| 80 | + def enode_cost(self, name: str, args: list[bindings.Value]) -> ENode: |
| 81 | + return (name, tuple(args)) |
| 82 | + |
| 83 | + def container_cost(self, tp: str, value: bindings.Value, element_costs: list[DAGCostValue]) -> DAGCostValue: |
| 84 | + return self.merge_costs(element_costs, (tp, (value,)), 1) |
| 85 | + |
| 86 | + def base_value_cost(self, tp: str, value: bindings.Value) -> DAGCostValue: |
| 87 | + return self.merge_costs([], (tp, (value,)), 1) |
| 88 | + |
| 89 | + @property |
| 90 | + def egg_cost_model(self) -> bindings.CostModel: |
| 91 | + return bindings.CostModel( |
| 92 | + fold=self.cost_fold, |
| 93 | + enode_cost=self.enode_cost, |
| 94 | + container_cost=self.container_cost, |
| 95 | + base_value_cost=self.base_value_cost, |
| 96 | + ) |
| 97 | + |
| 98 | + |
| 99 | +def test_dag_cost_model(): |
| 100 | + graph = EGraph() |
| 101 | + |
| 102 | + commands = graph._egraph.parse_program(""" |
| 103 | + (sort S) |
| 104 | +
|
| 105 | + (constructor Si (i64) S) |
| 106 | + (constructor Swide (S S S S S S S S) S ) |
| 107 | + (constructor Ssa (S) S) |
| 108 | + (constructor Ssb (S) S) |
| 109 | + (constructor Ssc (S) S) |
| 110 | + (constructor Sp (S S) S) |
| 111 | +
|
| 112 | +
|
| 113 | + (let w |
| 114 | + (Swide (Si 0) (Si 1) (Si 2) (Si 3) (Si 4) (Si 5) (Si 6) (Si 7))) |
| 115 | +
|
| 116 | + (let l (Ssa (Ssb (Ssc (Si 0))))) |
| 117 | + (let x (Ssa w)) |
| 118 | + (let v (Sp w x)) |
| 119 | +
|
| 120 | + (union x l) |
| 121 | + """) |
| 122 | + graph._egraph.run_program(*commands) |
| 123 | + |
| 124 | + cost_model = DAGCost(graph._egraph) |
| 125 | + extractor = bindings.Extractor(["S"], graph._egraph, cost_model.egg_cost_model) |
| 126 | + termdag = bindings.TermDag() |
| 127 | + value = graph._egraph.lookup_function("v", []) |
| 128 | + assert value is not None |
| 129 | + cost, _term = extractor.extract_best(graph._egraph, termdag, value, "S") |
| 130 | + |
| 131 | + assert cost.cost in {19, 21} |
0 commit comments