Skip to content

Commit 490798d

Browse files
committed
Final changes implement negative labelling for negative sources.
Combine label-> extremum_val dictionaries, combine label pixel-arrays.
1 parent 42961a3 commit 490798d

File tree

2 files changed

+35
-19
lines changed

2 files changed

+35
-19
lines changed

src/fastimgproto/sourcefind/image.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ class IslandParams(object):
3030
xbar (float): Barycentric centre in x-pixel index
3131
ybar(float): Barycentric centre in y-pixel index
3232
"""
33-
parent = attrib()
34-
label_idx = attrib()
33+
parent = attrib(cmp=False)
34+
label_idx = attrib(cmp=False)
3535
sign = attrib(validator=_positive_negative_sign_validator)
3636
extremum_val = attrib()
3737
extremum_x_idx = attrib(default=None)
@@ -49,7 +49,7 @@ def calculate_params(self):
4949

5050
)
5151
self.extremum_y_idx, self.extremum_x_idx = _extremum_pixel_index(
52-
self.data)
52+
self.data, self.sign)
5353
sum = self.sign * np.ma.sum(self.data)
5454
self.xbar = np.ma.sum(self.parent.xgrid * self.sign * self.data) / sum
5555
self.ybar = np.ma.sum(self.parent.ygrid * self.sign * self.data) / sum
@@ -59,11 +59,15 @@ def _label_mask(labels_map, label_num):
5959
return ~(labels_map == label_num)
6060

6161

62-
def _extremum_pixel_index(masked_image):
62+
def _extremum_pixel_index(masked_image,sign):
6363
"""
64-
Returns max pixel index in np array ordering, i.e. (y_max, x_max)
64+
Returns max/min pixel index in np array ordering, i.e. (y_max, x_max)
6565
"""
66-
return np.unravel_index(np.ma.argmax(masked_image),
66+
if sign==1:
67+
extremum_func = np.ma.argmax
68+
elif sign==-1:
69+
extremum_func = np.ma.argmin
70+
return np.unravel_index(extremum_func(masked_image),
6771
masked_image.shape)
6872

6973

@@ -98,15 +102,11 @@ def __init__(self, data, detection_n_sigma, analysis_n_sigma,
98102

99103
# Label connected regions
100104

105+
self.label_map, label_extrema = self._label_detection_islands(1)
101106
if find_negative_sources:
102-
pos_label_map, pos_label_extrema = self._label_detection_islands(1)
103107
neg_label_map, neg_label_extrema = self._label_detection_islands(-1)
104-
self.label_map = self._combine_label_maps(pos_label_map,
105-
neg_label_map)
106-
label_extrema = self._combine_label_extrema(pos_label_extrema,
107-
neg_label_extrema)
108-
else:
109-
self.label_map, label_extrema = self._label_detection_islands(1)
108+
self.label_map += neg_label_map
109+
label_extrema.update(neg_label_extrema)
110110

111111
self.islands = []
112112
for l_idx, l_extremum in label_extrema.items():
@@ -166,6 +166,13 @@ def _label_detection_islands(self, sign):
166166
valid_label_extrema[label] = ex_val
167167
else:
168168
label_map[label_map == label] = 0.
169+
170+
if sign == -1:
171+
# If extracting negative sources, flip the sign of the indices
172+
valid_label_extrema = {-1 * k: valid_label_extrema[k]
173+
for k in valid_label_extrema}
174+
# ... and the corresponding label map:
175+
label_map = -1 * label_map
169176
return label_map, valid_label_extrema
170177

171178

tests/test_sourcefind/test_detection.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
1-
import pytest
1+
from __future__ import print_function
2+
3+
import numpy as np
4+
import scipy.ndimage as ndimage
25

36
from fastimgproto.fixtures.image import (
47
evaluate_model_on_pixel_grid,
58
gaussian_point_source,
69
uncorrelated_gaussian_noise_background
710
)
811
from fastimgproto.sourcefind.image import (SourceFindImage, _estimate_rms)
9-
import numpy as np
1012

1113
ydim = 128
1214
xdim = 64
1315
rms = 1.0
1416
bright_src = gaussian_point_source(x_centre=48.24, y_centre=52.66,
1517
amplitude=10.0)
1618
faint_src = gaussian_point_source(x_centre=32, y_centre=64, amplitude=3.5)
17-
negative_src = gaussian_point_source(x_centre=24.31, y_centre=32.157,
19+
negative_src = gaussian_point_source(x_centre=24.31, y_centre=28.157,
1820
amplitude=-10.0)
1921

2022

@@ -65,7 +67,8 @@ def test_basic_source_detection():
6567
find_negative_sources=False)
6668
assert len(sf.islands) == 2
6769

68-
@pytest.mark.xfail()
70+
71+
6972
def test_negative_source_detection():
7073
"""
7174
Also need to detect 'negative' sources, i.e. where a source in the
@@ -83,8 +86,9 @@ def test_negative_source_detection():
8386
rms_est=rms)
8487
assert len(sf.islands) == 1
8588
found_src = sf.islands[0]
86-
# print(bright_src)
87-
# print(src)
89+
print()
90+
print(negative_src)
91+
print(found_src)
8892
assert np.abs(found_src.extremum_x_idx - negative_src.x_mean) < 0.5
8993
assert np.abs(found_src.extremum_y_idx - negative_src.y_mean) < 0.5
9094
assert np.abs(found_src.xbar - negative_src.x_mean) < 0.1
@@ -95,3 +99,8 @@ def test_negative_source_detection():
9599
analysis_n_sigma=3,
96100
rms_est=rms)
97101
assert len(sf.islands) == 2
102+
positive_islands = [i for i in sf.islands if i.sign==1]
103+
negative_islands = [i for i in sf.islands if i.sign==-1]
104+
assert len(positive_islands) ==1
105+
assert len(negative_islands) ==1
106+
assert negative_islands[0]== found_src

0 commit comments

Comments
 (0)