Skip to content

Commit

Permalink
add support for pymbolic.EqualityMapper
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed May 4, 2022
1 parent c80701f commit fff5af6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
19 changes: 19 additions & 0 deletions pytato/scalar_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
DistributeMapperBase)
from pymbolic.mapper.stringifier import (StringifyMapper as
StringifyMapperBase)
from pymbolic.mapper.equality import (EqualityMapper as
EqualityMapperBase)
from pymbolic.mapper.collector import TermCollector as TermCollectorBase
import pymbolic.primitives as prim
import numpy as np
Expand Down Expand Up @@ -169,6 +171,20 @@ def map_reduce(self, expr: Any, enclosing_prec: Any, *args: Any) -> str:
bounds_expr = "{" + bounds_expr + "}"
return (f"{expr.op}({bounds_expr}, {self.rec(expr.inner_expr, PN)})")


class EqualityMapper(EqualityMapperBase):
def map_reduce(self, expr: Reduce, other: Reduce) -> bool:
return (
len(expr.bounds) == len(other.bounds)
and all(k == other_k
and self.rec(lb, other_lb) and self.rec(ub, other_ub)
for (k, (lb, ub)), (other_k, (other_lb, other_ub)) in zip(
sorted(expr.bounds.items()),
sorted(other.bounds.items())))
and expr.op == other.op
and self.rec(expr.inner_expr, other.inner_expr)
)

# }}}


Expand Down Expand Up @@ -225,6 +241,9 @@ def distribute(expr: Any, parameters: FrozenSet[Any] = frozenset(),
# {{{ custom scalar expression nodes

class ExpressionBase(prim.Expression):
def make_equality_mapper(self) -> EqualityMapper:
return EqualityMapper()

def make_stringifier(self, originating_stringifier: Any = None) -> str:
return StringifyMapper()

Expand Down
2 changes: 2 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ def test_userscollector():


def test_asciidag():
pytest.importorskip("asciidag")

n = pt.make_size_param("n")
array = pt.make_placeholder(name="array", shape=n, dtype=np.float64)
stack = pt.stack([array, 2*array, array + 6])
Expand Down

0 comments on commit fff5af6

Please sign in to comment.