@@ -466,8 +466,13 @@ class BasisPolynomial(BasisInterp):
466466 [default: ('u','v')]
467467 :param max_order: The maximum total order to use for cross terms between keys.
468468 [default: None, which uses the maximum value of any individual key's order]
469+ If this is an integer, it applies to all pairs, but you may also specify
470+ a dict mapping pairs of keys to an integer. E.g. {('u','v'):3,
471+ ('u','z'):0, ('v','z'):0}. This sets the maximum order for cross terms
472+ between these pairs. Furthermore, any pairs for which you want to skip
473+ cross terms (max=0) may be omitted from the dict.
469474 :param solver: Which solver to use. Solvers available are "scipy", "qr", "jax",
470- "cpp". See above for details.
475+ "cpp". See above for details. [default: 'scipy']
471476 :param logger: A logger object for logging debug info. [default: None]
472477 """
473478 _type_name = 'BasisPolynomial'
@@ -518,12 +523,13 @@ def __init__(
518523 logger .warning ("JAX not installed. Reverting to numpy/scipy." )
519524 self .solver = "scipy"
520525
521- if self . _max_order < 0 or np .any (np .array (self ._orders ) < 0 ):
526+ if np .any (np .array (self ._orders ) < 0 ):
522527 # Exception if we have any requests for negative orders
523528 raise ValueError ('Negative polynomial order specified' )
524529
525530 self .kwargs = {
526531 'order' : order ,
532+ 'max_order' : max_order ,
527533 'keys' : keys ,
528534 'solver' : solver ,
529535 }
@@ -532,9 +538,38 @@ def __init__(
532538 # Start with 1d arrays giving orders in all dimensions
533539 ord_ranges = [np .arange (order + 1 ,dtype = int ) for order in self ._orders ]
534540 # Nifty trick to produce n-dim array holding total order
535- #sumorder = np.sum(np.ix_(*ord_ranges)) # This version doesn't work in numpy 1.19
536541 sumorder = np .sum (np .meshgrid (* ord_ranges , indexing = 'ij' ), axis = 0 )
537- self ._mask = sumorder <= self ._max_order
542+
543+ if isinstance (self ._max_order , dict ):
544+ # This code is not particularly efficient. Hopefully it doesn't matter.
545+ # Basically set a maxorder for each element in sumorder based on whether it is
546+ # a) a power of a single key. Use the order for that key.
547+ # b) a cross-product of multiple keys. Use it only if it is in the max_order dict.
548+ max_orders = np .zeros_like (sumorder )
549+
550+ def get_indices (arr , pre = ()):
551+ # Get the index tuples of the given multi-dimensional array.
552+ if not isinstance (arr , np .ndarray ):
553+ yield pre
554+ else :
555+ for i in range (len (arr )):
556+ yield from get_indices (arr [i ], pre + (i ,))
557+ for index in get_indices (sumorder ):
558+ for k , order in enumerate (self ._orders ):
559+ if index [k ] > 0 and all (index [j ] == 0 for j in range (len (index )) if j != k ):
560+ max_orders [index ] = order
561+ for keys , order in self ._max_order .items ():
562+ kk = [keys .index (key ) for key in keys ]
563+ ok = True
564+ for k in range (len (self ._orders )):
565+ if index [k ] > 0 and k not in kk : ok = False
566+ if index [k ] == 0 and k in kk : ok = False
567+ if ok :
568+ max_orders [index ] = order
569+ else :
570+ max_orders = self ._max_order
571+
572+ self ._mask = sumorder <= max_orders
538573
539574 def getProperties (self , star ):
540575 return np .array ([star .data [k ] for k in self ._keys ], dtype = float )
@@ -561,7 +596,6 @@ def basis(self, star):
561596 p [1 :] = vals [i ]
562597 pows1d .append (np .cumprod (p ))
563598 # Use trick to produce outer product of all these powers
564- #pows2d = np.prod(np.ix_(*pows1d))
565599 pows2d = np .prod (np .meshgrid (* pows1d , indexing = 'ij' ), axis = 0 )
566600 # Return linear array of terms making total power constraint
567601 return pows2d [self ._mask ]
@@ -598,4 +632,3 @@ def _finish_read(self, reader):
598632 data = reader .read_table ('solution' )
599633 assert data is not None
600634 self .q = data ['q' ][0 ]
601-
0 commit comments