Skip to content

Commit

Permalink
api: fix gpu-fit for tensorfunction
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Dec 18, 2023
1 parent 8f45ba0 commit dbd0064
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
17 changes: 11 additions & 6 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _normalize_kwargs(cls, **kwargs):
o['par-dynamic-work'] = np.inf # Always use static scheduling
o['par-nested'] = np.inf # Never use nested parallelism
o['par-disabled'] = oo.pop('par-disabled', True) # No host parallelism by default
o['gpu-fit'] = as_tuple(oo.pop('gpu-fit', cls._normalize_gpu_fit(**kwargs)))
o['gpu-fit'] = cls._normalize_gpu_fit(oo, **kwargs)
o['gpu-create'] = as_tuple(oo.pop('gpu-create', ()))

# Distributed parallelism
Expand All @@ -95,11 +95,16 @@ def _normalize_kwargs(cls, **kwargs):
return kwargs

@classmethod
def _normalize_gpu_fit(cls, **kwargs):
if any(i in kwargs['mode'] for i in ['tasking', 'streaming']):
return None
else:
return cls.GPU_FIT
def _normalize_gpu_fit(cls, oo, **kwargs):
try:
gfit = as_tuple(oo.pop('gpu-fit'))
gfit = set().union([f.values() if f.is_AbstractTensor else f for f in gfit])
return tuple(gfit)
except KeyError:
if any(i in kwargs['mode'] for i in ['tasking', 'streaming']):
return (None,)
else:
return as_tuple(cls.GPU_FIT)

@classmethod
def _rcompile_wrapper(cls, **kwargs0):
Expand Down
1 change: 1 addition & 0 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ class Basic(CodeSymbol):

# Top hierarchy
is_AbstractFunction = False
is_AbstractTensor = False
is_AbstractObject = False

# Symbolic objects created internally by Devito
Expand Down
14 changes: 13 additions & 1 deletion tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from devito import (Constant, Eq, Inc, Grid, Function, ConditionalDimension,
Dimension, MatrixSparseTimeFunction, SparseTimeFunction,
SubDimension, SubDomain, SubDomainSet, TimeFunction,
Operator, configuration, switchconfig)
Operator, configuration, switchconfig, TensorTimeFunction)
from devito.arch import get_gpu_info
from devito.exceptions import InvalidArgument
from devito.ir import (Conditional, Expression, Section, FindNodes, FindSymbols,
Expand Down Expand Up @@ -1423,6 +1423,18 @@ def test_npthreads(self):
with pytest.raises(InvalidArgument):
assert op.arguments(time_M=2, npthreads0=5)

def test_gpu_fit_w_tensor_functions(self):
grid = Grid(shape=(10, 10))

u = TensorTimeFunction(name='u', grid=grid)
usave = TensorTimeFunction(name="usave", grid=grid, save=10)

eqns = [Eq(u.forward, u + 1),
Eq(usave, u.forward)]

op = Operator(eqns, opt=('noop', {'gpu-fit': usave}))
assert set(op._options['gpu-fit']) - set(usave.values()) == set()


class TestMisc(object):

Expand Down

0 comments on commit dbd0064

Please sign in to comment.