Skip to content

Commit

Permalink
Merge pull request #2523 from firedrakeproject/connorjward/fix-assign…
Browse files Browse the repository at this point in the history
…-hashing

Fix hashing for Assign rvalues
  • Loading branch information
connorjward authored Sep 16, 2022
2 parents 8b15438 + b75c1a6 commit ee720d6
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 17 deletions.
22 changes: 7 additions & 15 deletions firedrake/assemble_expressions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import itertools
import os
import tempfile
import weakref
from collections import OrderedDict, defaultdict
from functools import singledispatch
Expand All @@ -13,7 +11,7 @@
from gem.node import MemoizerArg
from gem.node import traversal as gem_traversal
from pyop2 import op2
from pyop2.caching import disk_cached
from pyop2.caching import cached
from pyop2.parloop import GlobalLegacyArg, DatLegacyArg
from tsfc import ufl2gem
from tsfc.loopy import generate
Expand Down Expand Up @@ -416,25 +414,19 @@ def compile_to_gem(expr, translator):
return preprocess_gem([lvalue, rvalue])


try:
_cachedir = os.environ["FIREDRAKE_TSFC_KERNEL_CACHE_DIR"]
except KeyError:
_cachedir = os.path.join(tempfile.gettempdir(),
f"firedrake-pointwise-expression-kernel-cache-uid{os.getuid()}")
"""Storage location for the kernel cache."""
_pointwise_expression_cache = {}
"""In-memory cache for pointwise expression kernels."""


def _pointwise_expression_key(exprs, scalar_type, is_logging):
"""Return a cache key for use with :func:`pointwise_expression_kernel`."""
# Since this cache is collective this function must return a 2-tuple of
# communicator and cache key.
comm = exprs[0].lvalue.node_set.comm
key = tuple(e.slow_key for e in exprs) + (scalar_type, is_logging)
return comm, key
from firedrake.interpolation import hash_expr
return (tuple((e.__class__, hash(e.lvalue), hash_expr(e.rvalue)) for e in exprs)
+ (scalar_type, is_logging))


@PETSc.Log.EventDecorator()
@disk_cached({}, _cachedir, key=_pointwise_expression_key, collective=True)
@cached(_pointwise_expression_cache, key=_pointwise_expression_key)
def pointwise_expression_kernel(exprs, scalar_type, is_logging):
"""Compile a kernel for pointwise expressions.
Expand Down
4 changes: 2 additions & 2 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _compile_expression_key(comm, expr, to_element, ufl_element, domain, paramet
# Since the caching is collective, this function must return a 2-tuple of
# the form (comm, key) where comm is the communicator the cache is collective over.
# FIXME FInAT elements are not safely hashable so we ignore them here
key = _hash_expr(expr), hash(ufl_element), utils.tuplify(parameters), log
key = hash_expr(expr), hash(ufl_element), utils.tuplify(parameters), log
return comm, key


Expand Down Expand Up @@ -517,7 +517,7 @@ def __init__(self, glob):
self.ufl_domain = lambda: None


def _hash_expr(expr):
def hash_expr(expr):
"""Return a numbering-invariant hash of a UFL expression.
:arg expr: A UFL expression.
Expand Down
20 changes: 20 additions & 0 deletions tests/regression/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,26 @@ def test_expression_cache():
assert len(u._expression_cache) == 5


def test_global_expression_cache():
from firedrake.assemble_expressions import _pointwise_expression_cache

mesh = UnitSquareMesh(1, 1)
V = VectorFunctionSpace(mesh, "CG", 1)
u = Function(V)

_pointwise_expression_cache.clear()
assert len(_pointwise_expression_cache) == 0

u.assign(Constant(1))
assert len(_pointwise_expression_cache) == 1

u.assign(Constant(2))
assert len(_pointwise_expression_cache) == 1

u.assign(1)
assert len(_pointwise_expression_cache) == 2


def test_augmented_assignment_broadcast():
mesh = UnitSquareMesh(1, 1)
V = FunctionSpace(mesh, "BDM", 1)
Expand Down

0 comments on commit ee720d6

Please sign in to comment.