Skip to content

Commit

Permalink
Adds override mode for aprun
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonb5 committed Dec 7, 2022
1 parent c7f9ea6 commit 53891d2
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 11 deletions.
33 changes: 32 additions & 1 deletion CIME/XML/env_mach_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def _find_best_mpirun_match(self, attribs):

def get_aprun_mode(self, attribs):
default_mode = "ignore"
valid_modes = ("ignore", "default")
valid_modes = ("ignore", "default", "override")

try:
the_match = self._find_best_mpirun_match(attribs)
Expand All @@ -669,6 +669,37 @@ def get_aprun_mode(self, attribs):

return mode

def get_aprun_args(self, case, attribs, job, overrides=None):
args = {}

try:
the_match = self._find_best_mpirun_match(attribs)
except ValueError:
return None

arg_node = self.get_optional_child("arguments", root=the_match)

if arg_node:
arg_nodes = self.get_children("arg", root=arg_node)

for arg_node in arg_nodes:
position = self.get(arg_node, "position")

if position is None:
continue

arg_value = transform_vars(
self.text(arg_node),
case=case,
subgroup=job,
overrides=overrides,
default=self.get(arg_node, "default"),
)

args[arg_value] = dict(position=position)

return args

def get_mpirun(self, case, attribs, job, exe_only=False, overrides=None):
"""
Find best match, return (executable, {arg_name : text})
Expand Down
30 changes: 21 additions & 9 deletions CIME/aprun.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def _get_aprun_cmd_for_case_impl(
compiler,
machine,
run_exe,
extra_args,
):
###############################################################################
"""
Expand All @@ -38,19 +39,22 @@ def _get_aprun_cmd_for_case_impl(
>>> compiler = "pgi"
>>> machine = "titan"
>>> run_exe = "e3sm.exe"
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe)
(' -S 4 -n 680 -N 8 -d 2 e3sm.exe : -S 2 -n 128 -N 4 -d 4 e3sm.exe ', 117, 808, 4, 4)
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe, None)
(' -S 4 -n 680 -N 8 -d 2 e3sm.exe : -S 2 -n 128 -N 4 -d 4 e3sm.exe ', 117, 808, 4, 4)
>>> compiler = "intel"
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe)
(' -S 4 -cc numa_node -n 680 -N 8 -d 2 e3sm.exe : -S 2 -cc numa_node -n 128 -N 4 -d 4 e3sm.exe ', 117, 808, 4, 4)
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe, None)
(' -S 4 -cc numa_node -n 680 -N 8 -d 2 e3sm.exe : -S 2 -cc numa_node -n 128 -N 4 -d 4 e3sm.exe ', 117, 808, 4, 4)
>>> ntasks = [64, 64, 64, 64, 64, 64, 64, 64, 1]
>>> nthreads = [1, 1, 1, 1, 1, 1, 1, 1, 1]
>>> rootpes = [0, 0, 0, 0, 0, 0, 0, 0, 0]
>>> pstrids = [1, 1, 1, 1, 1, 1, 1, 1, 1]
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe)
(' -S 8 -cc numa_node -n 64 -N 16 -d 1 e3sm.exe ', 4, 64, 16, 1)
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe, None)
(' -S 8 -cc numa_node -n 64 -N 16 -d 1 e3sm.exe ', 4, 64, 16, 1)
"""
if extra_args is None:
extra_args = {}

max_tasks_per_node = 1 if max_tasks_per_node < 1 else max_tasks_per_node

total_tasks = 0
Expand Down Expand Up @@ -78,6 +82,12 @@ def _get_aprun_cmd_for_case_impl(
if maxt[c1] < 1:
maxt[c1] = 1

global_flags = " ".join(
[x for x, y in extra_args.items() if y["position"] == "global"]
)

per_flags = " ".join([x for x, y in extra_args.items() if y["position"] == "per"])

# Compute task and thread settings for batch commands
(
tasks_per_node,
Expand All @@ -88,7 +98,7 @@ def _get_aprun_cmd_for_case_impl(
total_node_count,
total_task_count,
aprun_args,
) = (0, max_mpitasks_per_node, 1, maxt[0], maxt[0], 0, 0, "")
) = (0, max_mpitasks_per_node, 1, maxt[0], maxt[0], 0, 0, f" {global_flags}")
c1list = list(range(1, total_tasks))
c1list.append(None)
for c1 in c1list:
Expand All @@ -107,10 +117,11 @@ def _get_aprun_cmd_for_case_impl(
if compiler == "intel":
aprun_args += " -cc numa_node"

aprun_args += " -n {:d} -N {:d} -d {:d} {} {}".format(
aprun_args += " -n {:d} -N {:d} -d {:d} {} {} {}".format(
task_count,
tasks_per_node,
thread_count,
per_flags,
run_exe,
"" if c1 is None else ":",
)
Expand Down Expand Up @@ -140,7 +151,7 @@ def _get_aprun_cmd_for_case_impl(


###############################################################################
def get_aprun_cmd_for_case(case, run_exe, overrides=None):
def get_aprun_cmd_for_case(case, run_exe, overrides=None, extra_args=None):
###############################################################################
"""
Given a case, construct and return the aprun command and optimized node count
Expand Down Expand Up @@ -179,4 +190,5 @@ def get_aprun_cmd_for_case(case, run_exe, overrides=None):
case.get_value("COMPILER"),
case.get_value("MACH"),
run_exe,
extra_args,
)
5 changes: 5 additions & 0 deletions CIME/case/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,10 +2020,15 @@ def get_mpirun_cmd(self, job=None, allow_unresolved_envvars=True, overrides=None
and aprun_mode != "ignore"
# and not "theta" in self.get_value("MACH")
):
extra_args = env_mach_specific.get_aprun_args(
self, mpi_attribs, job, overrides=overrides
)

aprun_args, num_nodes, _, _, _ = get_aprun_cmd_for_case(
self,
run_exe,
overrides=overrides,
extra_args=extra_args,
)
if job in ("case.run", "case.test"):
expect(
Expand Down
64 changes: 64 additions & 0 deletions CIME/tests/test_unit_aprun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import unittest
from unittest import mock

from CIME import aprun

# NTASKS, NTHRDS, ROOTPE, PSTRID
DEFAULT_COMP_ATTRS = [
512, 2, 0, 1,
675, 2, 0, 1,
168, 2, 512, 1,
512, 2, 0, 1,
128, 4, 680, 1,
168, 2, 512, 1,
168, 2, 512, 1,
512, 2, 0, 1,
1, 1, 0, 1,
]

# MAX_TASKS_PER_NODE, MAX_MPITASKS_PER_NODE, PIO_NUMTASKS, PIO_ASYNC_INTERFACE, COMPILER, MACH
DEFAULT_ARGS = [
16,
16,
-1,
False,
"gnu",
"docker",
]


class TestUnitAprun(unittest.TestCase):
def test_aprun_extra_args(self):
case = mock.MagicMock()

case.get_values.return_value = ["CPL", "ATM", "LND", "ICE", "OCN", "ROF", "GLC", "WAV", "IAC"]

case.get_value.side_effect = DEFAULT_COMP_ATTRS + DEFAULT_ARGS

extra_args = {
"-e DEBUG=true": {"position": "global"},
"-j 20": {"position": "per"},
}

aprun_args, total_node_count, total_task_count, min_tasks_per_node, max_thread_count = aprun.get_aprun_cmd_for_case(case, "e3sm.exe", extra_args=extra_args)

assert aprun_args == " -e DEBUG=true -n 680 -N 8 -d 2 -j 20 e3sm.exe : -n 128 -N 4 -d 4 -j 20 e3sm.exe "
assert total_node_count == 117
assert total_task_count == 808
assert min_tasks_per_node == 4
assert max_thread_count == 4

def test_aprun(self):
case = mock.MagicMock()

case.get_values.return_value = ["CPL", "ATM", "LND", "ICE", "OCN", "ROF", "GLC", "WAV", "IAC"]

case.get_value.side_effect = DEFAULT_COMP_ATTRS + DEFAULT_ARGS

aprun_args, total_node_count, total_task_count, min_tasks_per_node, max_thread_count = aprun.get_aprun_cmd_for_case(case, "e3sm.exe")

assert aprun_args == " -n 680 -N 8 -d 2 e3sm.exe : -n 128 -N 4 -d 4 e3sm.exe "
assert total_node_count == 117
assert total_task_count == 808
assert min_tasks_per_node == 4
assert max_thread_count == 4
58 changes: 57 additions & 1 deletion CIME/tests/test_unit_xml_env_mach_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,62 @@


class TestXMLEnvMachSpecific(unittest.TestCase):
def test_aprun_get_args(self):
with tempfile.NamedTemporaryFile() as temp:
temp.write(
b"""<?xml version="1.0"?>
<file id="env_mach_specific.xml" version="2.0">
<header>
These variables control the machine dependent environment including
the paths to compilers and libraries external to cime such as netcdf,
environment variables for use in the running job should also be set here.
</header>
<group id="compliant_values">
<entry id="run_exe" value="${EXEROOT}/e3sm.exe ">
<type>char</type>
<desc>executable name</desc>
</entry>
<entry id="run_misc_suffix" value=" &gt;&gt; e3sm.log.$LID 2&gt;&amp;1 ">
<type>char</type>
<desc>redirect for job output</desc>
</entry>
</group>
<module_system type="none"/>
<environment_variables>
<env name="OMPI_ALLOW_RUN_AS_ROOT">1</env>
<env name="OMPI_ALLOW_RUN_AS_ROOT_CONFIRM">1</env>
</environment_variables>
<mpirun mpilib="openmpi">
<aprun_mode>override</aprun_mode>
<executable>aprun</executable>
<arguments>
<arg name="skipped">should be skipped</arg>
<arg name="ntasks" position="global">-n {{ total_tasks }}</arg>
<arg name="oversubscribe" position="per">--oversubscribe</arg>
</arguments>
</mpirun>
</file>
"""
)
temp.seek(0)

mach_specific = EnvMachSpecific(infile=temp.name)

attribs = {"compiler": "gnu", "mpilib": "openmpi", "threaded": False}

case = mock.MagicMock()

type(case).total_tasks = mock.PropertyMock(return_value=4)

extra_args = mach_specific.get_aprun_args(case, attribs, "case.run")

expected_args = {
"--oversubscribe": {"position": "per"},
"-n 4": {"position": "global"},
}

assert extra_args == expected_args

def test_get_aprun_mode_not_valid(self):
with tempfile.NamedTemporaryFile() as temp:
temp.write(
Expand Down Expand Up @@ -58,7 +114,7 @@ def test_get_aprun_mode_not_valid(self):

assert (
str(e.exception)
== "ERROR: Value 'custom' for \"aprun_mode\" is not valid, options are 'ignore, default'"
== "ERROR: Value 'custom' for \"aprun_mode\" is not valid, options are 'ignore, default, override'"
)

def test_get_aprun_mode_user_defined(self):
Expand Down

0 comments on commit 53891d2

Please sign in to comment.