Skip to content

Commit 7d74ce1

Browse files
Incorporate virtual fit results into Session (#179)
Virtual fit results take the form of a dictionary mapping target IDs to pairs `(approval, transforms)`, where: `approval` is a boolean indicating whether the virtual fit for that target has been approved, and `transforms` is a list of transducer transforms resulting from the virtual fit for that target. The idea is that the list of transforms would be ordered from best to worst, and should of course contain at least one transform.
1 parent 2f7cd69 commit 7d74ce1

File tree

4 files changed

+74
-25
lines changed

4 files changed

+74
-25
lines changed

src/openlifu/db/database.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,17 @@ def write_session(self, subject:Subject, session:Session, on_conflict=OnConflict
153153
if session.subject_id != subject.id:
154154
raise ValueError("IDs do not match between the given subject and the subject referenced in the session.")
155155

156-
# Validate the approved fit target ID
157-
if (
158-
session.virtual_fit_approval_for_target_id is not None
159-
and session.virtual_fit_approval_for_target_id not in [target.id for target in session.targets]
160-
):
161-
raise ValueError(
162-
f"The provided virtual_fit_approval_for_target_id of {session.virtual_fit_approval_for_target_id} is not"
163-
" in this session's list of targets."
164-
)
156+
# Validate the virtual fit results
157+
for target_id, (_, transforms) in session.virtual_fit_results.items():
158+
if target_id not in [target.id for target in session.targets]:
159+
raise ValueError(
160+
f"The virtual_fit_results of session {session.id} references a target {target_id} that is not"
161+
" in the session's list of targets."
162+
)
163+
if len(transforms)<1:
164+
raise ValueError(
165+
f"The virtual_fit_results of session {session.id} provides no transforms for target {target_id}."
166+
)
165167

166168
# Check if the session already exists in the database
167169
session_ids = self.get_session_ids(subject.id)

src/openlifu/db/session.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import copy
12
import json
23
from dataclasses import asdict, dataclass, field
34
from datetime import datetime
45
from pathlib import Path
5-
from typing import Dict, List, Optional
6+
from typing import Dict, List, Optional, Tuple
67

78
import numpy as np
89

@@ -22,7 +23,7 @@ class ArrayTransform:
2223

2324
units : str
2425
"""The units of the space on which to apply the transform matrix , e.g. "mm"
25-
(In order to apply the transform to transducer points,.
26+
(In order to apply the transform to transducer points,
2627
first represent the points in these units.)"""
2728

2829
@dataclass
@@ -68,9 +69,15 @@ class Session:
6869
attrs: dict = field(default_factory=dict)
6970
"""Dictionary of additional custom attributes to save to the session"""
7071

71-
virtual_fit_approval_for_target_id: Optional[str] = None
72-
"""Approval state of virtual fit. `None` if there is no approval, otherwise this is the ID
73-
of the target for which virtual fitting has been marked approved."""
72+
virtual_fit_results: Dict[str,Tuple[bool,List[ArrayTransform]]] = field(default_factory=dict)
73+
"""Virtual fit results. This is a dictionary mapping target IDs to pairs `(approval, transforms)`,
74+
where:
75+
`approval` is a boolean indicating whether the virtual fit for that target has been approved, and
76+
`transforms` is a list of transducer transforms resulting from the virtual fit for that target.
77+
78+
The idea is that the list of transforms would be ordered from best to worst, and should of course
79+
contain at least one transform.
80+
"""
7481

7582
transducer_tracking_approved: Optional[bool] = False
7683
"""Approval state of transducer tracking. `True` means the user has provided some kind of
@@ -128,6 +135,12 @@ def from_dict(d:Dict):
128135
d['targets'] = [Point.from_dict(d['targets'])]
129136
elif isinstance(d['targets'], Point):
130137
d['targets'] = [d['targets']]
138+
if 'virtual_fit_results' in d:
139+
for target_id,(approval,transforms) in d['virtual_fit_results'].items():
140+
d['virtual_fit_results'][target_id] = (
141+
approval,
142+
[ArrayTransform(t_dict["matrix"], t_dict["units"]) for t_dict in transforms],
143+
)
131144
if isinstance(d['markers'], list):
132145
if len(d['markers'])>0 and isinstance(d['markers'][0], dict):
133146
d['markers'] = [Point.from_dict(p) for p in d['markers']]
@@ -143,14 +156,20 @@ def to_dict(self):
143156
144157
:returns: Dictionary of session parameters
145158
"""
146-
d = self.__dict__.copy()
159+
d = copy.deepcopy(self.__dict__) # Deep copy needed so that we don't modify the internals of self below
147160
d['date_created'] = d['date_created'].isoformat()
148161
d['date_modified'] = d['date_modified'].isoformat()
149162
d['targets'] = [p.to_dict() for p in d['targets']]
150163
d['markers'] = [p.to_dict() for p in d['markers']]
151164

152165
d['array_transform'] = asdict(d['array_transform'])
153166

167+
for target_id,(approval,transforms) in d['virtual_fit_results'].items():
168+
d['virtual_fit_results'][target_id] = (
169+
approval,
170+
[asdict(t) for t in transforms],
171+
)
172+
154173
return d
155174

156175
@staticmethod

tests/resources/example_db/subjects/example_subject/sessions/example_session/example_session.json

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@
2727
},
2828
"attrs": {},
2929
"date_modified": "2024-04-09 11:27:30",
30-
"virtual_fit_approval_for_target_id": "example_target",
30+
"virtual_fit_results": {
31+
"example_target": [
32+
true,
33+
[
34+
{
35+
"matrix": [
36+
[1.1, 0, 0, 0],
37+
[0, 1.2, 0, 0],
38+
[0, 0, 1.3, 0],
39+
[0, 0.05, 0, 1.4]
40+
],
41+
"units": "mm"
42+
}
43+
]
44+
]
45+
},
3146
"transducer_tracking_approved": false
3247
}

tests/test_database.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@
33
from contextlib import nullcontext as does_not_raise
44
from datetime import datetime, timedelta
55
from pathlib import Path
6-
from typing import Optional
6+
from typing import List
77
from unittest.mock import patch
88

9+
import numpy as np
910
import pytest
1011
from helpers import dataclasses_are_equal
1112
from vtk import vtkImageData, vtkPolyData
1213

1314
from openlifu import Point, Solution
1415
from openlifu.db import Session, Subject, User
1516
from openlifu.db.database import Database, OnConflictOpts
17+
from openlifu.db.session import ArrayTransform
1618
from openlifu.photoscan import Photoscan
1719
from openlifu.plan import Protocol, Run
1820
from openlifu.xdc import Transducer
@@ -347,26 +349,37 @@ def test_write_session_mismatched_id(example_database: Database, example_subject
347349
example_database.write_session(example_subject, session)
348350

349351
@pytest.mark.parametrize(
350-
("virtual_fit_approval_for_target_id", "expectation"),
352+
("target_ids", "numbers_of_transforms", "expectation"),
351353
[
352-
(None, does_not_raise()), # see https://docs.pytest.org/en/6.2.x/example/parametrize.html#parametrizing-conditional-raising
353-
("an_existing_target_id", does_not_raise()),
354-
("bogus_target_id", pytest.raises(ValueError, match="virtual_fit_approval_for_target_id.*not in")),
354+
# see https://docs.pytest.org/en/6.2.x/example/parametrize.html#parametrizing-conditional-raising
355+
([], [], does_not_raise()),
356+
(["an_existing_target_id"], [1], does_not_raise()),
357+
(["an_existing_target_id"], [2], does_not_raise()),
358+
(["bogus_target_id"], [1], pytest.raises(ValueError, match="references a target")),
359+
(["an_existing_target_id", "bogus_target_id"], [1,1], pytest.raises(ValueError, match="references a target")),
360+
(["an_existing_target_id"], [0], pytest.raises(ValueError, match="provides no transforms")),
355361
]
356362
)
357363
def test_write_session_with_invalid_fit_approval(
358364
example_database: Database,
359365
example_subject: Subject,
360-
virtual_fit_approval_for_target_id: Optional[str],
366+
target_ids: List[str],
367+
numbers_of_transforms: List[int],
361368
expectation,
362369
):
363-
"""Verify that writing a session with fit approval raises the invalid target error if and only if the
364-
target being approved does not exist."""
370+
"""Verify that write_session complains appropriately about invalid virtual fit results"""
371+
rng = np.random.default_rng()
365372
session = Session(
366373
id="unique_id_2764592837465",
367374
subject_id=example_subject.id,
368375
targets=[Point(id="an_existing_target_id")],
369-
virtual_fit_approval_for_target_id=virtual_fit_approval_for_target_id,
376+
virtual_fit_results={
377+
target_id : (
378+
True,
379+
[ArrayTransform(matrix=rng.random(size=(4,4)),units="mm") for _ in range(num_transforms)],
380+
)
381+
for target_id, num_transforms in zip(target_ids, numbers_of_transforms)
382+
},
370383
)
371384
with expectation:
372385
example_database.write_session(example_subject, session)

0 commit comments

Comments
 (0)