Skip to content

Commit 8f86a40

Browse files
Change Solution.focus to a list of foci (#125)
This includes * changing Solution.focus to Solution.foci and making that be a *list* of Points rather than a Point. * changing the xarray simulation data to have an additional coord "focal_point_index"
1 parent c53b029 commit 8f86a40

File tree

5 files changed

+30
-17
lines changed

5 files changed

+30
-17
lines changed

src/openlifu/plan/solution.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import asdict, dataclass, field
44
from datetime import datetime
55
from pathlib import Path
6-
from typing import Optional
6+
from typing import List, Optional
77

88
import numpy as np
99
import xarray
@@ -51,8 +51,11 @@ class Solution:
5151
"""Pulse to send to the transducer when running sonication"""
5252
sequence: Sequence = field(default_factory=Sequence)
5353
"""Pulse sequence to use when running sonication"""
54-
focus: Optional[Point] = None
55-
"""Point that is being focused on in this Solution; part of the focal pattern of the target"""
54+
foci: List[Point] = field(default_factory=list)
55+
"""Points that are focused on in this Solution due to the focal pattern around the target.
56+
Each item in this list is a unique point from the focal pattern, and the pulse sequence is
57+
what determines how many times each point will be used.
58+
"""
5659

5760
# there was "target_id" in the matlab software, but here we do not have the concept of a target ID.
5861
# I believe this was only needed in the matlab software because solutions were organized by target rather
@@ -109,8 +112,10 @@ def from_json(json_string : str, simulation_result: Optional[xarray.Dataset]=Non
109112
solution_dict["apodizations"] = np.array(solution_dict["apodizations"])
110113
solution_dict["pulse"] = Pulse.from_dict(solution_dict["pulse"])
111114
solution_dict["sequence"] = Sequence.from_dict(solution_dict["sequence"])
112-
if solution_dict["focus"] is not None:
113-
solution_dict["focus"] = Point.from_dict(solution_dict["focus"])
115+
solution_dict["foci"] = [
116+
Point.from_dict(focus_dict)
117+
for focus_dict in solution_dict["foci"]
118+
]
114119
if solution_dict["target"] is not None:
115120
solution_dict["target"] = Point.from_dict(solution_dict["target"])
116121

tests/resources/example_db/subjects/example_subject/sessions/example_session/solutions/example_solution/example_solution.json

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,17 @@
4444
"pulse_train_interval": 1,
4545
"pulse_train_count": 1
4646
},
47-
"focus": {
48-
"id": "example_target",
49-
"name": "Example Target",
50-
"color": [1.0, 0.0, 0.0],
51-
"radius": 0.001,
52-
"position": [0.0, -0.0022437460888595447, 0.05518120697745499],
53-
"dims": ["lat", "ele", "ax"],
54-
"units": "m"
55-
},
47+
"foci": [
48+
{
49+
"id": "example_target",
50+
"name": "Example Target",
51+
"color": [1.0, 0.0, 0.0],
52+
"radius": 0.001,
53+
"position": [0.0, -0.0022437460888595447, 0.05518120697745499],
54+
"dims": ["lat", "ele", "ax"],
55+
"units": "m"
56+
}
57+
],
5658
"target": {
5759
"id": "example_target",
5860
"name": "Example Target",

tests/test_database.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def test_load_solution(example_database:Database, example_session:Session):
178178
example_solution = example_database.load_solution(example_session, "example_solution")
179179
assert example_solution.name == "Example Solution"
180180
assert "p_min" in example_solution.simulation_result.data_vars # ensure the xarray dataset got loaded too
181+
assert len(example_solution.simulation_result['focal_point_index']) == len(example_solution.foci) # ensure simulation data was loaded for all foci
181182

182183
def test_write_solution(example_database:Database, example_session:Session):
183184
solution = Solution(name="bleh", id='new_solution')

tests/test_solution.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ def example_solution() -> Solution:
2323
apodizations=np.array([0.5, 0.75, 1.0, 0.85]),
2424
pulse=Pulse(frequency=42),
2525
sequence=Sequence(pulse_count=27),
26-
focus=Point(id = "test_focus_point"),
26+
foci=[Point(id = "test_focus_point")],
2727
target=Point(id = "test_target_point"),
2828
simulation_result=xarray.Dataset(
29-
{'pnp': (['x', 'y', 'z'], rng.random((3, 2, 3)))},
30-
coords={'x': np.linspace(0, 1, 3), 'y': np.linspace(0, 1, 2), 'z': np.linspace(0, 1, 3)},
29+
{'p_min': (['focal_point_index', 'x', 'y', 'z'], rng.random((1, 3, 2, 3)))},
30+
coords={'x': np.linspace(0, 1, 3), 'y': np.linspace(0, 1, 2), 'z': np.linspace(0, 1, 3), 'focal_point_index': [0]},
3131
),
3232
)
3333

@@ -67,3 +67,8 @@ def test_save_load_solution_custom_dataset_filepath(example_solution:Solution, t
6767
nc_filepath = tmp_path/"some_other_directory"/"sim_output.nc"
6868
example_solution.to_files(json_filepath, nc_filepath)
6969
assert dataclasses_are_equal(Solution.from_files(json_filepath, nc_filepath), example_solution)
70+
71+
def test_num_foci(example_solution:Solution):
72+
"""Ensure that the number of foci in the test solution matches the number of foci provided in its simulation data.
73+
(This is more checking correctness of the test example rather than correctness of code, but it is important.)"""
74+
assert len(example_solution.simulation_result['focal_point_index']) == len(example_solution.foci)

0 commit comments

Comments
 (0)