Skip to content

Commit

Permalink
compiler: Remove useless IndexDerivative properties
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Feb 20, 2024
1 parent 54d0fdb commit 88c39ba
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 18 deletions.
17 changes: 4 additions & 13 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,9 @@ def _xreplace(self, rule):

class IndexDerivative(IndexSum):

__rargs__ = ('expr', 'mapper', 'order')
__rargs__ = ('expr', 'mapper')

def __new__(cls, expr, mapper, order, **kwargs):
def __new__(cls, expr, mapper, **kwargs):
dimensions = as_tuple(set(mapper.values()))

# Detect the Weights among the arguments
Expand All @@ -758,12 +758,11 @@ def __new__(cls, expr, mapper, order, **kwargs):
obj = super().__new__(cls, expr, dimensions)
obj._weights = weights
obj._mapper = frozendict(mapper)
obj._order = order

return obj

def _hashable_content(self):
return super()._hashable_content() + (self.mapper, self.order)
return super()._hashable_content() + (self.mapper,)

def compare(self, other):
if self is other:
Expand All @@ -787,14 +786,6 @@ def weights(self):
def mapper(self):
return self._mapper

@property
def order(self):
return self._order

@property
def scaling(self):
return Mul(*[d.spacing**self.order for d in self.mapper])

@property
def depth(self):
iderivs = self.expr.find(IndexDerivative)
Expand Down Expand Up @@ -936,7 +927,7 @@ def _diff2sympy(obj):

# Handle special objects
if isinstance(obj, DiffDerivative):
return IndexDerivative(*args, obj.mapper, obj.order), True
return IndexDerivative(*args, obj.mapper), True

# Handle generic objects such as arithmetic operations
try:
Expand Down
2 changes: 1 addition & 1 deletion devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, symbolic
# Pure number
pass

deriv = DiffDerivative(expr*weights, {dim: indices.free_dim}, deriv_order)
deriv = DiffDerivative(expr*weights, {dim: indices.free_dim})
else:
terms = []
for i, c in zip(indices, weights):
Expand Down
7 changes: 3 additions & 4 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
ConditionalDimension, left, right, centered, div, grad)
from devito.finite_differences import Derivative, Differentiable
from devito.finite_differences.differentiable import (Add, EvalDerivative, IndexSum,
IndexDerivative, Weights, Pow)
IndexDerivative, Weights)
from devito.symbolics import indexify, retrieve_indexed
from devito.types.dimension import StencilDimension

Expand Down Expand Up @@ -716,9 +716,8 @@ def test_index_derivative(self):
ui = u.subs(x, x + i*x.spacing)
w = Weights(name='w0', dimensions=i, initvalue=[-0.5, 0, 0.5])

idxder = IndexDerivative(ui*w, {x: i}, so)
idxder = IndexDerivative(ui*w, {x: i})

assert idxder.scaling == Pow(x.spacing, so)
assert idxder.evaluate == -0.5*u + 0.5*ui.subs(i, 2)

# Make sure subs works as expected
Expand All @@ -727,7 +726,7 @@ def test_index_derivative(self):
vi0 = v.subs(x, x + i*x.spacing)
vi1 = idxder.subs(ui, vi0)

assert IndexDerivative(vi0*w, {x: i}, so) == vi1
assert IndexDerivative(vi0*w, {x: i}) == vi1

def test_dx2(self):
grid = Grid(shape=(4, 4))
Expand Down

0 comments on commit 88c39ba

Please sign in to comment.