Skip to content

Commit 2094c58

Browse files
ebrahimebrahimltetrel
authored andcommitted
Change Solution.focus to a list of foci (#125)
This includes * changing Solution.focus to Solution.foci and making that be a *list* of Points rather than a Point. * changing the xarray simulation data to have an additional coord "focal_point_index" * adding an axis to the delay and apodization arrays in Solution to index focal points. no need for instance checking anymore since foci will always be a list solution output analysis now takes into account foci list
1 parent 17a6983 commit 2094c58

File tree

6 files changed

+92
-71
lines changed

6 files changed

+92
-71
lines changed

src/openlifu/plan/solution.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import asdict, dataclass, field
44
from datetime import datetime
55
from pathlib import Path
6-
from typing import Optional
6+
from typing import List, Optional
77

88
import numpy as np
99
import xarray as xa
@@ -82,16 +82,22 @@ class Solution:
8282
"""Description of this solution"""
8383

8484
delays: Optional[np.ndarray] = None
85-
"""Vector of time delays to steer the beam"""
85+
"""Vectors of time delays to steer the beam. Shape is (number of foci, number of transducer elements)."""
8686

8787
apodizations: Optional[np.ndarray] = None
88-
"""Vector of apodizations to steer the beam"""
88+
"""Vectors of apodizations to steer the beam. Shape is (number of foci, number of transducer elements)."""
89+
8990
pulse: Pulse = field(default_factory=Pulse)
9091
"""Pulse to send to the transducer when running sonication"""
92+
9193
sequence: Sequence = field(default_factory=Sequence)
9294
"""Pulse sequence to use when running sonication"""
93-
focus: Optional[Point] = None
94-
"""Point that is being focused on in this Solution; part of the focal pattern of the target"""
95+
96+
foci: List[Point] = field(default_factory=list)
97+
"""Points that are focused on in this Solution due to the focal pattern around the target.
98+
Each item in this list is a unique point from the focal pattern, and the pulse sequence is
99+
what determines how many times each point will be used.
100+
"""
95101

96102
# there was "target_id" in the matlab software, but here we do not have the concept of a target ID.
97103
# I believe this was only needed in the matlab software because solutions were organized by target rather
@@ -110,16 +116,9 @@ class Solution:
110116
"""Approval state of this solution as a sonication plan. `True` means the user has provided some
111117
kind of confirmation that the solution is safe and acceptable to be executed."""
112118

113-
def num_foci(self):
119+
def num_foci(self) -> int:
114120
"""Get the number of foci"""
115-
if isinstance(self.focus, list):
116-
nfoc = len(self.focus)
117-
elif isinstance(self.focus, Point):
118-
nfoc = 1
119-
else:
120-
raise ValueError("Cannot get number of foci for types other than Point.")
121-
122-
return nfoc
121+
return len(self.foci)
123122

124123
def analyze(self, transducer: Transducer, options: SolutionOptions = SolutionOptions()) -> SolutionAnalysis:
125124
"""Analyzes the treatment solution.
@@ -166,10 +165,7 @@ def analyze(self, transducer: Transducer, options: SolutionOptions = SolutionOpt
166165
# power_W = np.zeros(self.num_foci())
167166
# TIC = np.zeros(self.num_foci())
168167
for focus_index in range(self.num_foci()):
169-
if isinstance(self.focus, list):
170-
foc = self.focus[focus_index]
171-
elif isinstance(self.focus, Point):
172-
foc = self.focus
168+
foc = self.foci[focus_index]
173169
# output_signal = []
174170
# output_signal = np.zeros((transducer.numelements(), len(input_signal)))
175171
# for i in range(transducer.numelements()):
@@ -203,8 +199,8 @@ def analyze(self, transducer: Transducer, options: SolutionOptions = SolutionOpt
203199
# distance=options.beamwidth_radius,
204200
# options=mask_options)
205201

206-
pk = np.max(pnp_MPa.data * mainlobe_mask) #TODO: pnp_MPa supposed to be a list for each focus: pnp_MPa(focus_index)
207-
solution_analysis.mainlobe_pnp_MPa = pk
202+
pk = np.max(pnp_MPa.data[focus_index] * mainlobe_mask) #TODO: pnp_MPa supposed to be a list for each focus: pnp_MPa(focus_index)
203+
solution_analysis.mainlobe_pnp_MPa += [pk]
208204

209205
# thresh_m3dB = pk*10**(-3 / 20)
210206
# thresh_m6dB = pk*10**(-6 / 20)
@@ -357,10 +353,13 @@ def from_json(json_string : str, simulation_result: Optional[xa.Dataset]=None) -
357353
solution_dict["delays"] = np.array(solution_dict["delays"])
358354
if solution_dict["apodizations"] is not None:
359355
solution_dict["apodizations"] = np.array(solution_dict["apodizations"], ndmin=2)
356+
solution_dict["apodizations"] = np.array(solution_dict["apodizations"], ndmin=2)
360357
solution_dict["pulse"] = Pulse.from_dict(solution_dict["pulse"])
361358
solution_dict["sequence"] = Sequence.from_dict(solution_dict["sequence"])
362-
if solution_dict["focus"] is not None:
363-
solution_dict["focus"] = Point.from_dict(solution_dict["focus"]) #TODO: Solution analysis needs a list, to interface with FocalPattern ?
359+
solution_dict["foci"] = [
360+
Point.from_dict(focus_dict)
361+
for focus_dict in solution_dict["foci"]
362+
]
364363
if solution_dict["target"] is not None:
365364
solution_dict["target"] = Point.from_dict(solution_dict["target"])
366365

src/openlifu/util/units.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,12 @@ def rescale_coords(data_arr: Dataset, units: str) -> Dataset:
211211
rescaled = data_arr.copy(deep=True)
212212
for coord_key in data_arr.coords:
213213
curr_coord_attrs = rescaled[coord_key].attrs
214-
curr_coord_units = curr_coord_attrs['units']
215-
scale = getunitconversion(curr_coord_units, units)
216-
curr_coord_rescaled = scale*rescaled[coord_key].data
217-
rescaled = rescaled.assign_coords({coord_key: (coord_key, curr_coord_rescaled, curr_coord_attrs)})
218-
rescaled[coord_key].attrs['units'] = units
214+
if 'units' in curr_coord_attrs:
215+
curr_coord_units = curr_coord_attrs['units']
216+
scale = getunitconversion(curr_coord_units, units)
217+
curr_coord_rescaled = scale*rescaled[coord_key].data
218+
rescaled = rescaled.assign_coords({coord_key: (coord_key, curr_coord_rescaled, curr_coord_attrs)})
219+
rescaled[coord_key].attrs['units'] = units
219220

220221
return rescaled
221222

@@ -235,7 +236,8 @@ def get_ndgrid_from_arr(data_arr: Dataset) -> np.ndarray:
235236
ordered_key = data_arr[first_data_key].dims
236237
all_coord = []
237238
for coord_key in ordered_key:
238-
all_coord += [data_arr.coords[coord_key].data]
239+
if 'units' in data_arr[coord_key].attrs:
240+
all_coord += [data_arr.coords[coord_key].data]
239241
ndgrid = np.stack(np.meshgrid(*all_coord, indexing="ij"), axis=-1)
240242

241243
return ndgrid

tests/resources/example_db/subjects/example_subject/sessions/example_session/solutions/example_solution/example_solution.json

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,36 @@
66
"created_on": "2024-01-30T09:18:11",
77
"description": "Example plan created 30-Jan-2024 09:16:02",
88
"delays": [
9-
7.139258974920841e-7, 1.164095583074107e-6, 1.4321977043056202e-6,
10-
1.5143736925332023e-6, 1.4094191657896087e-6, 1.1188705305994071e-6,
11-
6.46894718323139e-7, 0.0, 1.2683760653622907e-6, 1.7251583088871662e-6,
12-
1.9972746364594766e-6, 2.0806926330138847e-6, 1.974152793990993e-6,
13-
1.6792617729461087e-6, 1.2003736682835394e-6, 5.442792226777689e-7,
14-
1.6425243839760022e-6, 2.1038804054306437e-6, 2.378775260158848e-6,
15-
2.4630532471222807e-6, 2.3554157329476133e-6, 2.0575192329568598e-6,
16-
1.5738505533232074e-6, 9.114003043615498e-7, 1.8309933020916417e-6,
17-
2.29468834620898e-6, 2.571004767271362e-6, 2.655722845121427e-6,
18-
2.5475236155907678e-6, 2.248089500472459e-6, 1.7619762095269915e-6,
19-
1.0962779929139644e-6, 1.8309933020916417e-6, 2.29468834620898e-6,
20-
2.571004767271362e-6, 2.655722845121427e-6, 2.5475236155907678e-6,
21-
2.248089500472459e-6, 1.7619762095269915e-6, 1.0962779929139644e-6,
22-
1.6425243839760022e-6, 2.1038804054306437e-6, 2.378775260158848e-6,
23-
2.4630532471222807e-6, 2.3554157329476133e-6, 2.0575192329568598e-6,
24-
1.5738505533232074e-6, 9.114003043615498e-7, 1.2683760653622907e-6,
25-
1.7251583088871662e-6, 1.9972746364594766e-6, 2.0806926330138847e-6,
26-
1.974152793990993e-6, 1.6792617729461087e-6, 1.2003736682835394e-6,
27-
5.442792226777689e-7, 7.139258974920841e-7, 1.164095583074107e-6,
28-
1.4321977043056202e-6, 1.5143736925332023e-6, 1.4094191657896087e-6,
29-
1.1188705305994071e-6, 6.46894718323139e-7, 0.0
9+
[
10+
7.139258974920841e-7, 1.164095583074107e-6, 1.4321977043056202e-6,
11+
1.5143736925332023e-6, 1.4094191657896087e-6, 1.1188705305994071e-6,
12+
6.46894718323139e-7, 0.0, 1.2683760653622907e-6, 1.7251583088871662e-6,
13+
1.9972746364594766e-6, 2.0806926330138847e-6, 1.974152793990993e-6,
14+
1.6792617729461087e-6, 1.2003736682835394e-6, 5.442792226777689e-7,
15+
1.6425243839760022e-6, 2.1038804054306437e-6, 2.378775260158848e-6,
16+
2.4630532471222807e-6, 2.3554157329476133e-6, 2.0575192329568598e-6,
17+
1.5738505533232074e-6, 9.114003043615498e-7, 1.8309933020916417e-6,
18+
2.29468834620898e-6, 2.571004767271362e-6, 2.655722845121427e-6,
19+
2.5475236155907678e-6, 2.248089500472459e-6, 1.7619762095269915e-6,
20+
1.0962779929139644e-6, 1.8309933020916417e-6, 2.29468834620898e-6,
21+
2.571004767271362e-6, 2.655722845121427e-6, 2.5475236155907678e-6,
22+
2.248089500472459e-6, 1.7619762095269915e-6, 1.0962779929139644e-6,
23+
1.6425243839760022e-6, 2.1038804054306437e-6, 2.378775260158848e-6,
24+
2.4630532471222807e-6, 2.3554157329476133e-6, 2.0575192329568598e-6,
25+
1.5738505533232074e-6, 9.114003043615498e-7, 1.2683760653622907e-6,
26+
1.7251583088871662e-6, 1.9972746364594766e-6, 2.0806926330138847e-6,
27+
1.974152793990993e-6, 1.6792617729461087e-6, 1.2003736682835394e-6,
28+
5.442792226777689e-7, 7.139258974920841e-7, 1.164095583074107e-6,
29+
1.4321977043056202e-6, 1.5143736925332023e-6, 1.4094191657896087e-6,
30+
1.1188705305994071e-6, 6.46894718323139e-7, 0.0
31+
]
3032
],
3133
"apodizations": [
32-
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
33-
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
34-
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
34+
[
35+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
36+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
37+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
38+
]
3539
],
3640
"pulse": {
3741
"frequency": 500000,
@@ -44,15 +48,17 @@
4448
"pulse_train_interval": 1,
4549
"pulse_train_count": 1
4650
},
47-
"focus": {
48-
"id": "example_target",
49-
"name": "Example Target",
50-
"color": [1.0, 0.0, 0.0],
51-
"radius": 0.001,
52-
"position": [0.0, -0.0022437460888595447, 0.05518120697745499],
53-
"dims": ["lat", "ele", "ax"],
54-
"units": "m"
55-
},
51+
"foci": [
52+
{
53+
"id": "example_target",
54+
"name": "Example Target",
55+
"color": [1.0, 0.0, 0.0],
56+
"radius": 0.001,
57+
"position": [0.0, -0.0022437460888595447, 0.05518120697745499],
58+
"dims": ["lat", "ele", "ax"],
59+
"units": "m"
60+
}
61+
],
5662
"target": {
5763
"id": "example_target",
5864
"name": "Example Target",

tests/test_database.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,11 @@ def test_load_solution(example_database:Database, example_session:Session):
283283
assert example_solution.name == "Example Solution"
284284
assert "p_min" in example_solution.simulation_result.data_vars # ensure the xarray dataset got loaded too
285285

286+
# ensure the simulation and beamform data was loaded for all foci
287+
assert len(example_solution.simulation_result['focal_point_index']) == len(example_solution.foci)
288+
assert example_solution.delays.shape[0] == len(example_solution.foci)
289+
assert example_solution.apodizations.shape[0] == len(example_solution.foci)
290+
286291
def test_write_solution(example_database:Database, example_session:Session):
287292
solution = Solution(name="bleh", id='new_solution')
288293

tests/test_solution.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,34 +36,35 @@ def example_solution() -> Solution:
3636
transducer_id="trans_456",
3737
created_on=datetime(2024, 1, 1, 12, 0),
3838
description="This is a test solution for a unit test.",
39-
delays=np.array([0.0, 1.0, 2.0, 3.0]),
40-
apodizations=np.array([0.5, 0.75, 1.0, 0.85]),
39+
delays=np.array([[0.0, 1.0, 2.0, 3.0]]),
40+
apodizations=np.array([[0.5, 0.75, 1.0, 0.85]]),
4141
pulse=Pulse(frequency=42),
4242
sequence=Sequence(pulse_count=27),
43-
focus=Point(id="test_focus_point"),
43+
foci=[Point(id="test_focus_point")],
4444
target=Point(id="test_target_point"),
4545
simulation_result=xa.Dataset(
4646
{
4747
'p_min': xa.DataArray(
48-
data=rng.random((3, 2, 3)),
49-
dims=["x", "y", "z"],
48+
data=rng.random((1, 3, 2, 3)),
49+
dims=["focal_point_index", "x", "y", "z"],
5050
attrs={'units': "Pa"}
5151
),
5252
'p_max': xa.DataArray(
53-
data=rng.random((3, 2, 3)),
54-
dims=["x", "y", "z"],
53+
data=rng.random((1, 3, 2, 3)),
54+
dims=["focal_point_index", "x", "y", "z"],
5555
attrs={'units': "Pa"}
5656
),
5757
'ita': xa.DataArray(
58-
data=rng.random((3, 2, 3)),
59-
dims=["x", "y", "z"],
58+
data=rng.random((1, 3, 2, 3)),
59+
dims=["focal_point_index", "x", "y", "z"],
6060
attrs={'units': "W/cm^2"}
6161
)
6262
},
6363
coords={
6464
'x': xa.DataArray(dims=["x"], data=np.linspace(0, 1, 3), attrs={'units': "m"}),
6565
'y': xa.DataArray(dims=["y"], data=np.linspace(0, 1, 2), attrs={'units': "m"}),
66-
'z': xa.DataArray(dims=["z"], data=np.linspace(0, 1, 3), attrs={'units': "m"})
66+
'z': xa.DataArray(dims=["z"], data=np.linspace(0, 1, 3), attrs={'units': "m"}),
67+
'focal_point_index': [0]
6768
}
6869
),
6970
)
@@ -110,6 +111,14 @@ def test_save_load_solution_custom_dataset_filepath(example_solution: Solution,
110111
assert dataclasses_are_equal(Solution.from_files(json_filepath, nc_filepath), example_solution)
111112

112113

114+
def test_num_foci(example_solution:Solution):
115+
"""Ensure that the number of foci in the test solution matches the number of foci provided in the simuluation and beamform data.
116+
(This is more checking correctness of the test example rather than correctness of code, but it is important.)"""
117+
assert len(example_solution.simulation_result['focal_point_index']) == len(example_solution.foci)
118+
assert example_solution.delays.shape[0] == len(example_solution.foci)
119+
assert example_solution.apodizations.shape[0] == len(example_solution.foci)
120+
121+
113122
def test_solution_analysis(example_solution: Solution, example_transducer: Transducer):
114123
"""Test that a solution output can be analyzed."""
115124
example_solution.analyze(example_transducer)

0 commit comments

Comments
 (0)