Skip to content

Generalise functionality for plotting from feature tables #167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Move key setting/getting
  • Loading branch information
dstansby committed Jun 15, 2023
commit 75e79b6610be3f5d46eba54130e61f006bf9bf2f
30 changes: 21 additions & 9 deletions src/napari_matplotlib/features.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Optional

import napari
import napari.layers
Expand All @@ -12,12 +12,6 @@ class FeaturesMixin(NapariMPLWidget):
"""
Mixin to help with widgets that plot data from a features table stored
on a single layer.

Notes
-----
This currently only works for widgets that plot two quatities against each other
e.g., scatter plots. It is intended to be generalised in the future for widgets
that plot one quantity e.g., histograms.
"""

n_layers_input = Interval(1, 1)
Expand All @@ -30,19 +24,37 @@ class FeaturesMixin(NapariMPLWidget):
napari.layers.Vectors,
)

def __init__(self) -> None:
def __init__(self, *, ndim: int) -> None:
assert ndim in [1, 2]
self.dims = ["x", "y"][:ndim]
# Set up selection boxes
self.layout().addLayout(QVBoxLayout())

self._selectors: Dict[str, QComboBox] = {}
for dim in ["x", "y"]:
for dim in self.dims:
self._selectors[dim] = QComboBox()
# Re-draw when combo boxes are updated
self._selectors[dim].currentTextChanged.connect(self._draw)

self.layout().addWidget(QLabel(f"{dim}-axis:"))
self.layout().addWidget(self._selectors[dim])

def get_key(self, dim: str) -> Optional[str]:
"""
Get key for a given dimension.
"""
if self._selectors[dim].count() == 0:
return None
else:
return self._selectors[dim].currentText()

def set_key(self, dim: str, value: str) -> None:
"""
Set key for a given dimension.
"""
self._selectors[dim].setCurrentText(value)
self._draw()

def _get_valid_axis_keys(self) -> List[str]:
"""
Get the valid axis keys from the layer FeatureTable.
Expand Down
46 changes: 8 additions & 38 deletions src/napari_matplotlib/scatter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Tuple

import napari
import numpy.typing as npt
Expand Down Expand Up @@ -97,39 +97,9 @@ def __init__(
parent: Optional[QWidget] = None,
):
ScatterBaseWidget.__init__(self, napari_viewer, parent=parent)
FeaturesMixin.__init__(self)
FeaturesMixin.__init__(self, ndim=2)
self._update_layers(None)

@property
def x_axis_key(self) -> Union[str, None]:
"""
Key for the x-axis data.
"""
if self._selectors["x"].count() == 0:
return None
else:
return self._selectors["x"].currentText()

@x_axis_key.setter
def x_axis_key(self, key: str) -> None:
self._selectors["x"].setCurrentText(key)
self._draw()

@property
def y_axis_key(self) -> Union[str, None]:
"""
Key for the y-axis data.
"""
if self._selectors["y"].count() == 0:
return None
else:
return self._selectors["y"].currentText()

@y_axis_key.setter
def y_axis_key(self, key: str) -> None:
self._selectors["y"].setCurrentText(key)
self._draw()

def _ready_to_scatter(self) -> bool:
"""
Return True if selected layer has a feature table we can scatter with,
Expand All @@ -143,8 +113,8 @@ def _ready_to_scatter(self) -> bool:
return (
feature_table is not None
and len(feature_table) > 0
and self.x_axis_key in valid_keys
and self.y_axis_key in valid_keys
and self.get_key("x") in valid_keys
and self.get_key("y") in valid_keys
)

def draw(self) -> None:
Expand Down Expand Up @@ -173,11 +143,11 @@ def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
"""
feature_table = self.layers[0].features

x = feature_table[self.x_axis_key]
y = feature_table[self.y_axis_key]
x = feature_table[self.get_key("x")]
y = feature_table[self.get_key("y")]

x_axis_name = str(self.x_axis_key)
y_axis_name = str(self.y_axis_key)
x_axis_name = str(self.get_key("x"))
y_axis_name = str(self.get_key("y"))

return x, y, x_axis_name, y_axis_name

Expand Down
8 changes: 4 additions & 4 deletions src/napari_matplotlib/tests/scatter/test_scatter_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def test_features_scatter_widget_2D(

# Select points data and chosen features
viewer.layers.selection.add(viewer.layers[0]) # images need to be selected
widget.x_axis_key = "feature_0"
widget.y_axis_key = "feature_1"
widget.set_key("x", "feature_0")
widget.set_key("y", "feature_1")

fig = widget.figure

Expand Down Expand Up @@ -64,9 +64,9 @@ def test_features_scatter_get_data(make_napari_viewer):
viewer.layers.selection = [labels_layer]

x_column = "feature_0"
scatter_widget.x_axis_key = x_column
y_column = "feature_2"
scatter_widget.y_axis_key = y_column
scatter_widget.set_key("x", x_column)
scatter_widget.set_key("y", y_column)

x, y, x_axis_name, y_axis_name = scatter_widget._get_data()
np.testing.assert_allclose(x, feature_table[x_column])
Expand Down