@@ -463,8 +463,13 @@ class BasisPolynomial(BasisInterp):
463463 [default: ('u','v')]
464464 :param max_order: The maximum total order to use for cross terms between keys.
465465 [default: None, which uses the maximum value of any individual key's order]
466+ If this is an integer, it applies to all pairs, but you may also specify
467+ a dict mapping pairs of keys to an integer. E.g. {('u','v'):3,
468+ ('u','z'):0, ('v','z'):0}. This sets the maximum order for cross terms
469+ between these pairs. Furthermore, any pairs for which you want to skip
470+ cross terms (max=0) may be omitted from the dict.
466471 :param solver: Which solver to use. Solvers available are "scipy", "qr", "jax",
467- "cpp". See above for details.
472+ "cpp". See above for details. [default: 'scipy']
468473 :param logger: A logger object for logging debug info. [default: None]
469474 """
470475 _type_name = 'BasisPolynomial'
@@ -515,12 +520,13 @@ def __init__(
515520 logger .warning ("JAX not installed. Reverting to numpy/scipy." )
516521 self .solver = "scipy"
517522
518- if self . _max_order < 0 or np .any (np .array (self ._orders ) < 0 ):
523+ if np .any (np .array (self ._orders ) < 0 ):
519524 # Exception if we have any requests for negative orders
520525 raise ValueError ('Negative polynomial order specified' )
521526
522527 self .kwargs = {
523528 'order' : order ,
529+ 'max_order' : max_order ,
524530 'keys' : keys ,
525531 'solver' : solver ,
526532 }
@@ -529,9 +535,38 @@ def __init__(
529535 # Start with 1d arrays giving orders in all dimensions
530536 ord_ranges = [np .arange (order + 1 ,dtype = int ) for order in self ._orders ]
531537 # Nifty trick to produce n-dim array holding total order
532- #sumorder = np.sum(np.ix_(*ord_ranges)) # This version doesn't work in numpy 1.19
533538 sumorder = np .sum (np .meshgrid (* ord_ranges , indexing = 'ij' ), axis = 0 )
534- self ._mask = sumorder <= self ._max_order
539+
540+ if isinstance (self ._max_order , dict ):
541+ # This code is not particularly efficient. Hopefully it doesn't matter.
542+ # Basically set a maxorder for each element in sumorder based on whether it is
543+ # a) a power of a single key. Use the order for that key.
544+ # b) a cross-product of multiple keys. Use it only if it is in the max_order dict.
545+ max_orders = np .zeros_like (sumorder )
546+
547+ def get_indices (arr , pre = ()):
548+ # Get the index tuples of the given multi-dimensional array.
549+ if not isinstance (arr , np .ndarray ):
550+ yield pre
551+ else :
552+ for i in range (len (arr )):
553+ yield from get_indices (arr [i ], pre + (i ,))
554+ for index in get_indices (sumorder ):
555+ for k , order in enumerate (self ._orders ):
556+ if index [k ] > 0 and all (index [j ] == 0 for j in range (len (index )) if j != k ):
557+ max_orders [index ] = order
558+ for keys , order in self ._max_order .items ():
559+ kk = [keys .index (key ) for key in keys ]
560+ ok = True
561+ for k in range (len (self ._orders )):
562+ if index [k ] > 0 and k not in kk : ok = False
563+ if index [k ] == 0 and k in kk : ok = False
564+ if ok :
565+ max_orders [index ] = order
566+ else :
567+ max_orders = self ._max_order
568+
569+ self ._mask = sumorder <= max_orders
535570
536571 def getProperties (self , star ):
537572 return np .array ([star .data [k ] for k in self ._keys ], dtype = float )
@@ -558,7 +593,6 @@ def basis(self, star):
558593 p [1 :] = vals [i ]
559594 pows1d .append (np .cumprod (p ))
560595 # Use trick to produce outer product of all these powers
561- #pows2d = np.prod(np.ix_(*pows1d))
562596 pows2d = np .prod (np .meshgrid (* pows1d , indexing = 'ij' ), axis = 0 )
563597 # Return linear array of terms making total power constraint
564598 return pows2d [self ._mask ]
@@ -595,4 +629,3 @@ def _finish_read(self, reader):
595629 data = reader .read_table ('solution' )
596630 assert data is not None
597631 self .q = data ['q' ][0 ]
598-
0 commit comments