Skip to content

Commit 42dbb98

Browse files
committed
Allow for 2pi phase multiples in oper_equiv tests
1 parent 067309e commit 42dbb98

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

tests/test_util.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,28 +453,31 @@ def test_oper_equiv(self):
453453

454454
result = util.oper_equiv(psi, psi*np.exp(1j*phase))
455455
self.assertTrue(result[0])
456-
self.assertAlmostEqual(result[1], phase, places=5)
456+
self.assertAlmostEqual(np.mod(result[1], 2*np.pi), np.mod(phase, 2*np.pi), places=5)
457457

458458
result = util.oper_equiv(psi*np.exp(1j*phase), psi)
459459
self.assertTrue(result[0])
460-
self.assertAlmostEqual(result[1], -phase, places=5)
460+
self.assertAlmostEqual(np.mod(result[1], 2*np.pi), np.mod(-phase, 2*np.pi), places=5)
461461

462462
psi /= np.linalg.norm(psi, ord=2)
463463

464464
result = util.oper_equiv(psi, psi*np.exp(1j*phase), normalized=True, eps=1e-13)
465465
self.assertTrue(result[0])
466-
self.assertArrayAlmostEqual(result[1], phase, atol=1e-5)
466+
self.assertArrayAlmostEqual(np.mod(result[1], 2*np.pi), np.mod(phase, 2*np.pi),
467+
atol=1e-5)
467468

468469
result = util.oper_equiv(psi, psi+1)
469470
self.assertFalse(result[0])
470471

471472
result = util.oper_equiv(U, U*np.exp(1j*phase))
472473
self.assertTrue(np.all(result[0]))
473-
self.assertArrayAlmostEqual(result[1], phase, atol=1e-5)
474+
self.assertArrayAlmostEqual(np.mod(result[1], 2*np.pi), np.mod(phase, 2*np.pi),
475+
atol=1e-5)
474476

475477
result = util.oper_equiv(U*np.exp(1j*phase), U)
476478
self.assertTrue(np.all(result[0]))
477-
self.assertArrayAlmostEqual(result[1], -phase, atol=1e-5)
479+
self.assertArrayAlmostEqual(np.mod(result[1], 2*np.pi), np.mod(-phase, 2*np.pi),
480+
atol=1e-5)
478481

479482
norm = np.sqrt(util.dot_HS(U, U))
480483
norm = norm[:, None, None] if U.ndim == 3 else norm
@@ -483,7 +486,7 @@ def test_oper_equiv(self):
483486
# U /= np.expand_dims(np.sqrt(util.dot_HS(U, U)), axis=(-1, -2))
484487
result = util.oper_equiv(U, U*np.exp(1j*phase), normalized=True, eps=1e-10)
485488
self.assertTrue(np.all(result[0]))
486-
self.assertArrayAlmostEqual(result[1], phase)
489+
self.assertArrayAlmostEqual(np.mod(result[1], 2*np.pi), np.mod(phase, 2*np.pi))
487490

488491
result = util.oper_equiv(U, U+1)
489492
self.assertFalse(np.all(result[0]))
@@ -496,7 +499,7 @@ def test_dot_HS(self):
496499

497500
for d in rng.integers(2, 10, (5,)):
498501
U, V = testutil.rand_herm(d, 2)
499-
self.assertArrayAlmostEqual(util.dot_HS(U, V), (U.conj().T @ V).trace())
502+
self.assertArrayAlmostEqual(util.dot_HS(U, V), (U.conj().T @ V).trace())
500503

501504
U = testutil.rand_unit(d).squeeze()
502505
self.assertEqual(util.dot_HS(U, U), d)

0 commit comments

Comments
 (0)