Skip to content

Commit d60101b

Browse files
committed
sparsity: add test
1 parent 48a54d3 commit d60101b

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

test/unit/test_matrices.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from pyop2 import op2
4040
from pyop2.exceptions import MapValueError, ModeValueError
4141
from pyop2.mpi import COMM_WORLD
42+
from pyop2.datatypes import IntType
4243

4344
from petsc4py.PETSc import ScalarType
4445

@@ -941,6 +942,45 @@ def test_assemble_mixed_rhs_vector(self, mset, mmap, mvdat):
941942
assert_allclose(dat[1].data_ro, exp, eps)
942943

943944

945+
def test_matrices_sparsity_blockwise_specification():
946+
#
947+
# 0 1 2 3 nodesetA
948+
# x----x----x----x
949+
# 0 1 2 setA
950+
#
951+
# 0 1 2 nodesetB
952+
# x----x----x
953+
# 0 1 setB
954+
#
955+
# 0 1 2 3 | 0 1 2
956+
# 0 x |
957+
# 1 x | x x
958+
# 2 x | x x x
959+
# 3 x | x x sparsity
960+
# ----------+------
961+
# 0 x x | x
962+
# 1 x x x | x
963+
# 2 x x | x
964+
#
965+
arity = 2
966+
setA = op2.Set(3)
967+
nodesetA = op2.Set(4)
968+
setB = op2.Set(2)
969+
nodesetB = op2.Set(3)
970+
nodesetAB = op2.MixedSet((nodesetA, nodesetB))
971+
datasetAB = nodesetAB ** 1
972+
mapA = op2.Map(setA, nodesetA, arity, values=[[0, 1], [1, 2], [2, 3]])
973+
mapB = op2.Map(setB, nodesetB, arity, values=[[0, 1], [1, 2]])
974+
mapBA = op2.Map(setB, setA, 1, values=[1, 2])
975+
mapAB = op2.Map(setA, setB, 1, values=[-1, 0, 1]) # "inverse" map
976+
s = op2.Sparsity((datasetAB, datasetAB), {(1, 0): [(mapB, op2.ComposedMap(mapA, mapBA), None)],
977+
(0, 1): [(mapA, op2.ComposedMap(mapB, mapAB), None)]})
978+
assert np.all(s._blocks[0][0].nnz == np.array([1, 1, 1, 1], dtype=IntType))
979+
assert np.all(s._blocks[0][1].nnz == np.array([0, 2, 3, 2], dtype=IntType))
980+
assert np.all(s._blocks[1][0].nnz == np.array([2, 3, 2], dtype=IntType))
981+
assert np.all(s._blocks[1][1].nnz == np.array([1, 1, 1], dtype=IntType))
982+
983+
944984
if __name__ == '__main__':
945985
import os
946986
pytest.main(os.path.abspath(__file__))

0 commit comments

Comments
 (0)