From de4543acad40b3094b9add2a71f2d4c4a23230e1 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 18 Dec 2024 16:49:35 +0000 Subject: [PATCH] compiler: Ensure Weights always get printed before any other expr --- devito/finite_differences/differentiable.py | 6 ++++++ devito/types/array.py | 6 +++++- devito/types/basic.py | 6 ++++++ tests/test_symbolics.py | 24 +++++++++++++++++++-- 4 files changed, 39 insertions(+), 3 deletions(-) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 1fcf3e0b44..8a32df7289 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -749,6 +749,12 @@ def __init_finalize__(self, *args, **kwargs): super().__init_finalize__(*args, **kwargs) + @classmethod + def class_key(cls): + # Ensure Weights appear before any other AbstractFunction + p, v, _ = Array.class_key() + return p, v - 1, cls.__name__ + def __eq__(self, other): return (isinstance(other, Weights) and self.name == other.name and diff --git a/devito/types/array.py b/devito/types/array.py index 95e1b10ce2..c482c2849f 100644 --- a/devito/types/array.py +++ b/devito/types/array.py @@ -2,7 +2,7 @@ from functools import cached_property import numpy as np -from sympy import Expr +from sympy import Expr, cacheit from devito.tools import (Reconstructable, as_tuple, c_restrict_void_p, dtype_to_ctype, dtypes_vector_mapper, is_integer) @@ -556,6 +556,10 @@ def indices(self): def dtype(self): return self.function.dtype + @cacheit + def sort_key(self, order=None): + return self.base.sort_key(order=order) + # Default assumptions correspond to those of the `base` for i in ('is_real', 'is_imaginary', 'is_commutative'): locals()[i] = property(lambda self, v=i: getattr(self.base, v)) diff --git a/devito/types/basic.py b/devito/types/basic.py index f2c2179147..4dcf1dad95 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1527,6 +1527,12 @@ def __new__(cls, label, shape, function=None): func = Pickable._rebuild + @sympy.cacheit + def sort_key(self, order=None): + class_key, args, exp, coeff = super().sort_key(order=order) + args = (self.function.class_key(), *args) + return class_key, args, exp, coeff + def __getitem__(self, indices, **kwargs): """Produce a types.Indexed, rather than a sympy.Indexed.""" return Indexed(self, *as_tuple(indices)) diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 509b0cb9d3..8b1ae20098 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -8,7 +8,7 @@ from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos, Min, Max) -from devito.finite_differences.differentiable import SafeInv +from devito.finite_differences.differentiable import SafeInv, Weights from devito.ir import Expression, FindNodes from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa CallFromPointer, Cast, DefFunction, FieldFromPointer, @@ -17,7 +17,7 @@ retrieve_derivatives) from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, - Symbol as dSymbol) + ComponentAccess, StencilDimension, Symbol as dSymbol) from devito.types.basic import AbstractSymbol @@ -448,6 +448,26 @@ def test_findexed(): assert new_fi.strides_map == strides_map +def test_canonical_ordering_of_weights(): + grid = Grid(shape=(3, 3, 3)) + x, y, z = grid.dimensions + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + + i = StencilDimension('i0', 0, 2) + w = Weights(name='w0', dimensions=i, initvalue=[1.0, 2.0, 3.0]) + + fi = f[x, y + i, z] + wi = w[i] + cf = ComponentAccess(fi, 0) + + assert (ccode(1.0*f[x, y, z] + 2.0*f[x, y + 1, z] + 3.0*f[x, y + 2, z]) == + '1.0F*f[x][y][z] + 2.0F*f[x][y + 1][z] + 3.0F*f[x][y + 2][z]') + assert ccode(fi*wi) == 'w0[i0]*f[x][y + i0][z]' + assert ccode(cf*wi) == 'w0[i0]*f[x][y + i0][z].x' + + def test_symbolic_printing(): b = Symbol('b')