Skip to content

Commit df168ed

Browse files
niksirbilochhhpre-commit-ci[bot]
committed
Qt widget for loading pose datasets as napari Points layers (#253)
* initialise napari plugin development * Create skeleton for napari plugin with collapsible widgets (#218) * initialise napari plugin development * initialise napari plugin development * create skeleton for napari plugin with collapsible widgets * add basic widget smoke tests and allow headless testing * do not depend on napari from pip * include napari option in install instructions * make meta_widget module private * pin atlasapi version to avoid unnecessary dependencies * pin napari >= 0.4.19 from conda-forge * switched to pip install of napari[all] * seperation of concerns in widget tests * add pytest-mock dev dependency * initialise napari plugin development * initialise napari plugin development * initialise napari plugin development * Added loader widget for poses * update widget tests * simplify dependency on brainglobe-utils * consistent monospace formatting for movement in public docstrings * get rid of code that's only relevant for displaying Tracks * enable visibility of napari layer tooltips * renamed widget to PosesLoader * make cmap optional in set_color_by method * wrote unit tests for napari convert module * wrote unit-tests for the layer styles module * linkcheck ignore zenodo redirects * move _sample_colormap out of PointsStyle class * small refactoring in the loader widget * Expand tests for loader widget * added comments and docstrings to napari plugin tests * refactored all napari tests into separate unit test folder * added napari-video to dependencies * replaced deprecated edge_width with border_width * got rid of widget pytest fixtures * remove duplicate word from docstring * remove napari-video dependency * include napari extras in docs requirements * add test for _on_browse_clicked method * getOpenFileName returns tuple, not str * simplify poses_to_napari_tracks Co-authored-by: Chang Huan Lo <changhuan.lo@ucl.ac.uk> * [pre-commit.ci] pre-commit autoupdate (#338) updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.9 → v0.7.2](astral-sh/ruff-pre-commit@v0.6.9...v0.7.2) - [github.com/pre-commit/mirrors-mypy: v1.11.2 → v1.13.0](pre-commit/mirrors-mypy@v1.11.2...v1.13.0) - [github.com/mgedmin/check-manifest: 0.49 → 0.50](mgedmin/check-manifest@0.49...0.50) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Implement `compute_speed` and `compute_path_length` (#280) * implement compute_speed and compute_path_length functions * added speed to existing kinematics unit test * rewrote compute_path_length with various nan policies * unit test compute_path_length across time ranges * fixed and refactor compute_path_length and its tests * fixed docstring for compute_path_length * Accept suggestion on docstring wording Co-authored-by: Chang Huan Lo <changhuan.lo@ucl.ac.uk> * Remove print statement from test Co-authored-by: Chang Huan Lo <changhuan.lo@ucl.ac.uk> * Ensure nan report is printed Co-authored-by: Chang Huan Lo <changhuan.lo@ucl.ac.uk> * adapt warning message match in test * change 'any' to 'all' * uniform wording across path length docstrings * (mostly) leave time range validation to xarray slice * refactored parameters for test across time ranges * simplified test for path lenght with nans * replace drop policy with ffill * remove B905 ruff rule * make pre-commit happy --------- Co-authored-by: Chang Huan Lo <changhuan.lo@ucl.ac.uk> * initialise napari plugin development * initialise napari plugin development * initialise napari plugin development * initialise napari plugin development * initialise napari plugin development * avoid redefining duplicate attributes in child dataclass * modify test case to match poses_to_napari_tracks simplification * expected_log_messages should be a subset of captured messages Co-authored-by: Chang Huan Lo <changhuan.lo@ucl.ac.uk> * fix typo Co-authored-by: Chang Huan Lo <changhuan.lo@ucl.ac.uk> * use names for Qwidgets * reorganised test_valid_poses_to_napari_tracks * parametrised layer style tests * delet integration test which was reintroduced after conflict resolution * added test about file filters * deleted obsolete loader widget file (had snuck back in due to conflict merging) * combine tests for button callouts Co-authored-by: Chang Huan Lo <changhuan.lo@ucl.ac.uk> * Simplify test_layer_style_as_kwargs Co-authored-by: Chang Huan Lo <changhuan.lo@ucl.ac.uk> --------- Co-authored-by: Chang Huan Lo <changhuan.lo@ucl.ac.uk> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9258710 commit df168ed

File tree

13 files changed

+711
-103
lines changed

13 files changed

+711
-103
lines changed

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
-e .
1+
-e .[napari]
22
ablog
33
linkify-it-py
44
myst-parser

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@
181181
"https://opensource.org/license/bsd-3-clause/", # to avoid odd 403 error
182182
]
183183

184+
184185
myst_url_schemes = {
185186
"http": None,
186187
"https": None,

movement/napari/_loader_widget.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

movement/napari/_loader_widgets.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""Widgets for loading movement datasets from file."""
2+
3+
import logging
4+
from pathlib import Path
5+
6+
from napari.settings import get_settings
7+
from napari.utils.notifications import show_warning
8+
from napari.viewer import Viewer
9+
from qtpy.QtWidgets import (
10+
QComboBox,
11+
QFileDialog,
12+
QFormLayout,
13+
QHBoxLayout,
14+
QLineEdit,
15+
QPushButton,
16+
QSpinBox,
17+
QWidget,
18+
)
19+
20+
from movement.io import load_poses
21+
from movement.napari.convert import poses_to_napari_tracks
22+
from movement.napari.layer_styles import PointsStyle
23+
24+
logger = logging.getLogger(__name__)
25+
26+
# Allowed poses file suffixes for each supported source software
27+
SUPPORTED_POSES_FILES = {
28+
"DeepLabCut": ["*.h5", "*.csv"],
29+
"LightningPose": ["*.csv"],
30+
"SLEAP": ["*.h5", "*.slp"],
31+
}
32+
33+
34+
class PosesLoader(QWidget):
35+
"""Widget for loading movement poses datasets from file."""
36+
37+
def __init__(self, napari_viewer: Viewer, parent=None):
38+
"""Initialize the loader widget."""
39+
super().__init__(parent=parent)
40+
self.viewer = napari_viewer
41+
self.setLayout(QFormLayout())
42+
# Create widgets
43+
self._create_source_software_widget()
44+
self._create_fps_widget()
45+
self._create_file_path_widget()
46+
self._create_load_button()
47+
# Enable layer tooltips from napari settings
48+
self._enable_layer_tooltips()
49+
50+
def _create_source_software_widget(self):
51+
"""Create a combo box for selecting the source software."""
52+
self.source_software_combo = QComboBox()
53+
self.source_software_combo.setObjectName("source_software_combo")
54+
self.source_software_combo.addItems(SUPPORTED_POSES_FILES.keys())
55+
self.layout().addRow("source software:", self.source_software_combo)
56+
57+
def _create_fps_widget(self):
58+
"""Create a spinbox for selecting the frames per second (fps)."""
59+
self.fps_spinbox = QSpinBox()
60+
self.fps_spinbox.setObjectName("fps_spinbox")
61+
self.fps_spinbox.setMinimum(1)
62+
self.fps_spinbox.setMaximum(1000)
63+
self.fps_spinbox.setValue(30)
64+
self.layout().addRow("fps:", self.fps_spinbox)
65+
66+
def _create_file_path_widget(self):
67+
"""Create a line edit and browse button for selecting the file path.
68+
69+
This allows the user to either browse the file system,
70+
or type the path directly into the line edit.
71+
"""
72+
# File path line edit and browse button
73+
self.file_path_edit = QLineEdit()
74+
self.file_path_edit.setObjectName("file_path_edit")
75+
self.browse_button = QPushButton("Browse")
76+
self.browse_button.setObjectName("browse_button")
77+
self.browse_button.clicked.connect(self._on_browse_clicked)
78+
# Layout for line edit and button
79+
self.file_path_layout = QHBoxLayout()
80+
self.file_path_layout.addWidget(self.file_path_edit)
81+
self.file_path_layout.addWidget(self.browse_button)
82+
self.layout().addRow("file path:", self.file_path_layout)
83+
84+
def _create_load_button(self):
85+
"""Create a button to load the file and add layers to the viewer."""
86+
self.load_button = QPushButton("Load")
87+
self.load_button.setObjectName("load_button")
88+
self.load_button.clicked.connect(lambda: self._on_load_clicked())
89+
self.layout().addRow(self.load_button)
90+
91+
def _on_browse_clicked(self):
92+
"""Open a file dialog to select a file."""
93+
file_suffixes = SUPPORTED_POSES_FILES[
94+
self.source_software_combo.currentText()
95+
]
96+
97+
file_path, _ = QFileDialog.getOpenFileName(
98+
self,
99+
caption="Open file containing predicted poses",
100+
filter=f"Poses files ({' '.join(file_suffixes)})",
101+
)
102+
103+
# A blank string is returned if the user cancels the dialog
104+
if not file_path:
105+
return
106+
107+
# Add the file path to the line edit (text field)
108+
self.file_path_edit.setText(file_path)
109+
110+
def _on_load_clicked(self):
111+
"""Load the file and add as a Points layer to the viewer."""
112+
fps = self.fps_spinbox.value()
113+
source_software = self.source_software_combo.currentText()
114+
file_path = self.file_path_edit.text()
115+
if file_path == "":
116+
show_warning("No file path specified.")
117+
return
118+
ds = load_poses.from_file(file_path, source_software, fps)
119+
120+
self.data, self.props = poses_to_napari_tracks(ds)
121+
logger.info("Converted poses dataset to a napari Tracks array.")
122+
logger.debug(f"Tracks array shape: {self.data.shape}")
123+
124+
self.file_name = Path(file_path).name
125+
self._add_points_layer()
126+
127+
self._set_playback_fps(fps)
128+
logger.debug(f"Set napari playback speed to {fps} fps.")
129+
130+
def _add_points_layer(self):
131+
"""Add the predicted poses to the viewer as a Points layer."""
132+
# Style properties for the napari Points layer
133+
points_style = PointsStyle(
134+
name=f"poses: {self.file_name}",
135+
properties=self.props,
136+
)
137+
# Color the points by individual if there are multiple individuals
138+
# Otherwise, color by keypoint
139+
n_individuals = len(self.props["individual"].unique())
140+
points_style.set_color_by(
141+
prop="individual" if n_individuals > 1 else "keypoint"
142+
)
143+
# Add the points layer to the viewer
144+
self.viewer.add_points(self.data[:, 1:], **points_style.as_kwargs())
145+
logger.info("Added poses dataset as a napari Points layer.")
146+
147+
@staticmethod
148+
def _set_playback_fps(fps: int):
149+
"""Set the playback speed for the napari viewer."""
150+
settings = get_settings()
151+
settings.application.playback_fps = fps
152+
153+
@staticmethod
154+
def _enable_layer_tooltips():
155+
"""Toggle on tooltip visibility for napari layers.
156+
157+
This nicely displays the layer properties as a tooltip
158+
when hovering over the layer in the napari viewer.
159+
"""
160+
settings = get_settings()
161+
settings.appearance.layer_tooltip_visibility = True

movement/napari/_meta_widget.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from brainglobe_utils.qtpy.collapsible_widget import CollapsibleWidgetContainer
44
from napari.viewer import Viewer
55

6-
from movement.napari._loader_widget import Loader
6+
from movement.napari._loader_widgets import PosesLoader
77

88

99
class MovementMetaWidget(CollapsibleWidgetContainer):
@@ -18,9 +18,9 @@ def __init__(self, napari_viewer: Viewer, parent=None):
1818
super().__init__()
1919

2020
self.add_widget(
21-
Loader(napari_viewer, parent=self),
21+
PosesLoader(napari_viewer, parent=self),
2222
collapsible=True,
23-
widget_title="Load data",
23+
widget_title="Load poses",
2424
)
2525

2626
self.loader = self.collapsible_widgets[0]

movement/napari/convert.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Conversion functions from ``movement`` datasets to napari layers."""
2+
3+
import logging
4+
5+
import numpy as np
6+
import pandas as pd
7+
import xarray as xr
8+
9+
# get logger
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def _construct_properties_dataframe(ds: xr.Dataset) -> pd.DataFrame:
14+
"""Construct a properties DataFrame from a ``movement`` dataset."""
15+
return pd.DataFrame(
16+
{
17+
"individual": ds.coords["individuals"].values,
18+
"keypoint": ds.coords["keypoints"].values,
19+
"time": ds.coords["time"].values,
20+
"confidence": ds["confidence"].values.flatten(),
21+
}
22+
)
23+
24+
25+
def poses_to_napari_tracks(ds: xr.Dataset) -> tuple[np.ndarray, pd.DataFrame]:
26+
"""Convert poses dataset to napari Tracks array and properties.
27+
28+
Parameters
29+
----------
30+
ds : xr.Dataset
31+
``movement`` dataset containing pose tracks, confidence scores,
32+
and associated metadata.
33+
34+
Returns
35+
-------
36+
data : np.ndarray
37+
napari Tracks array with shape (N, 4),
38+
where N is n_keypoints * n_individuals * n_frames
39+
and the 4 columns are (track_id, frame_idx, y, x).
40+
properties : pd.DataFrame
41+
DataFrame with properties (individual, keypoint, time, confidence).
42+
43+
Notes
44+
-----
45+
A corresponding napari Points array can be derived from the Tracks array
46+
by taking its last 3 columns: (frame_idx, y, x). See the documentation
47+
on the napari Tracks [1]_ and Points [2]_ layers.
48+
49+
References
50+
----------
51+
.. [1] https://napari.org/stable/howtos/layers/tracks.html
52+
.. [2] https://napari.org/stable/howtos/layers/points.html
53+
54+
"""
55+
n_frames = ds.sizes["time"]
56+
n_individuals = ds.sizes["individuals"]
57+
n_keypoints = ds.sizes["keypoints"]
58+
n_tracks = n_individuals * n_keypoints
59+
# Construct the napari Tracks array
60+
# Reorder axes to (individuals, keypoints, frames, xy)
61+
yx_cols = np.transpose(ds.position.values, (1, 2, 0, 3)).reshape(-1, 2)[
62+
:, [1, 0] # swap x and y columns
63+
]
64+
# Each keypoint of each individual is a separate track
65+
track_id_col = np.repeat(np.arange(n_tracks), n_frames).reshape(-1, 1)
66+
time_col = np.tile(np.arange(n_frames), (n_tracks)).reshape(-1, 1)
67+
data = np.hstack((track_id_col, time_col, yx_cols))
68+
# Construct the properties DataFrame
69+
# Stack 3 dimensions into a new single dimension named "tracks"
70+
ds_ = ds.stack(tracks=("individuals", "keypoints", "time"))
71+
properties = _construct_properties_dataframe(ds_)
72+
73+
return data, properties

movement/napari/layer_styles.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Dataclasses containing layer styles for napari."""
2+
3+
from dataclasses import dataclass, field
4+
5+
import numpy as np
6+
import pandas as pd
7+
from napari.utils.colormaps import ensure_colormap
8+
9+
DEFAULT_COLORMAP = "turbo"
10+
11+
12+
@dataclass
13+
class LayerStyle:
14+
"""Base class for napari layer styles."""
15+
16+
name: str
17+
properties: pd.DataFrame
18+
visible: bool = True
19+
blending: str = "translucent"
20+
21+
def as_kwargs(self) -> dict:
22+
"""Return the style properties as a dictionary of kwargs."""
23+
return self.__dict__
24+
25+
26+
@dataclass
27+
class PointsStyle(LayerStyle):
28+
"""Style properties for a napari Points layer."""
29+
30+
symbol: str = "disc"
31+
size: int = 10
32+
border_width: int = 0
33+
face_color: str | None = None
34+
face_color_cycle: list[tuple] | None = None
35+
face_colormap: str = DEFAULT_COLORMAP
36+
text: dict = field(default_factory=lambda: {"visible": False})
37+
38+
def set_color_by(self, prop: str, cmap: str | None = None) -> None:
39+
"""Set the face_color to a column in the properties DataFrame.
40+
41+
Parameters
42+
----------
43+
prop : str
44+
The column name in the properties DataFrame to color by.
45+
cmap : str, optional
46+
The name of the colormap to use, otherwise use the face_colormap.
47+
48+
"""
49+
if cmap is None:
50+
cmap = self.face_colormap
51+
self.face_color = prop
52+
self.text["string"] = prop
53+
n_colors = len(self.properties[prop].unique())
54+
self.face_color_cycle = _sample_colormap(n_colors, cmap)
55+
56+
57+
def _sample_colormap(n: int, cmap_name: str) -> list[tuple]:
58+
"""Sample n equally-spaced colors from a napari colormap.
59+
60+
This includes the endpoints of the colormap.
61+
"""
62+
cmap = ensure_colormap(cmap_name)
63+
samples = np.linspace(0, len(cmap.colors) - 1, n).astype(int)
64+
return [tuple(cmap.colors[i]) for i in samples]

pyproject.toml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,8 @@ entry-points."napari.manifest".movement = "movement.napari:napari.yaml"
4747

4848
[project.optional-dependencies]
4949
napari = [
50-
"napari[all]>=0.4.19",
51-
# the rest will be replaced by brainglobe-utils[qt]>=0.6 after release
52-
"brainglobe-atlasapi>=2.0.7",
53-
"brainglobe-utils>=0.5",
54-
"qtpy",
55-
"superqt",
50+
"napari[all]>=0.5.0",
51+
"brainglobe-utils[qt]>=0.6" # needed for collapsible widgets
5652
]
5753
dev = [
5854
"pytest",

0 commit comments

Comments
 (0)