Skip to content

Commit df0e553

Browse files
authored
Merge pull request #67 from qutech/feature/len_and_getitem
Add __len__ and __getitem__ methods to PulseSequence
2 parents 8bf6159 + d43e184 commit df0e553

File tree

3 files changed

+73
-7
lines changed

3 files changed

+73
-7
lines changed

doc/source/examples/advanced_concatenation.ipynb

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,12 @@
164164
"FF_X2 = {key: val.get_filter_function(omega[key]) for key, val in X2.items()}\n",
165165
"FF_Y2 = {key: val.get_filter_function(omega[key]) for key, val in Y2.items()}\n",
166166
"H = {key: ff.concatenate((Y2, X2, X2), calc_pulse_correlation_FF=True)\n",
167-
" for (key, X2), (key, Y2) in zip(X2.items(), Y2.items())}"
167+
" for (key, X2), (key, Y2) in zip(X2.items(), Y2.items())}\n",
168+
"\n",
169+
"# Note that we can also slice PulseSequence objects, eg\n",
170+
"# X = H['primitive'][1:]\n",
171+
"# or\n",
172+
"# segments = [segment for segment in H['primitive']]"
168173
]
169174
},
170175
{
@@ -246,7 +251,7 @@
246251
"name": "python",
247252
"nbconvert_exporter": "python",
248253
"pygments_lexer": "ipython3",
249-
"version": "3.8.5"
254+
"version": "3.9.4"
250255
}
251256
},
252257
"nbformat": 4,

filter_functions/pulse_sequence.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,28 @@ def __eq__(self, other: object) -> bool:
368368

369369
return True
370370

371+
def __len__(self) -> int:
372+
return len(self.dt)
373+
374+
def __getitem__(self, key) -> 'PulseSequence':
375+
"""Return a slice of the PulseSequence."""
376+
new_dt = np.atleast_1d(self.dt[key])
377+
if not new_dt.size:
378+
raise IndexError('Cannot create empty PulseSequence')
379+
380+
new = self.__class__(
381+
c_opers=self.c_opers,
382+
n_opers=self.n_opers,
383+
c_oper_identifiers=self.c_oper_identifiers,
384+
n_oper_identifiers=self.n_oper_identifiers,
385+
c_coeffs=np.atleast_2d(self.c_coeffs.T[key]).T,
386+
n_coeffs=np.atleast_2d(self.n_coeffs.T[key]).T,
387+
dt=new_dt,
388+
d=self.d,
389+
basis=self.basis
390+
)
391+
return new
392+
371393
def __copy__(self) -> 'PulseSequence':
372394
"""Return shallow copy of self"""
373395
cls = self.__class__
@@ -1488,14 +1510,15 @@ def concatenate_without_filter_function(pulses: Iterable[PulseSequence],
14881510
concatenate: Concatenate PulseSequences including filter functions.
14891511
concatenate_periodic: Concatenate PulseSequences periodically.
14901512
"""
1491-
pulses = tuple(pulses)
14921513
try:
1493-
# Do awkward checking for type
1494-
if not all(hasattr(pls, 'c_opers') for pls in pulses):
1495-
raise TypeError('Can only concatenate PulseSequences!')
1514+
pulses = tuple(pulses)
14961515
except TypeError:
14971516
raise TypeError(f'Expected pulses to be iterable, not {type(pulses)}')
14981517

1518+
if not all(hasattr(pls, 'c_opers') for pls in pulses):
1519+
# Do awkward checking for type
1520+
raise TypeError('Can only concatenate PulseSequences!')
1521+
14991522
# Check if the Hamiltonians' shapes are compatible, ie the set of all
15001523
# shapes has length 1
15011524
if len(set(pulse.c_opers.shape[1:] for pulse in pulses)) != 1:

tests/test_sequencing.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,44 @@ def test_concatenate_base(self):
6262
pulse_2.omega = [3, 4]
6363
ff.concatenate([pulse_1, pulse_2], calc_filter_function=True)
6464

65+
def test_slicing(self):
66+
"""Tests _getitem__."""
67+
for d, n in zip(rng.integers(2, 5, 20), rng.integers(3, 51, 20)):
68+
pulse = testutil.rand_pulse_sequence(d, n)
69+
parts = np.array([part for part in pulse], dtype=object).squeeze()
70+
71+
# Iterable
72+
self.assertEqual(pulse, ff.concatenate(parts))
73+
self.assertEqual(len(pulse), n)
74+
75+
# Slices
76+
ix = rng.integers(1, n-1)
77+
part = pulse[ix]
78+
self.assertEqual(part, parts[ix])
79+
self.assertEqual(pulse, ff.concatenate([pulse[:ix], pulse[ix:]]))
80+
81+
# More complicated slices
82+
self.assertEqual(pulse[:len(pulse) // 2 * 2],
83+
ff.concatenate([p for zipped in zip(pulse[::2], pulse[1::2])
84+
for p in zipped]))
85+
self.assertEqual(pulse[::-1], ff.concatenate(parts[::-1]))
86+
87+
# Boolean indices
88+
ix = rng.integers(0, 2, size=n, dtype=bool)
89+
if not ix.any():
90+
with self.assertRaises(IndexError):
91+
pulse[ix]
92+
else:
93+
self.assertEqual(pulse[ix], ff.concatenate(parts[ix]))
94+
95+
# Raises
96+
with self.assertRaises(IndexError):
97+
pulse[:0]
98+
with self.assertRaises(IndexError):
99+
pulse[1, 3]
100+
with self.assertRaises(IndexError):
101+
pulse['a']
102+
65103
def test_concatenate_without_filter_function(self):
66104
"""Concatenate two Spin Echos without filter functions."""
67105
tau = 10
@@ -103,7 +141,7 @@ def test_concatenate_without_filter_function(self):
103141

104142
with self.assertRaises(TypeError):
105143
# Not iterable
106-
pulse_sequence.concatenate_without_filter_function(pulse)
144+
pulse_sequence.concatenate_without_filter_function(1)
107145

108146
with self.assertRaises(ValueError):
109147
# Incompatible Hamiltonian shapes

0 commit comments

Comments
 (0)