Skip to content

Commit 0542fe4

Browse files
committed
Add dict option for max_order to allow finer control of which cross terms to include
1 parent aa46f3a commit 0542fe4

File tree

2 files changed

+77
-6
lines changed

2 files changed

+77
-6
lines changed

piff/basis_interp.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,8 +463,13 @@ class BasisPolynomial(BasisInterp):
463463
[default: ('u','v')]
464464
:param max_order: The maximum total order to use for cross terms between keys.
465465
[default: None, which uses the maximum value of any individual key's order]
466+
If this is an integer, it applies to all pairs, but you may also specify
467+
a dict mapping pairs of keys to an integer. E.g. {('u','v'):3,
468+
('u','z'):0, ('v','z'):0}. This sets the maximum order for cross terms
469+
between these pairs. Furthermore, any pairs for which you want to skip
470+
cross terms (max=0) may be omitted from the dict.
466471
:param solver: Which solver to use. Solvers available are "scipy", "qr", "jax",
467-
"cpp". See above for details.
472+
"cpp". See above for details. [default: 'scipy']
468473
:param logger: A logger object for logging debug info. [default: None]
469474
"""
470475
_type_name = 'BasisPolynomial'
@@ -515,12 +520,13 @@ def __init__(
515520
logger.warning("JAX not installed. Reverting to numpy/scipy.")
516521
self.solver = "scipy"
517522

518-
if self._max_order<0 or np.any(np.array(self._orders) < 0):
523+
if np.any(np.array(self._orders) < 0):
519524
# Exception if we have any requests for negative orders
520525
raise ValueError('Negative polynomial order specified')
521526

522527
self.kwargs = {
523528
'order' : order,
529+
'max_order' : max_order,
524530
'keys' : keys,
525531
'solver': solver,
526532
}
@@ -529,9 +535,38 @@ def __init__(
529535
# Start with 1d arrays giving orders in all dimensions
530536
ord_ranges = [np.arange(order+1,dtype=int) for order in self._orders]
531537
# Nifty trick to produce n-dim array holding total order
532-
#sumorder = np.sum(np.ix_(*ord_ranges)) # This version doesn't work in numpy 1.19
533538
sumorder = np.sum(np.meshgrid(*ord_ranges, indexing='ij'), axis=0)
534-
self._mask = sumorder <= self._max_order
539+
540+
if isinstance(self._max_order, dict):
541+
# This code is not particularly efficient. Hopefully it doesn't matter.
542+
# Basically set a maxorder for each element in sumorder based on whether it is
543+
# a) a power of a single key. Use the order for that key.
544+
# b) a cross-product of multiple keys. Use it only if it is in the max_order dict.
545+
max_orders = np.zeros_like(sumorder)
546+
547+
def get_indices(arr, pre=()):
548+
# Get the index tuples of the given multi-dimensional array.
549+
if not isinstance(arr, np.ndarray):
550+
yield pre
551+
else:
552+
for i in range(len(arr)):
553+
yield from get_indices(arr[i], pre + (i,))
554+
for index in get_indices(sumorder):
555+
for k, order in enumerate(self._orders):
556+
if index[k] > 0 and all(index[j] == 0 for j in range(len(index)) if j != k):
557+
max_orders[index] = order
558+
for keys, order in self._max_order.items():
559+
kk = [keys.index(key) for key in keys]
560+
ok = True
561+
for k in range(len(self._orders)):
562+
if index[k] > 0 and k not in kk: ok = False
563+
if index[k] == 0 and k in kk: ok = False
564+
if ok:
565+
max_orders[index] = order
566+
else:
567+
max_orders = self._max_order
568+
569+
self._mask = sumorder <= max_orders
535570

536571
def getProperties(self, star):
537572
return np.array([star.data[k] for k in self._keys], dtype=float)
@@ -558,7 +593,6 @@ def basis(self, star):
558593
p[1:] = vals[i]
559594
pows1d.append(np.cumprod(p))
560595
# Use trick to produce outer product of all these powers
561-
#pows2d = np.prod(np.ix_(*pows1d))
562596
pows2d = np.prod(np.meshgrid(*pows1d, indexing='ij'), axis=0)
563597
# Return linear array of terms making total power constraint
564598
return pows2d[self._mask]
@@ -595,4 +629,3 @@ def _finish_read(self, reader):
595629
data = reader.read_table('solution')
596630
assert data is not None
597631
self.q = data['q'][0]
598-

tests/test_pixel.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,6 +1728,15 @@ def test_color():
17281728
piff.piffify(config)
17291729
psf = piff.read(psf_file)
17301730

1731+
# Show that the basis includes cross terms:
1732+
np.testing.assert_equal(psf.interp._orders, [2,2,1])
1733+
assert psf.interp._max_order == 2
1734+
s = psf.stars[0]
1735+
np.testing.assert_allclose(
1736+
psf.interp.basis(s),
1737+
[ 1, s['color'], s['v'], s['v']*s['color'], s['v']**2,
1738+
s['u'], s['u']*s['color'], s['u']*s['v'], s['u']**2 ])
1739+
17311740
for s in psf.stars:
17321741
orig_stamp = s.image
17331742
weight = s.weight
@@ -1742,6 +1751,35 @@ def test_color():
17421751
# Anyway, I think the fit is working, just this test doesn't
17431752
# seem quite the right thing.
17441753

1754+
# Repeat without the cross-terms between color and u/v.
1755+
config['psf']['interp']['max_order'] = { ('u','v') : 2 }
1756+
piff.piffify(config)
1757+
psf = piff.read(psf_file)
1758+
1759+
# Show that the basis now doesn't include cross terms:
1760+
np.testing.assert_equal(psf.interp._orders, [2,2,1])
1761+
assert psf.interp._max_order == { ('u','v') : 2 }
1762+
s = psf.stars[0]
1763+
print(s['u'], s['v'], s['color'])
1764+
print('basis = ',psf.interp.basis(s))
1765+
np.testing.assert_allclose(
1766+
psf.interp.basis(s),
1767+
[ 1, s['color'], s['v'], s['v']**2, s['u'], s['u']*s['v'], s['u']**2 ])
1768+
1769+
# Still works just as well without the cross terms.
1770+
for s in psf.stars:
1771+
orig_stamp = s.image
1772+
weight = s.weight
1773+
offset = s.center_to_offset(s.fit.center)
1774+
image = psf.draw(x=s['x'], y=s['y'], color=s['color'],
1775+
stamp_size=32, flux=s.fit.flux, offset=offset)
1776+
resid = image - orig_stamp
1777+
chisq = np.sum(resid.array**2 * weight.array)
1778+
dof = np.sum(weight.array != 0)
1779+
print('color = ',s['color'],'chisq = ',chisq,'dof = ',dof)
1780+
assert chisq < dof * 1.5
1781+
1782+
17451783
@timer
17461784
def test_convert_func():
17471785
"""Test PixelGrid fitting with a non-trivial convert_func

0 commit comments

Comments
 (0)