Skip to content

Commit 954f857

Browse files
committed
Add dict option for max_order to allow finer control of which cross terms to include
1 parent ae43f03 commit 954f857

File tree

2 files changed

+77
-7
lines changed

2 files changed

+77
-7
lines changed

piff/basis_interp.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,8 +466,13 @@ class BasisPolynomial(BasisInterp):
466466
[default: ('u','v')]
467467
:param max_order: The maximum total order to use for cross terms between keys.
468468
[default: None, which uses the maximum value of any individual key's order]
469+
If this is an integer, it applies to all pairs, but you may also specify
470+
a dict mapping pairs of keys to an integer. E.g. {('u','v'):3,
471+
('u','z'):0, ('v','z'):0}. This sets the maximum order for cross terms
472+
between these pairs. Furthermore, any pairs for which you want to skip
473+
cross terms (max=0) may be omitted from the dict.
469474
:param solver: Which solver to use. Solvers available are "scipy", "qr", "jax",
470-
"cpp". See above for details.
475+
"cpp". See above for details. [default: 'scipy']
471476
:param logger: A logger object for logging debug info. [default: None]
472477
"""
473478
_type_name = 'BasisPolynomial'
@@ -518,12 +523,13 @@ def __init__(
518523
logger.warning("JAX not installed. Reverting to numpy/scipy.")
519524
self.solver = "scipy"
520525

521-
if self._max_order<0 or np.any(np.array(self._orders) < 0):
526+
if np.any(np.array(self._orders) < 0):
522527
# Exception if we have any requests for negative orders
523528
raise ValueError('Negative polynomial order specified')
524529

525530
self.kwargs = {
526531
'order' : order,
532+
'max_order' : max_order,
527533
'keys' : keys,
528534
'solver': solver,
529535
}
@@ -532,9 +538,38 @@ def __init__(
532538
# Start with 1d arrays giving orders in all dimensions
533539
ord_ranges = [np.arange(order+1,dtype=int) for order in self._orders]
534540
# Nifty trick to produce n-dim array holding total order
535-
#sumorder = np.sum(np.ix_(*ord_ranges)) # This version doesn't work in numpy 1.19
536541
sumorder = np.sum(np.meshgrid(*ord_ranges, indexing='ij'), axis=0)
537-
self._mask = sumorder <= self._max_order
542+
543+
if isinstance(self._max_order, dict):
544+
# This code is not particularly efficient. Hopefully it doesn't matter.
545+
# Basically set a maxorder for each element in sumorder based on whether it is
546+
# a) a power of a single key. Use the order for that key.
547+
# b) a cross-product of multiple keys. Use it only if it is in the max_order dict.
548+
max_orders = np.zeros_like(sumorder)
549+
550+
def get_indices(arr, pre=()):
551+
# Get the index tuples of the given multi-dimensional array.
552+
if not isinstance(arr, np.ndarray):
553+
yield pre
554+
else:
555+
for i in range(len(arr)):
556+
yield from get_indices(arr[i], pre + (i,))
557+
for index in get_indices(sumorder):
558+
for k, order in enumerate(self._orders):
559+
if index[k] > 0 and all(index[j] == 0 for j in range(len(index)) if j != k):
560+
max_orders[index] = order
561+
for keys, order in self._max_order.items():
562+
kk = [keys.index(key) for key in keys]
563+
ok = True
564+
for k in range(len(self._orders)):
565+
if index[k] > 0 and k not in kk: ok = False
566+
if index[k] == 0 and k in kk: ok = False
567+
if ok:
568+
max_orders[index] = order
569+
else:
570+
max_orders = self._max_order
571+
572+
self._mask = sumorder <= max_orders
538573

539574
def getProperties(self, star):
540575
return np.array([star.data[k] for k in self._keys], dtype=float)
@@ -561,7 +596,6 @@ def basis(self, star):
561596
p[1:] = vals[i]
562597
pows1d.append(np.cumprod(p))
563598
# Use trick to produce outer product of all these powers
564-
#pows2d = np.prod(np.ix_(*pows1d))
565599
pows2d = np.prod(np.meshgrid(*pows1d, indexing='ij'), axis=0)
566600
# Return linear array of terms making total power constraint
567601
return pows2d[self._mask]
@@ -598,4 +632,3 @@ def _finish_read(self, reader):
598632
data = reader.read_table('solution')
599633
assert data is not None
600634
self.q = data['q'][0]
601-

tests/test_pixel.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,6 @@ def test_basis_interp():
389389
np.testing.assert_raises(ValueError, piff.BasisPolynomial, order=[-1,0])
390390
np.testing.assert_raises(ValueError, piff.BasisPolynomial, order=[-4,-1])
391391
np.testing.assert_raises(ValueError, piff.BasisPolynomial, order=-2)
392-
np.testing.assert_raises(ValueError, piff.BasisPolynomial, order=[3,3], max_order=-1)
393392

394393

395394
@timer
@@ -1728,6 +1727,15 @@ def test_color():
17281727
piff.piffify(config)
17291728
psf = piff.read(psf_file)
17301729

1730+
# Show that the basis includes cross terms:
1731+
np.testing.assert_equal(psf.interp._orders, [2,2,1])
1732+
assert psf.interp._max_order == 2
1733+
s = psf.stars[0]
1734+
np.testing.assert_allclose(
1735+
psf.interp.basis(s),
1736+
[ 1, s['color'], s['v'], s['v']*s['color'], s['v']**2,
1737+
s['u'], s['u']*s['color'], s['u']*s['v'], s['u']**2 ])
1738+
17311739
for s in psf.stars:
17321740
orig_stamp = s.image
17331741
weight = s.weight
@@ -1742,6 +1750,35 @@ def test_color():
17421750
# Anyway, I think the fit is working, just this test doesn't
17431751
# seem quite the right thing.
17441752

1753+
# Repeat without the cross-terms between color and u/v.
1754+
config['psf']['interp']['max_order'] = { ('u','v') : 2 }
1755+
piff.piffify(config)
1756+
psf = piff.read(psf_file)
1757+
1758+
# Show that the basis now doesn't include cross terms:
1759+
np.testing.assert_equal(psf.interp._orders, [2,2,1])
1760+
assert psf.interp._max_order == { ('u','v') : 2 }
1761+
s = psf.stars[0]
1762+
print(s['u'], s['v'], s['color'])
1763+
print('basis = ',psf.interp.basis(s))
1764+
np.testing.assert_allclose(
1765+
psf.interp.basis(s),
1766+
[ 1, s['color'], s['v'], s['v']**2, s['u'], s['u']*s['v'], s['u']**2 ])
1767+
1768+
# Still works just as well without the cross terms.
1769+
for s in psf.stars:
1770+
orig_stamp = s.image
1771+
weight = s.weight
1772+
offset = s.center_to_offset(s.fit.center)
1773+
image = psf.draw(x=s['x'], y=s['y'], color=s['color'],
1774+
stamp_size=32, flux=s.fit.flux, offset=offset)
1775+
resid = image - orig_stamp
1776+
chisq = np.sum(resid.array**2 * weight.array)
1777+
dof = np.sum(weight.array != 0)
1778+
print('color = ',s['color'],'chisq = ',chisq,'dof = ',dof)
1779+
assert chisq < dof * 1.5
1780+
1781+
17451782
@timer
17461783
def test_convert_func():
17471784
"""Test PixelGrid fitting with a non-trivial convert_func

0 commit comments

Comments
 (0)