Skip to content

Commit 324603e

Browse files
Convert VF transforms to correct type (#179)
Virtual fit transforms were being read in as lists and not properly converted to arrays. Fixed, and added a test that would have caught the issue.
1 parent 79cb01a commit 324603e

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/openlifu/db/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def from_dict(d:Dict):
139139
for target_id,(approval,transforms) in d['virtual_fit_results'].items():
140140
d['virtual_fit_results'][target_id] = (
141141
approval,
142-
[ArrayTransform(t_dict["matrix"], t_dict["units"]) for t_dict in transforms],
142+
[ArrayTransform(np.array(t_dict["matrix"]), t_dict["units"]) for t_dict in transforms],
143143
)
144144
if isinstance(d['markers'], list):
145145
if len(d['markers'])>0 and isinstance(d['markers'][0], dict):

tests/test_database.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def test_write_session_mismatched_id(example_database: Database, example_subject
360360
(["an_existing_target_id"], [0], pytest.raises(ValueError, match="provides no transforms")),
361361
]
362362
)
363-
def test_write_session_with_invalid_fit_approval(
363+
def test_write_session_with_invalid_fit_results(
364364
example_database: Database,
365365
example_subject: Subject,
366366
target_ids: List[str],
@@ -384,6 +384,13 @@ def test_write_session_with_invalid_fit_approval(
384384
with expectation:
385385
example_database.write_session(example_subject, session)
386386

387+
def test_session_arrays_read_correctly(example_session:Session):
388+
"""Verify that session data that is supposed to be array type is actually array type after reading from json"""
389+
assert isinstance(example_session.array_transform.matrix, np.ndarray)
390+
for _, (_, array_transforms) in example_session.virtual_fit_results.items():
391+
for array_transform in array_transforms:
392+
assert isinstance(array_transform.matrix, np.ndarray)
393+
387394
@pytest.mark.parametrize("compact_representation", [True, False])
388395
def test_serialize_deserialize_session(example_session : Session, compact_representation:bool):
389396
reconstructed_session = example_session.from_json(example_session.to_json(compact_representation))

0 commit comments

Comments
 (0)