Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions db_dvc.dvc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
outs:
- md5: f9fe2e59c4e854484b4e922a719bdfa7.dir
size: 803478753
- md5: d82c778a093a87d824ea111b6afa62fe.dir
size: 803480757
nfiles: 83
hash: md5
path: db_dvc
87 changes: 87 additions & 0 deletions scripts/standardize_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""This is a utility script for developers to read in and write back out the dvc database.
It is useful for standardizing the format of the example dvc data, and also for checking that the database
mostly still works.

To use this script, install openlifu to a python environment and then run the script providing the database folder as an argument:

```
python standardize_database.py db_dvc/
```

A couple of known issues to watch out for:
- The date_modified of Sessions gets updated, as it should, when this is run. But we don't care about that change.
- The netCDF simulation output files (.nc files) are modified for some reason each time they are written out. It's probably a similar
thing going on with some kind of timestamp being embedded in the file.
"""

import pathlib
import shutil
import sys
import tempfile

from openlifu.db import Database
from openlifu.db.database import OnConflictOpts

if len(sys.argv) != 2:
raise RuntimeError("Provide exactly one argument: the path to the database folder.")
db = Database(sys.argv[1])

db.write_protocol_ids(db.get_protocol_ids())
for protocol_id in db.get_protocol_ids():
protocol = db.load_protocol(protocol_id)
assert protocol_id == protocol.id
db.write_protocol(protocol, on_conflict=OnConflictOpts.OVERWRITE)

db.write_transducer_ids(db.get_transducer_ids())
for transducer_id in db.get_transducer_ids():
transducer = db.load_transducer(transducer_id)
assert transducer_id == transducer.id
db.write_transducer(transducer, on_conflict=OnConflictOpts.OVERWRITE)

db.write_subject_ids(db.get_subject_ids())
for subject_id in db.get_subject_ids():
subject = db.load_subject(subject_id)
assert subject_id == subject.id
db.write_subject(subject, on_conflict=OnConflictOpts.OVERWRITE)

db.write_volume_ids(subject_id, db.get_volume_ids(subject_id))
for volume_id in db.get_volume_ids(subject_id):
volume_info = db.get_volume_info(subject_id, volume_id)
assert volume_info["id"] == volume_id
volume_data_abspath = pathlib.Path(volume_info["data_abspath"])

# The weird file move here is because of a quirk in Database:
# - you can't just edit the volume metadata, you have to write the metadata json and volume data file together
# - if you try to provide the volume_data_abspath as the data path you get a SameFileError from shutil which
# refuses to do the copy. These things can be fixed but it's a niche use case so I'd rather work around it in this script.
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
moved_vol_abspath = tmpdir / volume_data_abspath.name
shutil.move(volume_data_abspath, moved_vol_abspath)
db.write_volume(subject_id, volume_id, volume_info["name"], moved_vol_abspath, on_conflict=OnConflictOpts.OVERWRITE)

session_ids = db.get_session_ids(subject.id)
db.write_session_ids(subject_id, session_ids)
for session_id in session_ids:
session = db.load_session(subject, session_id)
assert session.id == session_id
assert session.subject_id == subject.id
db.write_session(subject, session, on_conflict=OnConflictOpts.OVERWRITE)

solution_ids = db.get_solution_ids(session.subject_id, session.id)
db.write_solution_ids(session, solution_ids)
for solution_id in solution_ids:
solution = db.load_solution(session, solution_id)
assert solution.id == solution_id
assert solution.simulation_result['p_min'].shape[0] == solution.num_foci()
db.write_solution(session, solution, on_conflict=OnConflictOpts.OVERWRITE)

run_ids = db.get_run_ids(subject_id, session_id)
db.write_run_ids(subject_id, session_id, run_ids)
# (Runs are read only at the moment so it's just the runs.json and no individual runs to standardize)

db.write_user_ids(db.get_user_ids())
for user_id in db.get_user_ids():
user = db.load_user(user_id)
assert user_id == user.id
db.write_user(user, on_conflict=OnConflictOpts.OVERWRITE)
20 changes: 11 additions & 9 deletions src/openlifu/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,17 @@ def write_session(self, subject:Subject, session:Session, on_conflict=OnConflict
if session.subject_id != subject.id:
raise ValueError("IDs do not match between the given subject and the subject referenced in the session.")

# Validate the approved fit target ID
if (
session.virtual_fit_approval_for_target_id is not None
and session.virtual_fit_approval_for_target_id not in [target.id for target in session.targets]
):
raise ValueError(
f"The provided virtual_fit_approval_for_target_id of {session.virtual_fit_approval_for_target_id} is not"
" in this session's list of targets."
)
# Validate the virtual fit results
for target_id, (_, transforms) in session.virtual_fit_results.items():
if target_id not in [target.id for target in session.targets]:
raise ValueError(
f"The virtual_fit_results of session {session.id} references a target {target_id} that is not"
" in the session's list of targets."
)
if len(transforms)<1:
raise ValueError(
f"The virtual_fit_results of session {session.id} provides no transforms for target {target_id}."
)

# Check if the session already exists in the database
session_ids = self.get_session_ids(subject.id)
Expand Down
35 changes: 27 additions & 8 deletions src/openlifu/db/session.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import copy
import json
from dataclasses import asdict, dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

import numpy as np

Expand All @@ -22,7 +23,7 @@ class ArrayTransform:

units : str
"""The units of the space on which to apply the transform matrix , e.g. "mm"
(In order to apply the transform to transducer points,.
(In order to apply the transform to transducer points,
first represent the points in these units.)"""

@dataclass
Expand Down Expand Up @@ -68,9 +69,17 @@ class Session:
attrs: dict = field(default_factory=dict)
"""Dictionary of additional custom attributes to save to the session"""

virtual_fit_approval_for_target_id: Optional[str] = None
"""Approval state of virtual fit. `None` if there is no approval, otherwise this is the ID
of the target for which virtual fitting has been marked approved."""
virtual_fit_results: Dict[str,Tuple[bool,List[ArrayTransform]]] = field(default_factory=dict)
"""Virtual fit results. This is a dictionary mapping target IDs to pairs `(approval, transforms)`,
where:

`approval` is a boolean indicating whether the virtual fit for that target has been approved, and
`transforms` is a list of transducer transforms resulting from the virtual fit for that target.

The idea is that the list of transforms would be ordered from best to worst, and should of course
contain at least one transform. The "approval" is intended to apply to the first transform in the list
only. None of the other transforms in the list are considered to be approved.
"""

transducer_tracking_approved: Optional[bool] = False
"""Approval state of transducer tracking. `True` means the user has provided some kind of
Expand Down Expand Up @@ -98,7 +107,6 @@ def from_file(filename):
Create a Session from a file

:param filename: Name of the file to read
:param db: Database object
:returns: Session object
"""
with open(filename) as f:
Expand All @@ -110,7 +118,6 @@ def from_dict(d:Dict):
Create a session from a dictionary

:param d: Dictionary of session parameters
:param db: Database object
:returns: Session object
"""
if 'date_created' in d:
Expand All @@ -128,6 +135,12 @@ def from_dict(d:Dict):
d['targets'] = [Point.from_dict(d['targets'])]
elif isinstance(d['targets'], Point):
d['targets'] = [d['targets']]
if 'virtual_fit_results' in d:
for target_id,(approval,transforms) in d['virtual_fit_results'].items():
d['virtual_fit_results'][target_id] = (
approval,
[ArrayTransform(np.array(t_dict["matrix"]), t_dict["units"]) for t_dict in transforms],
)
if isinstance(d['markers'], list):
if len(d['markers'])>0 and isinstance(d['markers'][0], dict):
d['markers'] = [Point.from_dict(p) for p in d['markers']]
Expand All @@ -143,14 +156,20 @@ def to_dict(self):

:returns: Dictionary of session parameters
"""
d = self.__dict__.copy()
d = copy.deepcopy(self.__dict__) # Deep copy needed so that we don't modify the internals of self below
d['date_created'] = d['date_created'].isoformat()
d['date_modified'] = d['date_modified'].isoformat()
d['targets'] = [p.to_dict() for p in d['targets']]
d['markers'] = [p.to_dict() for p in d['markers']]

d['array_transform'] = asdict(d['array_transform'])

for target_id,(approval,transforms) in d['virtual_fit_results'].items():
d['virtual_fit_results'][target_id] = (
approval,
[asdict(t) for t in transforms],
)

return d

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@
},
"attrs": {},
"date_modified": "2024-04-09 11:27:30",
"virtual_fit_approval_for_target_id": "example_target",
"virtual_fit_results": {
"example_target": [
true,
[
{
"matrix": [
[1.1, 0, 0, 0],
[0, 1.2, 0, 0],
[0, 0, 1.3, 0],
[0, 0.05, 0, 1.4]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious -- why did you make the 4th row of this matrix like this, just to test, or is this an extra scaling transformation across all dimensions? I assume you apply this like $matrix@(x,y,z,1)^T$. I take it that in this case, you would then divide everything by $0.05y + 1.4$ after the scale and translation (in this case, no translation)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, you would apply it like that. I think typically you would want to make the last row [0,0,0,1] which is probably what you were expecting. In this case I wasn't thinking anything fancy... I just want to stick any old 4x4 array in here and have tests that the array data is handled correctly -- not caring about the interpretation as an "affine transform".

],
"units": "mm"
}
]
]
},
"transducer_tracking_approved": false
}
40 changes: 30 additions & 10 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
from contextlib import nullcontext as does_not_raise
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional
from typing import List, Optional
from unittest.mock import patch

import numpy as np
import pytest
from helpers import dataclasses_are_equal
from vtk import vtkImageData, vtkPolyData

from openlifu import Point, Solution
from openlifu.db import Session, Subject, User
from openlifu.db.database import Database, OnConflictOpts
from openlifu.db.session import ArrayTransform
from openlifu.photoscan import Photoscan
from openlifu.plan import Protocol, Run
from openlifu.xdc import Transducer
Expand Down Expand Up @@ -347,30 +349,48 @@ def test_write_session_mismatched_id(example_database: Database, example_subject
example_database.write_session(example_subject, session)

@pytest.mark.parametrize(
("virtual_fit_approval_for_target_id", "expectation"),
("target_ids", "numbers_of_transforms", "expectation"),
[
(None, does_not_raise()), # see https://docs.pytest.org/en/6.2.x/example/parametrize.html#parametrizing-conditional-raising
("an_existing_target_id", does_not_raise()),
("bogus_target_id", pytest.raises(ValueError, match="virtual_fit_approval_for_target_id.*not in")),
# see https://docs.pytest.org/en/6.2.x/example/parametrize.html#parametrizing-conditional-raising
([], [], does_not_raise()),
(["an_existing_target_id"], [1], does_not_raise()),
(["an_existing_target_id"], [2], does_not_raise()),
(["bogus_target_id"], [1], pytest.raises(ValueError, match="references a target")),
(["an_existing_target_id", "bogus_target_id"], [1,1], pytest.raises(ValueError, match="references a target")),
(["an_existing_target_id"], [0], pytest.raises(ValueError, match="provides no transforms")),
]
)
def test_write_session_with_invalid_fit_approval(
def test_write_session_with_invalid_fit_results(
example_database: Database,
example_subject: Subject,
virtual_fit_approval_for_target_id: Optional[str],
target_ids: List[str],
numbers_of_transforms: List[int],
expectation,
):
"""Verify that writing a session with fit approval raises the invalid target error if and only if the
target being approved does not exist."""
"""Verify that write_session complains appropriately about invalid virtual fit results"""
rng = np.random.default_rng()
session = Session(
id="unique_id_2764592837465",
subject_id=example_subject.id,
targets=[Point(id="an_existing_target_id")],
virtual_fit_approval_for_target_id=virtual_fit_approval_for_target_id,
virtual_fit_results={
target_id : (
True,
[ArrayTransform(matrix=rng.random(size=(4,4)),units="mm") for _ in range(num_transforms)],
)
for target_id, num_transforms in zip(target_ids, numbers_of_transforms)
},
)
with expectation:
example_database.write_session(example_subject, session)

def test_session_arrays_read_correctly(example_session:Session):
"""Verify that session data that is supposed to be array type is actually array type after reading from json"""
assert isinstance(example_session.array_transform.matrix, np.ndarray)
for _, (_, array_transforms) in example_session.virtual_fit_results.items():
for array_transform in array_transforms:
assert isinstance(array_transform.matrix, np.ndarray)

@pytest.mark.parametrize("compact_representation", [True, False])
def test_serialize_deserialize_session(example_session : Session, compact_representation:bool):
reconstructed_session = example_session.from_json(example_session.to_json(compact_representation))
Expand Down
Loading