41
41
warnings .filterwarnings ("ignore" , message = "divide by zero encountered in log" )
42
42
43
43
MAXIT = 10000 # maximum number of iterations in self-cal minimizer
44
+ NHIST = 50 # number of steps to store for hessian approx
45
+ MAXLS = 40 # maximum number of line search steps in BFGS-B
46
+ STOP = 1e-6 # convergence criterion
44
47
45
48
###################################################################################################
46
49
# Self-Calibration
@@ -52,7 +55,8 @@ def self_cal(obs, im, sites=[], method="both", pol='I', minimizer_method='BFGS',
52
55
ttype = 'direct' , fft_pad_factor = 2 , caltable = False ,
53
56
debias = True , apply_dterms = False ,
54
57
copy_closure_tables = True ,
55
- processes = - 1 , show_solution = False , msgtype = 'bar' ):
58
+ processes = - 1 , show_solution = False , msgtype = 'bar' ,
59
+ use_grad = False ):
56
60
"""Self-calibrate a dataset to an image.
57
61
58
62
Args:
@@ -83,12 +87,14 @@ def self_cal(obs, im, sites=[], method="both", pol='I', minimizer_method='BFGS',
83
87
apply_dterms (bool): if True, apply dterms (in obs.tarr) to clean data before calibrating
84
88
show_solution (bool): if True, display the solution as it is calculated
85
89
msgtype (str): type of progress message to be printed, default is 'bar'
86
-
90
+ use_grad (bool): if True, use gradients in minimizer
91
+
87
92
Returns:
88
93
(Obsdata): the calibrated observation, if caltable==False
89
94
(Caltable): the derived calibration table, if caltable==True
90
95
"""
91
-
96
+ if use_grad and (method == 'phase' or method == 'amp' ):
97
+ raise Exception ("errfunc_grad in self_cal only works with method=='both'!" )
92
98
if pol not in ['I' , 'Q' , 'U' , 'V' , 'RR' , 'LL' ]:
93
99
raise Exception ("Can only self-calibrate to I, Q, U, V, RR, or LL images!" )
94
100
if pol in ['I' , 'Q' , 'U' , 'V' ]:
@@ -148,7 +154,8 @@ def self_cal(obs, im, sites=[], method="both", pol='I', minimizer_method='BFGS',
148
154
obs .polrep , pol ,
149
155
method , minimizer_method ,
150
156
show_solution , pad_amp , gain_tol ,
151
- caltable , debias , msgtype
157
+ debias , caltable , msgtype ,
158
+ use_grad
152
159
] for i in range (len (scans ))]))
153
160
154
161
else : # run on a single core
@@ -157,8 +164,10 @@ def self_cal(obs, im, sites=[], method="both", pol='I', minimizer_method='BFGS',
157
164
scans_cal [i ] = self_cal_scan (scans [i ], im , V_scan = V_scans [i ], sites = sites ,
158
165
polrep = obs .polrep , pol = pol ,
159
166
method = method , minimizer_method = minimizer_method ,
160
- show_solution = show_solution , debias = debias ,
161
- pad_amp = pad_amp , gain_tol = gain_tol , caltable = caltable )
167
+ show_solution = show_solution ,
168
+ pad_amp = pad_amp , gain_tol = gain_tol ,
169
+ debias = debias , caltable = caltable ,
170
+ use_grad = use_grad )
162
171
163
172
tstop = time .time ()
164
173
print ("\n self_cal time: %f s" % (tstop - tstart ))
@@ -201,7 +210,8 @@ def self_cal(obs, im, sites=[], method="both", pol='I', minimizer_method='BFGS',
201
210
202
211
def self_cal_scan (scan , im , V_scan = [], sites = [], polrep = 'stokes' , pol = 'I' , method = "both" ,
203
212
minimizer_method = 'BFGS' , show_solution = False ,
204
- pad_amp = 0. , gain_tol = .2 , debias = True , caltable = False ):
213
+ pad_amp = 0. , gain_tol = .2 , debias = True , caltable = False ,
214
+ use_grad = False ):
205
215
"""Self-calibrate a scan to an image.
206
216
207
217
Args:
@@ -224,12 +234,16 @@ def self_cal_scan(scan, im, V_scan=[], sites=[], polrep='stokes', pol='I', metho
224
234
debias (bool): If True, debias the amplitudes
225
235
caltable (bool): if True, returns a Caltable instead of an Obsdata
226
236
show_solution (bool): if True, display the solution as it is calculated
227
-
237
+ use_grad (bool): if True, use gradients in minimizer
238
+
228
239
Returns:
229
240
(Obsdata): the calibrated observation, if caltable==False
230
241
(Caltable): the derived calibration table, if caltable==True
231
242
"""
232
-
243
+
244
+ if use_grad and (method == 'phase' or method == 'amp' ):
245
+ raise Exception ("errfunc_grad in self_cal only works with method=='both'!" )
246
+
233
247
if len (sites ) == 0 :
234
248
print ("No stations specified in self cal: defaulting to calibrating all !" )
235
249
sites = list (set (scan ['t1' ]).union (set (scan ['t2' ])))
@@ -286,46 +300,25 @@ def self_cal_scan(scan, im, V_scan=[], sites=[], polrep='stokes', pol='I', metho
286
300
287
301
# error function
288
302
def errfunc (gpar ):
289
- # all the forward site gains (complex)
290
- g = gpar .astype (np .float64 ).view (dtype = np .complex128 )
303
+ return errfunc_full (gpar , vis , V_scan , sigma_inv , gain_tol , sites , g1_keys , g2_keys , method )
291
304
292
- if method == "phase" :
293
- g = g / np .abs (g )
294
- if method == "amp" :
295
- g = np .abs (np .real (g ))
296
-
297
- # append the default values to g for missing gains
298
- g = np .append (g , 1. )
299
- g1 = g [g1_keys ]
300
- g2 = g [g2_keys ]
301
-
302
- # build site specific tolerance parameters
303
- tol0 = np .array ([gain_tol .get (s , gain_tol ['default' ])[0 ] for s in sites ])
304
- tol1 = np .array ([gain_tol .get (s , gain_tol ['default' ])[1 ] for s in sites ])
305
-
306
- if method == 'amp' :
307
- verr = np .abs (vis ) - g1 * g2 .conj () * np .abs (V_scan )
308
- else :
309
- verr = vis - g1 * g2 .conj () * V_scan
310
-
311
- nan_mask = [not np .isnan (v ) for v in verr ]
312
- verr = verr [nan_mask ]
313
-
314
- # goodness-of-fit for gains
315
- chisq = np .sum ((verr .real * sigma_inv [nan_mask ])** 2 ) + \
316
- np .sum ((verr .imag * sigma_inv [nan_mask ])** 2 )
317
-
318
- # prior on the gains
319
- # don't count the last (default missing site) gain dummy value
320
- chisq_g = np .sum (np .log (np .abs (g [:- 1 ]))** 2 /
321
- ((np .abs (g [:- 1 ]) > 1 ) * tol0 + (np .abs (g [:- 1 ]) <= 1 ) * tol1 )** 2 )
322
-
323
- return chisq + chisq_g
305
+ def errfunc_grad (gpar ):
306
+ return errfunc_grad_full (gpar , vis , V_scan , sigma_inv , gain_tol , sites , g1_keys , g2_keys , method )
324
307
325
308
# use gradient descent to find the gains
326
- optdict = {'maxiter' : MAXIT } # minimizer params
327
- res = opt .minimize (errfunc , gpar_guess , method = minimizer_method , options = optdict )
328
-
309
+ # minimizer params
310
+ if minimizer_method == 'L-BFGS-B' :
311
+ optdict = {'maxiter' : MAXIT ,
312
+ 'ftol' : STOP , 'gtol' : STOP ,
313
+ 'maxcor' : NHIST , 'maxls' : MAXLS }
314
+ else :
315
+ optdict = {'maxiter' : MAXIT }
316
+
317
+ if use_grad :
318
+ res = opt .minimize (errfunc , gpar_guess , method = minimizer_method , options = optdict , jac = errfunc_grad )
319
+ else :
320
+ res = opt .minimize (errfunc , gpar_guess , method = minimizer_method , options = optdict )
321
+
329
322
# save the solution
330
323
g_fit = res .x .view (np .complex128 )
331
324
@@ -397,7 +390,7 @@ def get_selfcal_scan_cal(args):
397
390
398
391
399
392
def get_selfcal_scan_cal2 (i , n , scan , im , V_scan , sites , polrep , pol , method , minimizer_method ,
400
- show_solution , pad_amp , gain_tol , caltable , debias , msgtype ):
393
+ show_solution , pad_amp , gain_tol , debias , caltable , msgtype , use_grad ):
401
394
if n > 1 :
402
395
global counter
403
396
counter .increment ()
@@ -406,4 +399,140 @@ def get_selfcal_scan_cal2(i, n, scan, im, V_scan, sites, polrep, pol, method, mi
406
399
return self_cal_scan (scan , im , V_scan = V_scan , sites = sites , polrep = polrep , pol = pol ,
407
400
method = method , minimizer_method = minimizer_method ,
408
401
show_solution = show_solution ,
409
- pad_amp = pad_amp , gain_tol = gain_tol , caltable = caltable , debias = debias )
402
+ pad_amp = pad_amp , gain_tol = gain_tol , debias = debias , caltable = caltable ,
403
+ use_grad = use_grad )
404
+
405
+ # error function
406
+ def errfunc_full (gpar , vis , v_scan , sigma_inv , gain_tol , sites , g1_keys , g2_keys , method ):
407
+ # all the forward site gains (complex)
408
+ g = gpar .astype (np .float64 ).view (dtype = np .complex128 )
409
+
410
+ if method == "phase" :
411
+ g = g / np .abs (g )
412
+ if method == "amp" :
413
+ g = np .abs (np .real (g ))
414
+
415
+ # append the default values to g for missing gains
416
+ g = np .append (g , 1. )
417
+ g1 = g [g1_keys ]
418
+ g2 = g [g2_keys ]
419
+
420
+ # build site specific tolerance parameters
421
+ tol0 = np .array ([gain_tol .get (s , gain_tol ['default' ])[0 ] for s in sites ])
422
+ tol1 = np .array ([gain_tol .get (s , gain_tol ['default' ])[1 ] for s in sites ])
423
+
424
+ if method == 'amp' :
425
+ verr = np .abs (vis ) - g1 * g2 .conj () * np .abs (v_scan )
426
+ else :
427
+ verr = vis - g1 * g2 .conj () * v_scan
428
+
429
+ nan_mask = [not np .isnan (v ) for v in verr ]
430
+ verr = verr [nan_mask ]
431
+
432
+ # goodness-of-fit for gains
433
+ chisq = np .sum ((verr .real * sigma_inv [nan_mask ])** 2 ) + \
434
+ np .sum ((verr .imag * sigma_inv [nan_mask ])** 2 )
435
+
436
+ # prior on the gains
437
+ # don't count the last (default missing site) gain dummy value
438
+ tolsq = ((np .abs (g [:- 1 ]) > 1 ) * tol0 + (np .abs (g [:- 1 ]) <= 1 ) * tol1 )** 2
439
+ chisq_g = np .sum (np .log (np .abs (g [:- 1 ]))** 2 / tolsq )
440
+
441
+ # total chi^2
442
+ chisqtot = chisq + chisq_g
443
+ return chisqtot
444
+
445
+ def errfunc_grad_full (gpar , vis , v_scan , sigma_inv , gain_tol , sites , g1_keys , g2_keys , method ):
446
+ # does not work for method=='phase' or method=='amp'
447
+ if method == 'phase' or method == 'amp' :
448
+ raise Exception ("errfunc_grad in self_cal only works with method=='both'!" )
449
+
450
+ # all the forward site gains (complex)
451
+ g = gpar .astype (np .float64 ).view (dtype = np .complex128 )
452
+ gr = np .real (g )
453
+ gi = np .imag (g )
454
+
455
+ # build site specific tolerance parameters
456
+ tol0 = np .array ([gain_tol .get (s , gain_tol ['default' ])[0 ] for s in sites ])
457
+ tol1 = np .array ([gain_tol .get (s , gain_tol ['default' ])[1 ] for s in sites ])
458
+
459
+ # append the default values to g for missing gains
460
+ g = np .append (g , 1. )
461
+ g1 = g [g1_keys ]
462
+ g2 = g [g2_keys ]
463
+
464
+ g1r = np .real (g1 )
465
+ g1i = np .imag (g1 )
466
+ g2r = np .real (g2 )
467
+ g2i = np .imag (g2 )
468
+
469
+ v_scan_sq = v_scan * v_scan .conj ()
470
+ g1sq = g1 * (g1 .conj ())
471
+ g2sq = g2 * (g2 .conj ())
472
+
473
+ ###################################
474
+ # data term chi^2 derivitive
475
+ ###################################
476
+
477
+ # chi^2 term gradients
478
+ dchisq_dg1r = (- g2 .conj ()* vis .conj ()* v_scan - g2 * vis * v_scan .conj () + 2 * g1r * g2sq * v_scan_sq )
479
+ dchisq_dg1i = (- 1j * g2 .conj ()* vis .conj ()* v_scan + 1j * g2 * vis * v_scan .conj () + 2 * g1i * g2sq * v_scan_sq )
480
+
481
+ dchisq_dg2r = (- g1 * vis .conj ()* v_scan - g1 .conj ()* vis * v_scan .conj () + 2 * g2r * g1sq * v_scan_sq )
482
+ dchisq_dg2i = (1j * g1 * vis .conj ()* v_scan - 1j * g1 .conj ()* vis * v_scan .conj () + 2 * g2i * g1sq * v_scan_sq )
483
+
484
+
485
+ dchisq_dg1r *= ((sigma_inv )** 2 )
486
+ dchisq_dg1i *= ((sigma_inv )** 2 )
487
+ dchisq_dg2r *= ((sigma_inv )** 2 )
488
+ dchisq_dg2i *= ((sigma_inv )** 2 )
489
+
490
+ # same masking function as in errfunc
491
+ # preserve length of dchisq arrays
492
+ verr = vis - g1 * g2 .conj () * v_scan
493
+ nan_mask = np .isnan (verr )
494
+
495
+ dchisq_dg1r [nan_mask ] = 0
496
+ dchisq_dg1i [nan_mask ] = 0
497
+ dchisq_dg2r [nan_mask ] = 0
498
+ dchisq_dg2i [nan_mask ] = 0
499
+
500
+ # derivitives of real and imaginary gains
501
+ dchisq_dgr = np .zeros (len (gpar )// 2 ) #len(gpar) must be even
502
+ dchisq_dgi = np .zeros (len (gpar )// 2 )
503
+
504
+ # TODO faster than a for loop?
505
+ for i in range (len (gpar )// 2 ):
506
+ g1idx = np .argwhere (np .array (g1_keys )== i )
507
+ g2idx = np .argwhere (np .array (g2_keys )== i )
508
+
509
+ dchisq_dgr [i ] = np .sum (dchisq_dg1r [g1idx ]) + np .sum (dchisq_dg2r [g2idx ])
510
+ dchisq_dgi [i ] = np .sum (dchisq_dg1i [g1idx ]) + np .sum (dchisq_dg2i [g2idx ])
511
+
512
+ ###################################
513
+ # prior term chi^2 derivitive
514
+ ###################################
515
+
516
+ # NOTE this derivitive doesn't account for possible sharp change in tol at g=1
517
+ gsq = np .abs (g [:- 1 ])** 2 # don't count default missing site dummy value
518
+ tolsq = ((np .abs (g [:- 1 ]) > 1 ) * tol0 + (np .abs (g [:- 1 ]) <= 1 ) * tol1 )** 2
519
+
520
+ dchisqg_dgr = gr * np .log (gsq )/ gsq / tolsq
521
+ dchisqg_dgi = gi * np .log (gsq )/ gsq / tolsq
522
+
523
+ # total derivative
524
+ dchisqtot_dgr = dchisq_dgr + dchisqg_dgr
525
+ dchisqtot_dgi = dchisq_dgi + dchisqg_dgi
526
+
527
+ # interleave final derivs
528
+ dchisqtot_dgpar = np .zeros (len (gpar ))
529
+ dchisqtot_dgpar [0 ::2 ] = dchisqtot_dgr
530
+ dchisqtot_dgpar [1 ::2 ] = dchisqtot_dgi
531
+
532
+ # any imaginary parts??? should all be real
533
+ dchisqtot_dgpar = np .real (dchisqtot_dgpar )
534
+
535
+ return dchisqtot_dgpar
536
+
537
+
538
+
0 commit comments