Skip to content

Commit 6cfd14a

Browse files
authored
Merge pull request #269 from imr-framework/dev_schuenke
Add convert_to_arbitrary option to make_extended_trapezoid_area function
2 parents 689615a + a7c5534 commit 6cfd14a

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

src/pypulseq/make_extended_trapezoid_area.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def make_extended_trapezoid_area(
1414
channel: str,
1515
grad_start: float,
1616
grad_end: float,
17+
convert_to_arbitrary: bool = False,
1718
system: Union[Opts, None] = None,
1819
) -> Tuple[SimpleNamespace, np.array, np.array]:
1920
"""Make the shortest possible extended trapezoid for given area and gradient start and end point.
@@ -28,6 +29,8 @@ def make_extended_trapezoid_area(
2829
Starting non-zero gradient value.
2930
grad_end : float
3031
Ending non-zero gradient value.
32+
convert_to_arbitrary : bool, default=False
33+
Boolean flag to enable converting the extended trapezoid gradient into an arbitrary gradient.
3134
system: Opts, optional
3235
System limits.
3336
@@ -214,7 +217,9 @@ def binary_search(fun, lower_limit, upper_limit):
214217
times = cumsum(0, time_ramp_up, time_ramp_down)
215218
amplitudes = np.array([grad_start, grad_amp, grad_end])
216219

217-
grad = make_extended_trapezoid(channel=channel, system=system, times=times, amplitudes=amplitudes)
220+
grad = make_extended_trapezoid(
221+
channel=channel, amplitudes=amplitudes, convert_to_arbitrary=convert_to_arbitrary, system=system, times=times
222+
)
218223

219224
# Overwrite trace
220225
if trace_enabled():
@@ -223,4 +228,4 @@ def binary_search(fun, lower_limit, upper_limit):
223228
if not abs(grad.area - area) < 1e-8:
224229
raise ValueError(f'Could not find a solution for area={area}.')
225230

226-
return grad, np.array(times), amplitudes
231+
return grad, grad.tt, grad.waveform

tests/test_make_extended_trapezoid_area.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
]
4747

4848

49-
@pytest.mark.parametrize('grad_start,grad_end,area', test_zoo)
49+
@pytest.mark.parametrize('grad_start, grad_end, area', test_zoo)
5050
def test_make_extended_trapezoid_area(grad_start, grad_end, area):
5151
g, _, _ = make_extended_trapezoid_area(
5252
channel='x', grad_start=grad_start, grad_end=grad_end, area=area, system=system
@@ -71,7 +71,7 @@ def test_make_extended_trapezoid_area(grad_start, grad_end, area):
7171
]
7272

7373

74-
@pytest.mark.parametrize('grad_start,grad_end,area', test_zoo_random)
74+
@pytest.mark.parametrize('grad_start, grad_end, area', test_zoo_random)
7575
def test_make_extended_trapezoid_area_random_cases(grad_start, grad_end, area):
7676
g, _, _ = make_extended_trapezoid_area(
7777
channel='x', grad_start=grad_start, grad_end=grad_end, area=area, system=system
@@ -85,6 +85,34 @@ def test_make_extended_trapezoid_area_random_cases(grad_start, grad_end, area):
8585
assert slew_ok, 'Maximum slew rate violated'
8686

8787

88+
@pytest.mark.parametrize('grad_start, grad_end, area', test_zoo_random)
89+
def test_make_extended_trapezoid_area_convert_to_arb(grad_start, grad_end, area):
90+
g, _, _ = make_extended_trapezoid_area(
91+
channel='x', grad_start=grad_start, grad_end=grad_end, area=area, system=system
92+
)
93+
94+
g_arb, _, _ = make_extended_trapezoid_area(
95+
channel='x', grad_start=grad_start, grad_end=grad_end, area=area, convert_to_arbitrary=True, system=system
96+
)
97+
98+
grad_ok = all(abs(g.waveform) <= system.max_grad)
99+
slew_ok = all(abs(np.diff(g.waveform) / np.diff(g.tt)) <= system.max_slew)
100+
101+
grad_arb_ok = all(abs(g_arb.waveform) <= system.max_grad)
102+
slew_arb_ok = all(abs(np.diff(g_arb.waveform) / np.diff(g_arb.tt)) <= system.max_slew)
103+
104+
assert pytest.approx(g.area) == g_arb.area, 'Area of extended trapz and arb gradient do not match'
105+
assert pytest.approx(g.shape_dur) == g_arb.shape_dur, 'Duration of extended trapz and arb gradient do not match'
106+
assert g.tt.shape[0] <= g_arb.tt.shape[0], (
107+
'Extended trapezoid should have less or equal number of points than arb gradient'
108+
)
109+
assert g.waveform.shape[0] <= g_arb.waveform.shape[0], (
110+
'Extended trapezoid should have less or equal number of points than arb gradient'
111+
)
112+
assert grad_ok == grad_arb_ok, 'Gradient strength violation between extended trapz and arb gradient'
113+
assert slew_ok == slew_arb_ok, 'Slew rate violation between extended trapz and arb gradient'
114+
115+
88116
random.seed(0)
89117
test_zoo_random = [
90118
(
@@ -96,7 +124,7 @@ def test_make_extended_trapezoid_area_random_cases(grad_start, grad_end, area):
96124
]
97125

98126

99-
@pytest.mark.parametrize('grad_start,grad_end,grad_amp', test_zoo_random)
127+
@pytest.mark.parametrize('grad_start, grad_end, grad_amp', test_zoo_random)
100128
def test_make_extended_trapezoid_area_recreate(grad_start, grad_end, grad_amp):
101129
def _to_raster(time: float) -> float:
102130
return np.ceil(time / system.grad_raster_time) * system.grad_raster_time

0 commit comments

Comments
 (0)