Skip to content

Commit

Permalink
symbolics: fix printer and arithmetic for sympy 1.13
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 9, 2024
1 parent dee09dd commit 4d460a0
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 53 deletions.
2 changes: 1 addition & 1 deletion devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _symbolic_functions(self):
@cached_property
def function(self):
if len(self._functions) == 1:
return self._functions.pop()
return set(self._functions).pop()
else:
return None

Expand Down
8 changes: 6 additions & 2 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,13 @@ def _print_Float(self, expr):
rv = to_str(expr._mpf_, dps, strip_zeros=strip, max_fixed=-2, min_fixed=2)

if rv.startswith('-.0'):
rv = '-0.' + rv[3:]
rv = "-0." + rv[3:]
elif rv.startswith('.0'):
rv = '0.' + rv[2:]
rv = "0." + rv[2:]

# Remove trailing zero except first one to avoid 1. instead of 1.0
if 'e' not in rv:
rv = rv.rstrip('0') + "0"

if self.single_prec():
rv = '%sF' % rv
Expand Down
9 changes: 8 additions & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,11 +735,18 @@ def adjoint(self, inner=True):
def __add__(self, other):
try:
# Most case support sympy add
return super().__add__(other)
tsum = super().__add__(other)
except TypeError:
# Sympy doesn't support add with scalars
tsum = self.applyfunc(lambda x: x + other)

# As of sympy 1.13, super does not throw an exception but
# only returns NotImplemented for some internal dispatch.
if tsum is NotImplemented:
return self.applyfunc(lambda x: x + other)

return tsum

def _eval_matrix_mul(self, other):
"""
Copy paste from sympy to avoid explicit call to sympy.Add
Expand Down
12 changes: 6 additions & 6 deletions examples/seismic/tutorials/05_staggered_acoustic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@
{
"data": {
"text/latex": [
"$\\displaystyle \\left[\\begin{matrix}v_x(t + dt, x + h_x/2, z)\\\\v_z(t + dt, x, z + h_z/2)\\end{matrix}\\right] = \\left[\\begin{matrix}dt \\left(\\frac{\\partial}{\\partial x} p(t, x, z) + \\frac{v_x(t, x + h_x/2, z)}{dt}\\right)\\\\dt \\left(\\frac{\\partial}{\\partial z} p(t, x, z) + \\frac{v_z(t, x, z + h_z/2)}{dt}\\right)\\end{matrix}\\right]$"
"$\\displaystyle \\left[\\begin{matrix}v_x(t + dt, x + h_x/2, z)\\\\v_z(t + dt, x, z + h_z/2)\\end{matrix}\\right] = \\left[\\begin{matrix}dt \\left(1.0 \\frac{\\partial}{\\partial x} p(t, x, z) + \\frac{v_x(t, x + h_x/2, z)}{dt}\\right)\\\\dt \\left(1.0 \\frac{\\partial}{\\partial z} p(t, x, z) + \\frac{v_z(t, x, z + h_z/2)}{dt}\\right)\\end{matrix}\\right]$"
],
"text/plain": [
"Eq(Vector(v_x(t + dt, x + h_x/2, z), v_z(t + dt, x, z + h_z/2)), Vector(dt*(Derivative(p(t, x, z), x) + v_x(t, x + h_x/2, z)/dt), dt*(Derivative(p(t, x, z), z) + v_z(t, x, z + h_z/2)/dt)))"
"Eq(Vector(v_x(t + dt, x + h_x/2, z), v_z(t + dt, x, z + h_z/2)), Vector(dt*(1.0*Derivative(p(t, x, z), x) + v_x(t, x + h_x/2, z)/dt), dt*(1.0*Derivative(p(t, x, z), z) + v_z(t, x, z + h_z/2)/dt)))"
]
},
"execution_count": 7,
Expand Down Expand Up @@ -190,9 +190,9 @@
"data": {
"text/plain": [
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
" PerfEntry(time=0.003492999999999997, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.005126, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section1', rank=None),\n",
" PerfEntry(time=0.0014679999999999954, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
" PerfEntry(time=0.0023490000000000017, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
]
},
"execution_count": 10,
Expand Down Expand Up @@ -296,9 +296,9 @@
"data": {
"text/plain": [
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
" PerfEntry(time=0.004075000000000002, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.005326999999999997, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section1', rank=None),\n",
" PerfEntry(time=0.0012249999999999995, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
" PerfEntry(time=0.0017349999999999978, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
]
},
"execution_count": 14,
Expand Down
82 changes: 43 additions & 39 deletions examples/seismic/tutorials/06_elastic_varying_parameters.ipynb

Large diffs are not rendered by default.

19 changes: 15 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,22 @@
from setuptools import setup, find_packages


def min_max(pkgs, pkg_name):
pkg = [p for p in pkgs if pkg_name in p][0]
minsign = '>=' if '>=' in pkg else '>'
maxsign = '<=' if '<=' in pkg else '<'
vmin = pkg.split(minsign)[1].split(',')[0]
vmax = pkg.split(maxsign)[-1]
return vmin, vmax


def numpy_compat(required):
new_reqs = [r for r in required if "numpy" not in r and "sympy" not in r]
sympy_lb, sympy_ub = min_max(required, "sympy")
numpy_lb, numpy_ub = min_max(required, "numpy")
if sys.version_info < (3, 9):
# Numpy 2.0 requires python > 3.8
new_reqs.extend(["sympy>=1.9,<1.13", "numpy>1.16,<2.0"])
new_reqs.extend([f"sympy>={sympy_lb},<1.12.1", f"numpy>{numpy_lb},<2.0"])
return new_reqs

# Due to api changes in numpy 2.0, it requires sympy 1.12.1 at the minimum
Expand All @@ -20,11 +31,11 @@ def numpy_compat(required):
sympy_version = pkg_resources.get_distribution("sympy").version
min_ver2 = pkg_resources.parse_version("1.12.1")
if pkg_resources.parse_version(sympy_version) < min_ver2:
new_reqs.append("numpy>1.16,<2.0")
new_reqs.append(f"numpy>{numpy_lb},<2.0")
else:
new_reqs.append("numpy>=2.0")
new_reqs.append(f"numpy>=2.0,<{numpy_ub}")
except pkg_resources.DistributionNotFound:
new_reqs.extend(["sympy>=1.12.1", "numpy>=2.0"])
new_reqs.extend([f"sympy>=1.12.1,<{sympy_ub}", f"numpy>=2.0,<{numpy_ub}"])

return new_reqs

Expand Down
13 changes: 13 additions & 0 deletions tests/test_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,16 @@ def test_shifted_lap_of_tensor(shift, ndim):
type(shift) is tuple else d + shift * d.spacing)
ref += getattr(v[j, i], 'd%s2' % d.name)(x0=x0, fd_order=order)
assert df[j] == ref


def test_basic_arithmetic():
grid = Grid(tuple([5]*3))
tau = TensorFunction(name="tau", grid=grid)

# Scalar operations
t1 = tau + 1
print(t1)
assert all(t1i == ti + 1 for (t1i, ti) in zip(t1, tau))

t1 = tau * 2
assert all(t1i == ti * 2 for (t1i, ti) in zip(t1, tau))

0 comments on commit 4d460a0

Please sign in to comment.