Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions testsuite/pytests/sli2py_mpi/mpi_test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,35 @@ def assert_correct_results(self, tmpdirpath):
pd.testing.assert_frame_equal(res[0], r)


class MPITestAssertAllRanksEqual(MPITestWrapper):
"""
Assert that the results from all ranks are equal, independent of number of ranks.
"""

def assert_correct_results(self, tmpdirpath):
self.collect_results(tmpdirpath)

all_res = []
if self._spike:
raise NotImplementedError("SPIKE data not supported by MPITestAssertAllRanksEqual")

if self._multi:
raise NotImplementedError("MULTI data not supported by MPITestAssertAllRanksEqual")

if self._other:
all_res = list(self._other.values()) # need to get away from dict_values to allow indexing below

assert len(all_res) == len(self._procs_lst), "Missing data for some process numbers"
assert len(all_res[0]) == self._procs_lst[0], "Data for first proc number does not match number of procs"

reference = all_res[0][0]
for res, num_ranks in zip(all_res, self._procs_lst):
assert len(res) == num_ranks, f"Got data for {len(res)} ranks, expected {num_ranks}."

for r in res:
pd.testing.assert_frame_equal(r, reference)


class MPITestAssertCompletes(MPITestWrapper):
"""
Test class that just confirms that the test code completes.
Expand Down
50 changes: 50 additions & 0 deletions testsuite/pytests/sli2py_mpi/test_global_rng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
#
# test_global_rng.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

import pandas as pd
import pytest
from mpi_test_wrapper import MPITestAssertAllRanksEqual


# Parametrization over the number of nodes here only to show hat it works
@pytest.mark.skipif_incompatible_mpi
@MPITestAssertAllRanksEqual([1, 2, 4], debug=False)
def test_global_rng():
"""
Confirm that NEST random parameter used from the Python level uses globally sync'ed RNG correctly.
All ranks must report identical random number sequences independent of the number of ranks.

The test compares connection data written to OTHER_LABEL.
"""

import nest

nest.rng_seed = 12
p = nest.CreateParameter("uniform", {"min": 0, "max": 1})

# Uncomment one of the two for loops to provoke failure
# for _ in range(nest.num_processes):
# p.GetValue()
# for _ in range(nest.Rank()):
# p.GetValue()

vals = pd.DataFrame([p.GetValue() for _ in range(5)])
vals.to_csv(OTHER_LABEL.format(nest.num_processes, nest.Rank()), sep="\t") # noqa: F821
Loading