diff --git a/devito/__init__.py b/devito/__init__.py index b0a981dcfa..80cb63a8cc 100644 --- a/devito/__init__.py +++ b/devito/__init__.py @@ -64,7 +64,7 @@ def reinit_compiler(val): # Setup target platform and compiler configuration.add('platform', 'cpu64', list(platform_registry), callback=lambda i: platform_registry[i]()) -configuration.add('compiler', 'custom', list(compiler_registry), +configuration.add('compiler', 'custom', compiler_registry, callback=lambda i: compiler_registry[i]()) # Setup language for shared-memory parallelism diff --git a/devito/arch/archinfo.py b/devito/arch/archinfo.py index 0d35fa5aeb..cd4f9d45ad 100644 --- a/devito/arch/archinfo.py +++ b/devito/arch/archinfo.py @@ -25,7 +25,8 @@ 'INTEL64', 'SNB', 'IVB', 'HSW', 'BDW', 'KNL', 'KNL7210', 'SKX', 'KLX', 'CLX', 'CLK', 'SPR', # ARM CPUs - 'AMD', 'ARM', 'AppleArm', 'M1', 'M2', 'M3', 'GRAVITON', + 'AMD', 'ARM', 'AppleArm', 'M1', 'M2', 'M3', + 'Graviton', 'GRAVITON2', 'GRAVITON3', 'GRAVITON4', # Other legacy CPUs 'POWER8', 'POWER9', # Generic GPUs @@ -764,6 +765,20 @@ def march(self): return min(mx, 'm2') +class Graviton(Arm): + + @property + def version(self): + return int(self.name.split('graviton')[-1]) + + @cached_property + def march(self): + if self.version >= 4: + return 'neoverse-n2' + else: + return 'neoverse-n1' + + class Amd(Cpu64): known_isas = ('cpp', 'sse', 'avx', 'avx2') @@ -912,7 +927,9 @@ def march(cls): SPR = IntelGoldenCove('spr') # Sapphire Rapids ARM = Arm('arm') -GRAVITON = Arm('graviton') +GRAVITON2 = Graviton('graviton2') +GRAVITON3 = Graviton('graviton3') +GRAVITON4 = Graviton('graviton4') M1 = AppleArm('m1') M2 = AppleArm('m2') M3 = AppleArm('m3') diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 9cd94ed597..151bba868a 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -13,7 +13,7 @@ from codepy.toolchain import (GCCToolchain, call_capture_output as _call_capture_output) -from devito.arch import (AMDGPUX, Cpu64, AppleArm, NVIDIAX, POWER8, POWER9, GRAVITON, +from devito.arch import (AMDGPUX, Cpu64, AppleArm, NVIDIAX, POWER8, POWER9, Graviton, IntelDevice, get_nvidia_cc, check_cuda_runtime, get_m1_llvm_path) from devito.exceptions import CompilationError @@ -434,6 +434,8 @@ def __init_finalize__(self, **kwargs): if platform in [POWER8, POWER9]: # -march isn't supported on power architectures, is -mtune needed? self.cflags = ['-mcpu=native'] + self.cflags + elif isinstance(platform, Graviton): + self.cflags = ['-mcpu=%s' % platform.march] + self.cflags else: self.cflags = ['-march=native'] + self.cflags @@ -462,8 +464,8 @@ def __init_finalize__(self, **kwargs): platform = kwargs.pop('platform', configuration['platform']) # Graviton flag - if platform is GRAVITON: - self.cflags += ['-mcpu=neoverse-n1'] + if isinstance(platform, Graviton): + self.cflags += ['-mcpu=%s' % platform.march] class ClangCompiler(Compiler): @@ -962,7 +964,23 @@ def __new_with__(self, **kwargs): return super().__new_with__(base=self._base, **kwargs) -compiler_registry = { +class CompilerRegistry(dict): + """ + Registry dict for deriving Compiler classes according to the environment variable + DEVITO_ARCH. Developers should add new compiler classes here. + """ + + def __getitem__(self, key): + if key.startswith('gcc-'): + i = key.split('-')[1] + return partial(GNUCompiler, suffix=i) + return super().__getitem__(key) + + def __contains__(self, k): + return k in self.keys() or k.startswith('gcc-') + + +_compiler_registry = { 'custom': CustomCompiler, 'gnu': GNUCompiler, 'gcc': GNUCompiler, @@ -989,10 +1007,6 @@ def __new_with__(self, **kwargs): 'knl': IntelKNLCompiler, 'dpcpp': DPCPPCompiler, } -""" -Registry dict for deriving Compiler classes according to the environment variable -DEVITO_ARCH. Developers should add new compiler classes here. -""" -compiler_registry.update({'gcc-%s' % i: partial(GNUCompiler, suffix=i) - for i in ['4.9', '5', '6', '7', '8', '9', '10', - '11', '12', '13']}) + + +compiler_registry = CompilerRegistry(**_compiler_registry) diff --git a/tests/test_arch.py b/tests/test_arch.py index fd00bd70da..501ff6d8d8 100644 --- a/tests/test_arch.py +++ b/tests/test_arch.py @@ -1,6 +1,6 @@ import pytest -from devito.arch.compiler import sniff_compiler_version +from devito.arch.compiler import sniff_compiler_version, compiler_registry @pytest.mark.parametrize("cc", [ @@ -10,3 +10,8 @@ def test_sniff_compiler_version(cc): with pytest.raises(RuntimeError, match=cc): sniff_compiler_version(cc) + + +@pytest.mark.parametrize("cc", ['gcc-4.9', 'gcc-11', 'gcc', 'gcc-14', 'gcc-123']) +def test_gcc(cc): + assert cc in compiler_registry