Skip to content

Commit

Permalink
compiler: fix dtype for mpi routines
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 17, 2025
1 parent dcd0bd4 commit dbf1f78
Showing 6 changed files with 18 additions and 9 deletions.
6 changes: 3 additions & 3 deletions devito/mpi/routines.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
from devito.mpi import MPI
from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite,
IndexedPointer, Macro, cast_mapper, subs_op_args)
from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_to_ctype,
from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_alloc_ctype,
flatten, generator, is_integer, split)
from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject,
CompositeObject, CustomDimension)
@@ -1204,8 +1204,8 @@ def _arg_defaults(self, allocator, alias, args=None):
entry.sizes = (c_int*len(shape))(*shape)

# Allocate the send/recv buffers
size = reduce(mul, shape)*dtype_len(self.target.dtype)
ctype = dtype_to_ctype(f.dtype)
ctype, c_scale = dtype_alloc_ctype(f.dtype)
size = int(reduce(mul, shape) * c_scale) * dtype_len(self.target.dtype)
entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype)
entry.bufs, bufs_memfree_args = allocator._alloc_C_libcall(size, ctype)

4 changes: 2 additions & 2 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
@@ -1118,7 +1118,7 @@ def __setstate__(self, state):
self._lib.name = soname

self._allocator = default_allocator(
'%s.%s.%s' % (self._compiler.name, self._language, self._platform)
'%s.%s.%s' % (self._compiler.__class__.name, self._language, self._platform)
)


@@ -1404,7 +1404,7 @@ def parse_kwargs(**kwargs):

# `allocator`
kwargs['allocator'] = default_allocator(
'%s.%s.%s' % (kwargs['compiler'].name,
'%s.%s.%s' % (kwargs['compiler'].__class__.__name__,
kwargs['language'],
kwargs['platform'])
)
3 changes: 2 additions & 1 deletion devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import singledispatch

from sympy import S
import numpy as np

from devito.finite_differences import IndexDerivative
from devito.ir import Backward, Forward, Interval, IterationSpace, Queue
@@ -157,7 +158,7 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
# NOTE: created before recurring so that we ultimately get a sound ordering
try:
s = reusables.pop()
assert s.dtype is dtype
assert np.can_cast(s.dtype, dtype)
except KeyError:
name = sregistry.make_name(prefix='r')
s = Symbol(name=name, dtype=dtype)
2 changes: 1 addition & 1 deletion devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
@@ -393,7 +393,7 @@ def normalize_args(args):
for k, v in args.items():
try:
retval[k] = sympify(v, strict=True)
except SympifyError:
except (TypeError, SympifyError):
continue

return retval
6 changes: 5 additions & 1 deletion devito/tools/dtypes_lowering.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
# NOTE: the following is inspired by pyopencl.cltypes

mapper = {
"half": np.float16,
"int": np.int32,
"float": np.float32,
"double": np.float64
@@ -189,7 +190,8 @@ def dtype_to_mpitype(dtype):
np.int32: 'MPI_INT',
np.float32: 'MPI_FLOAT',
np.int64: 'MPI_LONG',
np.float64: 'MPI_DOUBLE'
np.float64: 'MPI_DOUBLE',
np.float16: 'MPI_UNSIGNED_SHORT'
}[dtype]


@@ -222,6 +224,8 @@ class c_restrict_void_p(ctypes.c_void_p):

ctypes_vector_mapper = {}
for base_name, base_dtype in mapper.items():
if base_dtype is np.float16:
continue
base_ctype = dtype_to_ctype(base_dtype)

for count in counts:
6 changes: 5 additions & 1 deletion devito/types/dense.py
Original file line number Diff line number Diff line change
@@ -792,17 +792,21 @@ def _halo_exchange(self):
# Gather send data
data = self._data_in_region(OWNED, d, i)
sendbuf = np.ascontiguousarray(data)
if self.dtype == np.float16:
sendbuf = sendbuf.view(np.uint16)

# Setup recv buffer
shape = self._data_in_region(HALO, d, i.flip()).shape
recvbuf = np.ndarray(shape=shape, dtype=self.dtype)
if self.dtype == np.float16:
recvbuf = recvbuf.view(np.uint16)

# Communication
comm.Sendrecv(sendbuf, dest=dest, recvbuf=recvbuf, source=source)

# Scatter received data
if recvbuf is not None and source != MPI.PROC_NULL:
self._data_in_region(HALO, d, i.flip())[:] = recvbuf
self._data_in_region(HALO, d, i.flip())[:] = recvbuf.view(self.dtype)

self._is_halo_dirty = False

0 comments on commit dbf1f78

Please sign in to comment.