Skip to content

Commit

Permalink
compiler: Ensure Weights always get printed before any other expr
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Jan 24, 2025
1 parent b5ed317 commit de4543a
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 3 deletions.
6 changes: 6 additions & 0 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,12 @@ def __init_finalize__(self, *args, **kwargs):

super().__init_finalize__(*args, **kwargs)

@classmethod
def class_key(cls):
# Ensure Weights appear before any other AbstractFunction
p, v, _ = Array.class_key()
return p, v - 1, cls.__name__

def __eq__(self, other):
return (isinstance(other, Weights) and
self.name == other.name and
Expand Down
6 changes: 5 additions & 1 deletion devito/types/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import cached_property

import numpy as np
from sympy import Expr
from sympy import Expr, cacheit

from devito.tools import (Reconstructable, as_tuple, c_restrict_void_p,
dtype_to_ctype, dtypes_vector_mapper, is_integer)
Expand Down Expand Up @@ -556,6 +556,10 @@ def indices(self):
def dtype(self):
return self.function.dtype

@cacheit
def sort_key(self, order=None):
return self.base.sort_key(order=order)

# Default assumptions correspond to those of the `base`
for i in ('is_real', 'is_imaginary', 'is_commutative'):
locals()[i] = property(lambda self, v=i: getattr(self.base, v))
6 changes: 6 additions & 0 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,12 @@ def __new__(cls, label, shape, function=None):

func = Pickable._rebuild

@sympy.cacheit
def sort_key(self, order=None):
class_key, args, exp, coeff = super().sort_key(order=order)
args = (self.function.class_key(), *args)
return class_key, args, exp, coeff

def __getitem__(self, indices, **kwargs):
"""Produce a types.Indexed, rather than a sympy.Indexed."""
return Indexed(self, *as_tuple(indices))
Expand Down
24 changes: 22 additions & 2 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
Min, Max)
from devito.finite_differences.differentiable import SafeInv
from devito.finite_differences.differentiable import SafeInv, Weights
from devito.ir import Expression, FindNodes
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
CallFromPointer, Cast, DefFunction, FieldFromPointer,
Expand All @@ -17,7 +17,7 @@
retrieve_derivatives)
from devito.tools import as_tuple
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
Symbol as dSymbol)
ComponentAccess, StencilDimension, Symbol as dSymbol)
from devito.types.basic import AbstractSymbol


Expand Down Expand Up @@ -448,6 +448,26 @@ def test_findexed():
assert new_fi.strides_map == strides_map


def test_canonical_ordering_of_weights():
grid = Grid(shape=(3, 3, 3))
x, y, z = grid.dimensions

f = Function(name='f', grid=grid)
g = Function(name='g', grid=grid)

i = StencilDimension('i0', 0, 2)
w = Weights(name='w0', dimensions=i, initvalue=[1.0, 2.0, 3.0])

fi = f[x, y + i, z]
wi = w[i]
cf = ComponentAccess(fi, 0)

assert (ccode(1.0*f[x, y, z] + 2.0*f[x, y + 1, z] + 3.0*f[x, y + 2, z]) ==
'1.0F*f[x][y][z] + 2.0F*f[x][y + 1][z] + 3.0F*f[x][y + 2][z]')
assert ccode(fi*wi) == 'w0[i0]*f[x][y + i0][z]'
assert ccode(cf*wi) == 'w0[i0]*f[x][y + i0][z].x'


def test_symbolic_printing():
b = Symbol('b')

Expand Down

0 comments on commit de4543a

Please sign in to comment.