Skip to content

Commit bc817d4

Browse files
committed
Apply Hanno's truncation-correction factor for method of moments estimation.
1 parent eb36154 commit bc817d4

File tree

3 files changed

+186
-66
lines changed

3 files changed

+186
-66
lines changed

docs/standalone_notebooks/sourcefinder_debugging.ipynb

+33-5
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@
6161
"metadata": {},
6262
"outputs": [],
6363
"source": [
64-
"test_source = Gaussian2dParams(x_centre=18.043469485644753, y_centre=34.32245772421613, amplitude=25.31192179601978,\n",
65-
" semimajor=0.9129890763866767, semiminor=0.6266945800960042, theta=-0.1905488784339695)\n",
64+
"test_source = Gaussian2dParams(x_centre=18.220534745584583, y_centre=34.58710129203467, amplitude=5.509118196421525, semimajor=2.382286097703295, semiminor=0.7718844595419742, theta=-1.4225499318644537)\n",
65+
"\n",
6666
"img = np.zeros(image_shape)\n",
6767
"add_gaussian2d_to_image(test_source, img)"
6868
]
@@ -92,7 +92,7 @@
9292
"metadata": {},
9393
"outputs": [],
9494
"source": [
95-
"sfimg.fit_gaussian_2d(sfimg.islands[0])"
95+
"sfimg.fit_gaussian_2d(sfimg.islands[0],verbose=2)"
9696
]
9797
},
9898
{
@@ -101,7 +101,25 @@
101101
"metadata": {},
102102
"outputs": [],
103103
"source": [
104-
"sfimg.islands[0].params"
104+
"pars = sfimg.islands[0].params"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": null,
110+
"metadata": {},
111+
"outputs": [],
112+
"source": [
113+
"pars.moments_fit"
114+
]
115+
},
116+
{
117+
"cell_type": "code",
118+
"execution_count": null,
119+
"metadata": {},
120+
"outputs": [],
121+
"source": [
122+
"pars.leastsq_fit"
105123
]
106124
},
107125
{
@@ -113,14 +131,24 @@
113131
"test_source"
114132
]
115133
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": null,
137+
"metadata": {},
138+
"outputs": [],
139+
"source": [
140+
"leastsq_img = np.zeros(image_shape)\n",
141+
"add_gaussian2d_to_image(pars.leastsq_fit, leastsq_img)"
142+
]
143+
},
116144
{
117145
"cell_type": "code",
118146
"execution_count": null,
119147
"metadata": {},
120148
"outputs": [],
121149
"source": [
122150
"# fig = plt.figure(figsize=(8, 10))\n",
123-
"plt.imshow(img)\n",
151+
"plt.imshow(img- leastsq_img)\n",
124152
"plt.colorbar()\n",
125153
"plt.scatter(test_source.x_centre, test_source.y_centre)\n",
126154
"peak_idx =sfimg.islands[0].extremum.index\n",

docs/standalone_notebooks/test_source_fitting.ipynb

+46-6
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
"\n",
6565
"base_x = 18\n",
6666
"base_y = 34\n",
67-
"n_sources = 200\n",
67+
"n_sources = 5\n",
6868
"positive_sources = generate_random_source_params(n_sources=n_sources,\n",
6969
" base_x=base_x,\n",
7070
" base_y=base_y,\n",
@@ -91,7 +91,7 @@
9191
"outputs": [],
9292
"source": [
9393
"n_islands = 0\n",
94-
"islands = []\n",
94+
"island_params = []\n",
9595
"fits = []\n",
9696
"\n",
9797
"start = datetime.datetime.now()\n",
@@ -109,7 +109,7 @@
109109
" n_islands += 1\n",
110110
" \n",
111111
" assert len(sfimg.islands) == 1\n",
112-
" islands.append(sfimg.islands[0])\n",
112+
" island_params.append(sfimg.islands[0].params)\n",
113113
" lsq_fit = sfimg.fit_gaussian_2d(sfimg.islands[0], verbose=1)\n",
114114
" fits.append(lsq_fit)\n",
115115
" else:\n",
@@ -178,6 +178,25 @@
178178
"import attr"
179179
]
180180
},
181+
{
182+
"cell_type": "code",
183+
"execution_count": null,
184+
"metadata": {},
185+
"outputs": [],
186+
"source": [
187+
"print(island_params[0])"
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"metadata": {},
194+
"outputs": [],
195+
"source": [
196+
"num_evaluations = np.array([i.optimize_result.nfev for i in island_params])\n",
197+
"np.where(num_evaluations > 10)"
198+
]
199+
},
181200
{
182201
"cell_type": "code",
183202
"execution_count": null,
@@ -223,18 +242,39 @@
223242
"metadata": {},
224243
"outputs": [],
225244
"source": [
226-
"idx = 172\n",
245+
"for idx in range(len(positive_sources)):\n",
246+
" print(island_params[idx].moments_fit.comparable_params)\n",
247+
" print(positive_sources[idx].comparable_params)\n",
248+
" print(np.degrees(positive_sources[idx].theta), np.degrees(island_params[idx].moments_fit.theta))\n",
249+
" print()"
250+
]
251+
},
252+
{
253+
"cell_type": "code",
254+
"execution_count": null,
255+
"metadata": {},
256+
"outputs": [],
257+
"source": [
258+
"idx = 125\n",
227259
"f = fits[idx]\n",
228260
"# print(islands[idx].fit.comparable_params == approx(positive_sources[idx].comparable_params))\n",
229-
"i=islands[idx]\n",
261+
"i=island_params[idx]\n",
230262
"print(\"Peak\\n\", i.extremum)\n",
231-
"print(\"Moments\\n\", i.xbar,i.ybar)\n",
263+
"print(\"Moments\\n\", i.moments_fit)\n",
232264
"print(\"Fit\\n\", f)\n",
233265
"print(\"Truth\\n\", positive_sources[idx])\n",
234266
"print()\n",
235267
"print(f.comparable_params)\n",
268+
"print(i.moments_fit.comparable_params)\n",
236269
"print(positive_sources[idx].comparable_params)"
237270
]
271+
},
272+
{
273+
"cell_type": "code",
274+
"execution_count": null,
275+
"metadata": {},
276+
"outputs": [],
277+
"source": []
238278
}
239279
],
240280
"metadata": {

src/fastimgproto/sourcefind/image.py

+107-55
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
from attr import attrib, attrs
88
from scipy import ndimage
9-
from scipy.optimize import least_squares
9+
from scipy.optimize import OptimizeResult, least_squares
1010

1111
from fastimgproto.sourcefind.fit import (
1212
Gaussian2dParams,
@@ -102,11 +102,24 @@ class IslandParams(object):
102102
extremum = attrib(validator=attr.validators.instance_of(Pixel))
103103

104104
# Optional
105+
moments_fit = attrib(
106+
default=None,
107+
validator=attr.validators.optional(
108+
attr.validators.instance_of((Gaussian2dParams, bool))))
109+
105110
leastsq_fit = attrib(
106111
default=None,
107112
validator=attr.validators.optional(
108113
attr.validators.instance_of((Gaussian2dParams, bool))))
109114

115+
# Useful for debugging - store the full report on the least-squares fit.
116+
# Don't show it in the standard repr, though - too verbose!
117+
optimize_result = attrib(
118+
default=None, repr=False, cmp=False,
119+
validator=attr.validators.optional(
120+
attr.validators.instance_of(OptimizeResult)
121+
))
122+
110123

111124
@attrs
112125
class Island(object):
@@ -296,23 +309,75 @@ def _label_detection_islands(self, sign):
296309
def calculate_moments(self, island):
297310
"""
298311
Analyses an island to extract further parameters.
312+
313+
See Hanno Spreeuw's thesis for formulae (eqn 2.50 -- 2.54).
314+
(Will add a derivation notebook to repo if time allows).
299315
"""
300316
sign = island.sign
301-
sum = sign * np.ma.sum(island.data)
302-
island.xbar = np.ma.sum(self.xgrid * sign * island.data) / sum
303-
island.ybar = np.ma.sum(self.ygrid * sign * island.data) / sum
317+
# If working with a negative source, be sure to take a positive copy
318+
# (modulus) of the island data to get the moment calculations correct.
319+
abs_data = sign * island.data
320+
sum = abs_data.sum()
321+
y = self.ygrid
322+
x = self.xgrid
323+
x_bar = (x * abs_data).sum() / sum
324+
y_bar = (y * abs_data).sum() / sum
325+
xx_bar = (x * x * abs_data).sum() / sum - x_bar * x_bar
326+
yy_bar = (y * y * abs_data).sum() / sum - y_bar * y_bar
327+
xy_bar = (x * y * abs_data).sum() / sum - x_bar * y_bar
328+
329+
working1 = (xx_bar + yy_bar) / 2.0
330+
working2 = math.sqrt(((xx_bar - yy_bar) / 2) ** 2 + xy_bar ** 2)
331+
trunc_semimajor_sq = working1 + working2
332+
trunc_semiminor_sq = working1 - working2
333+
334+
# Semimajor / minor axes are under-estimated due to threholding
335+
# Hanno calculated the following correction factor (eqns 2.60,2.61):
336+
337+
pixel_threshold = self.analysis_n_sigma * self.rms_est
338+
# `cutoff_ratio` == 'C/T' in Hanno's formulae.
339+
# Always >1.0, else the source would not be detected.
340+
cutoff_ratio = sign * island.extremum.value / pixel_threshold
341+
axes_scale_factor = 1.0 - math.log(cutoff_ratio) / (cutoff_ratio - 1.)
342+
semimajor_est = math.sqrt(trunc_semimajor_sq / axes_scale_factor)
343+
semiminor_est = math.sqrt(trunc_semiminor_sq / axes_scale_factor)
344+
345+
# For theta, we differ from Hanno's algorithm - I think Hanno maybe made
346+
# an error,or possibly this is due to different parameter
347+
# bound-choices, not sure...
348+
theta_est = 0.5 * math.atan(2. * xy_bar / (xx_bar - yy_bar))
349+
350+
# Atan(theta) solutions are periodic - can add or subtract pi.
351+
# math.atan(theta) returns an angle in the range (-pi/2,pi/2) (matching the
352+
# sign of theta).
353+
# No problem since we're robust to rotations of pi / 180 degrees.
354+
# But atan(2theta) solutions are periodic in pi/2. This is an issue,
355+
# since we could have the wrong solution. To do so, we can just check
356+
# if we're in the correct quadrant - if it's the wrong solution,
357+
# it will be flipped by pi/2, then constrained to the (-pi/2,pi/2)
358+
# by an additional rotation of pi. So if needed we add *another* pi/2,
359+
# and let the Gaussian2dParams constructor take care of correcting bounds.
360+
# We expect the sign of theta to match the sign of the covariance:
361+
if theta_est * xy_bar < 0.:
362+
theta_est += math.pi / 2.0
363+
364+
moments_fits = Gaussian2dParams.from_unconstrained_parameters(
365+
x_centre=x_bar,
366+
y_centre=y_bar,
367+
amplitude=island.extremum.value,
368+
semimajor=semimajor_est,
369+
semiminor=semiminor_est,
370+
theta=theta_est
371+
)
372+
island.params.moments_fit = moments_fits
373+
return moments_fits
304374

305375
def fit_gaussian_2d(self, island, verbose=0):
306376
# x, y, x_centre, y_centre, amplitude, x_stddev, y_stddev, theta
307377
y_indices, x_indices = island.unmasked_pixel_indices
308378
fitting_data = island.data[y_indices, x_indices]
309379

310-
def island_residuals(x_centre,
311-
y_centre,
312-
amplitude,
313-
semimajor,
314-
semiminor,
315-
theta):
380+
def island_residuals(pars):
316381
"""
317382
A wrapped version of `gaussian2d` applied to this island's unmasked
318383
pixels, then subtracting the island values
@@ -325,6 +390,13 @@ def island_residuals(x_centre,
325390
326391
"""
327392

393+
(x_centre,
394+
y_centre,
395+
amplitude,
396+
semimajor,
397+
semiminor,
398+
theta) = pars
399+
328400
model_vals = gaussian2d(x_indices, y_indices,
329401
x_centre=x_centre,
330402
y_centre=y_centre,
@@ -336,57 +408,37 @@ def island_residuals(x_centre,
336408
assert model_vals.shape == fitting_data.shape
337409
return fitting_data - model_vals
338410

339-
def located_jacobian(pars):
340-
"""
341-
Wrapped version of `gaussian2d_jac` applied at these pixel positions.
342-
"""
343-
(x_centre,
344-
y_centre,
345-
amplitude,
346-
semimajor,
347-
semiminor,
348-
theta) = pars
349-
return gaussian2d_jac(x_indices, y_indices,
350-
x_centre=x_centre,
351-
y_centre=y_centre,
352-
amplitude=amplitude,
353-
x_stddev=semimajor,
354-
y_stddev=semiminor,
355-
theta=theta,
356-
)
357-
358-
def wrapped_island_residuals(pars):
359-
"""
360-
Wrapped version of `island_residuals` that takes a single argument
361-
362-
(a tuple of the varying parameters).
363-
364-
Args:
365-
pars (tuple):
366-
(x_centre, y_centre, amplitude, x_stddev, y_stddev, theta)
367-
368-
Returns:
369-
numpy.ndarray: vector of residuals
370-
371-
"""
372-
assert len(pars) == 6
373-
return island_residuals(*pars)
374-
375-
initial_params = Gaussian2dParams(x_centre=island.xbar,
376-
y_centre=island.ybar,
377-
amplitude=island.extremum.value,
378-
semimajor=1.,
379-
semiminor=1.,
380-
theta=0
381-
)
411+
# def located_jacobian(pars):
412+
# """
413+
# Wrapped version of `gaussian2d_jac` applied at these pixel positions.
414+
# """
415+
# (x_centre,
416+
# y_centre,
417+
# amplitude,
418+
# semimajor,
419+
# semiminor,
420+
# theta) = pars
421+
# return gaussian2d_jac(x_indices, y_indices,
422+
# x_centre=x_centre,
423+
# y_centre=y_centre,
424+
# amplitude=amplitude,
425+
# x_stddev=semimajor,
426+
# y_stddev=semiminor,
427+
# theta=theta,
428+
# )
429+
430+
431+
initial_params = island.params.moments_fit
382432

383433
# Using the jacobian mostly gives bad fits?
384-
lsq_result = least_squares(fun=wrapped_island_residuals,
434+
lsq_result = least_squares(fun=island_residuals,
385435
# jac=located_jacobian,
386436
x0=attr.astuple(initial_params),
387-
# method='dogbox',
437+
method='dogbox',
388438
verbose=verbose,
439+
# max_nfev=50,
389440
)
441+
island.params.optimize_result = lsq_result
390442
island.params.leastsq_fit = Gaussian2dParams.from_unconstrained_parameters(
391443
*tuple(lsq_result.x))
392444
return island.params.leastsq_fit

0 commit comments

Comments
 (0)