11from typing import Any , List , Optional , Tuple
22
3- import matplotlib .colors as mcolor
43import napari
54import numpy .typing as npt
65from magicgui import magicgui
@@ -17,15 +16,8 @@ class ScatterBaseWidget(NapariMPLWidget):
1716 Base class for widgets that scatter two datasets against each other.
1817 """
1918
20- # opacity value for the markers
21- _marker_alpha = 0.5
22-
23- # flag set to True if histogram should be used
24- # for plotting large points
25- _histogram_for_large_data = True
26-
2719 # if the number of points is greater than this value,
28- # the scatter is plotted as a 2dhist
20+ # the scatter is plotted as a 2D histogram
2921 _threshold_to_switch_to_histogram = 500
3022
3123 def __init__ (self , napari_viewer : napari .viewer .Viewer ):
@@ -44,40 +36,32 @@ def draw(self) -> None:
4436 """
4537 Scatter the currently selected layers.
4638 """
47- data , x_axis_name , y_axis_name = self ._get_data ()
48-
49- if len (data ) == 0 :
50- # don't plot if there isn't data
51- return
39+ x , y , x_axis_name , y_axis_name = self ._get_data ()
5240
53- if self ._histogram_for_large_data and (
54- data [0 ].size > self ._threshold_to_switch_to_histogram
55- ):
41+ if x .size > self ._threshold_to_switch_to_histogram :
5642 self .axes .hist2d (
57- data [ 0 ] .ravel (),
58- data [ 1 ] .ravel (),
43+ x .ravel (),
44+ y .ravel (),
5945 bins = 100 ,
60- norm = mcolor .LogNorm (),
6146 )
6247 else :
63- self .axes .scatter (data [ 0 ], data [ 1 ] , alpha = self . _marker_alpha )
48+ self .axes .scatter (x , y , alpha = 0.5 )
6449
6550 self .axes .set_xlabel (x_axis_name )
6651 self .axes .set_ylabel (y_axis_name )
6752
68- def _get_data (self ) -> Tuple [List [npt .NDArray [Any ]], str , str ]:
69- """Get the plot data.
53+ def _get_data (self ) -> Tuple [npt .NDArray [Any ], npt .NDArray [Any ], str , str ]:
54+ """
55+ Get the plot data.
7056
7157 This must be implemented on the subclass.
7258
7359 Returns
7460 -------
75- data : np.ndarray
76- The list containing the scatter plot data.
77- x_axis_name : str
78- The label to display on the x axis
79- y_axis_name: str
80- The label to display on the y axis
61+ x, y : np.ndarray
62+ x and y values of plot data.
63+ x_axis_name, y_axis_name : str
64+ Label to display on the x/y axis
8165 """
8266 raise NotImplementedError
8367
@@ -93,7 +77,7 @@ class ScatterWidget(ScatterBaseWidget):
9377 n_layers_input = Interval (2 , 2 )
9478 input_layer_types = (napari .layers .Image ,)
9579
96- def _get_data (self ) -> Tuple [List [ npt .NDArray [Any ]], str , str ]:
80+ def _get_data (self ) -> Tuple [npt .NDArray [Any ], npt . NDArray [ Any ], str , str ]:
9781 """
9882 Get the plot data.
9983
@@ -106,11 +90,12 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
10690 y_axis_name: str
10791 The title to display on the y axis
10892 """
109- data = [layer .data [self .current_z ] for layer in self .layers ]
93+ x = self .layers [0 ].data [self .current_z ]
94+ y = self .layers [1 ].data [self .current_z ]
11095 x_axis_name = self .layers [0 ].name
11196 y_axis_name = self .layers [1 ].name
11297
113- return data , x_axis_name , y_axis_name
98+ return x , y , x_axis_name , y_axis_name
11499
115100
116101class FeaturesScatterWidget (ScatterBaseWidget ):
@@ -191,9 +176,33 @@ def _get_valid_axis_keys(
191176 else :
192177 return self .layers [0 ].features .keys ()
193178
194- def _get_data (self ) -> Tuple [ List [ npt . NDArray [ Any ]], str , str ] :
179+ def _ready_to_scatter (self ) -> bool :
195180 """
196- Get the plot data.
181+ Return True if selected layer has a feature table we can scatter with,
182+ and the two columns to be scatterd have been selected.
183+ """
184+ if not hasattr (self .layers [0 ], "features" ):
185+ return False
186+
187+ feature_table = self .layers [0 ].features
188+ return (
189+ feature_table is not None
190+ and len (feature_table ) > 0
191+ and self .x_axis_key is not None
192+ and self .y_axis_key is not None
193+ )
194+
195+ def draw (self ) -> None :
196+ """
197+ Scatter two features from the currently selected layer.
198+ """
199+ if self ._ready_to_scatter ():
200+ super ().draw ()
201+
202+ def _get_data (self ) -> Tuple [npt .NDArray [Any ], npt .NDArray [Any ], str , str ]:
203+ """
204+ Get the plot data from the ``features`` attribute of the first
205+ selected layer.
197206
198207 Returns
199208 -------
@@ -207,28 +216,15 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
207216 The title to display on the y axis. Returns
208217 an empty string if nothing to plot.
209218 """
210- if not hasattr (self .layers [0 ], "features" ):
211- # if the selected layer doesn't have a featuretable,
212- # skip draw
213- return [], "" , ""
214-
215219 feature_table = self .layers [0 ].features
216220
217- if (
218- (len (feature_table ) == 0 )
219- or (self .x_axis_key is None )
220- or (self .y_axis_key is None )
221- ):
222- return [], "" , ""
223-
224- data_x = feature_table [self .x_axis_key ]
225- data_y = feature_table [self .y_axis_key ]
226- data = [data_x , data_y ]
221+ x = feature_table [self .x_axis_key ]
222+ y = feature_table [self .y_axis_key ]
227223
228- x_axis_name = self .x_axis_key . replace ( "_" , " " )
229- y_axis_name = self .y_axis_key . replace ( "_" , " " )
224+ x_axis_name = str ( self .x_axis_key )
225+ y_axis_name = str ( self .y_axis_key )
230226
231- return data , x_axis_name , y_axis_name
227+ return x , y , x_axis_name , y_axis_name
232228
233229 def _on_update_layers (self ) -> None :
234230 """
0 commit comments