@@ -598,21 +598,32 @@ def addupdate_transpose(cts_in, ref, x, *idx, **params):
598598
599599## get/swap/addupdate partial_eval_custom rules
600600
601- def _state_partial_eval_custom (prim , saveable , unks_in , inst_in , eqn ):
602- if any (unks_in ):
603- res = [v for v , inst in zip (eqn .invars , inst_in ) if not inst ]
604- return None , eqn , [True ] * len (eqn .outvars ), [True ] * len (eqn .outvars ), res
605- elif saveable (prim , * [var .aval for var in eqn .invars ], ** eqn .params ):
606- return eqn , None , [False ] * len (eqn .outvars ), [False ] * len (eqn .outvars ), []
607- res = [v for v , inst in zip (eqn .invars , inst_in ) if not inst ]
608- return eqn , eqn , [False ] * len (eqn .outvars ), [True ] * len (eqn .outvars ), res
609-
610- pe .partial_eval_jaxpr_custom_rules [get_p ] = partial (_state_partial_eval_custom ,
611- get_p )
612- pe .partial_eval_jaxpr_custom_rules [swap_p ] = partial (_state_partial_eval_custom ,
613- swap_p )
614- pe .partial_eval_jaxpr_custom_rules [addupdate_p ] = partial (
615- _state_partial_eval_custom , addupdate_p )
601+ def _array_ref_partial_eval_custom (saveable , unks_in , inst_in , eqn ):
602+ del saveable # ignored, always full remat array_ref on known input
603+ unk , = unks_in
604+ inst , = inst_in
605+ invar , = eqn .invars
606+ res = [invar ] if not inst else []
607+ if unk :
608+ return None , eqn , [True ], [True ], res # tangent operation
609+ else :
610+ return eqn , eqn , [False ], [True ], res # full remat
611+ pe .partial_eval_jaxpr_custom_rules [core .array_ref_p ] = _array_ref_partial_eval_custom
612+
613+ def _state_partial_eval_custom (saveable , unks_in , inst_in , eqn ):
614+ del saveable # ignored, always full remat state ops on known inputs
615+ ref_unk , * _ = unks_in
616+ ref_inst , * inst_in = inst_in
617+ _ , * val_vars = eqn .invars
618+ assert ref_inst
619+ res = [v for v , inst in zip (val_vars , inst_in ) if not inst ]
620+ if ref_unk :
621+ return None , eqn , [True ], [True ], res # tangent operation
622+ else :
623+ return eqn , eqn , [False ], [True ], res # full remat
624+ pe .partial_eval_jaxpr_custom_rules [get_p ] = _state_partial_eval_custom
625+ pe .partial_eval_jaxpr_custom_rules [swap_p ] = _state_partial_eval_custom
626+ pe .partial_eval_jaxpr_custom_rules [addupdate_p ] = _state_partial_eval_custom
616627
617628## get/swap/addupdate batching rules
618629
0 commit comments