@@ -136,10 +136,10 @@ def process_call(self, call_primitive, f, tracers, params):
136
136
lifted_jaxpr = convert_constvars_jaxpr (jaxpr )
137
137
out_tracers = [JaxprTracer (self , PartialVal ((out_pv , out_pv_const )), None )
138
138
for out_pv , out_pv_const in zip (out_pvs , out_pv_consts )]
139
+ new_params = dict (params , call_jaxpr = lifted_jaxpr )
139
140
# The `jaxpr` already contains the env_vars at start of invars
140
141
eqn = new_eqn_recipe (tuple (it .chain (const_tracers , env_tracers , tracers )),
141
- out_tracers , call_primitive , params ,
142
- subjaxpr = lifted_jaxpr )
142
+ out_tracers , call_primitive , new_params )
143
143
for t in out_tracers :
144
144
t .recipe = eqn
145
145
return out_tracers
@@ -162,10 +162,10 @@ def process_map(self, map_primitive, f, tracers, params):
162
162
new_params = dict (params ,
163
163
mapped_invars = tuple ([True ] * len (const_tracers ) +
164
164
[False ] * len (env_tracers ) +
165
- [True ] * len (tracers )))
165
+ [True ] * len (tracers )),
166
+ call_jaxpr = lifted_jaxpr )
166
167
eqn = new_eqn_recipe (tuple (it .chain (const_tracers , env_tracers , tracers )),
167
- out_tracers , map_primitive , new_params ,
168
- subjaxpr = lifted_jaxpr )
168
+ out_tracers , map_primitive , new_params )
169
169
for t in out_tracers :
170
170
t .recipe = eqn
171
171
return out_tracers
@@ -187,10 +187,10 @@ def todo(x):
187
187
lifted_jaxpr = convert_constvars_jaxpr (jaxpr )
188
188
out_tracers = [JaxprTracer (trace , PartialVal ((out_pv , out_pv_const )), None )
189
189
for out_pv , out_pv_const in zip (out_pvs , out_pv_consts )]
190
+ new_params = dict (params , call_jaxpr = lifted_jaxpr )
190
191
# The `jaxpr` already contains the env_vars at start of invars
191
192
eqn = new_eqn_recipe (tuple (it .chain (const_tracers , env_tracers )),
192
- out_tracers , call_primitive , params ,
193
- subjaxpr = lifted_jaxpr )
193
+ out_tracers , call_primitive , new_params )
194
194
for t in out_tracers :
195
195
t .recipe = eqn
196
196
return out_tracers
@@ -215,11 +215,11 @@ def todo(x):
215
215
for out_pv , out_pv_const in zip (out_pvs , out_pv_consts )]
216
216
new_params = dict (params ,
217
217
mapped_invars = tuple ([True ] * len (const_tracers ) +
218
- [False ] * len (env )))
218
+ [False ] * len (env )),
219
+ call_jaxpr = lifted_jaxpr )
219
220
env_tracers = map (trace .full_raise , env )
220
221
eqn = new_eqn_recipe (it .chain (const_tracers , env_tracers ),
221
- out_tracers , map_primitive , new_params ,
222
- subjaxpr = lifted_jaxpr )
222
+ out_tracers , map_primitive , new_params )
223
223
for t in out_tracers :
224
224
t .recipe = eqn
225
225
return out_tracers
@@ -383,38 +383,31 @@ def instantiate_const_at(trace, instantiate, tracer):
383
383
ConstVar = namedtuple ('ConstVar' , ['val' ])
384
384
LambdaBinding = namedtuple ('LambdaBinding' , [])
385
385
JaxprEqnRecipe = namedtuple ('JaxprEqnRecipe' ,
386
- ['eqn_id' , 'invars' , 'outvars' , 'primitive' ,
387
- 'bound_subjaxpr' , 'params' ])
386
+ ['eqn_id' , 'invars' , 'outvars' , 'primitive' , 'params' ])
388
387
389
-
390
- def new_eqn_recipe (invars , outvars , primitive , params ,
391
- subjaxpr = None ):
388
+ def new_eqn_recipe (invars , outvars , primitive , params ):
392
389
"""Constructs a new JaxEqnRecipe.
393
390
394
391
Params:
395
392
invars: the tracers for the primitive inputs.
396
393
outvars: the tracers for the primitive outputs.
397
394
primitive: the primitive.
398
395
params: the primitive params
399
- subjaxpr: (optional) a sub-Jaxpr, used only for `xla_call` or `xla_pmap`.
400
- If present, then `subjaxpr.invars` correspond to `invars.
401
396
"""
402
- if subjaxpr is not None :
403
- assert len (subjaxpr .constvars ) == 0
404
- assert len (subjaxpr .invars ) == len (tuple (invars ))
405
- bound_subjaxpr = subjaxpr
406
- else :
407
- bound_subjaxpr = None
408
-
397
+ if primitive .call_primitive :
398
+ # TODO(necula): move these checks to core.check_jaxpr, and call it
399
+ # in more places.
400
+ assert "call_jaxpr" in params
409
401
return JaxprEqnRecipe (object (), tuple (invars ), map (ref , outvars ), primitive ,
410
- bound_subjaxpr , params )
402
+ params )
403
+
411
404
412
405
def recipe_to_eqn (unused_var , getvar , recipe ):
413
- _ , in_tracers , out_tracer_refs , primitive , bound_subjaxpr , params = recipe
406
+ _ , in_tracers , out_tracer_refs , primitive , params = recipe
414
407
out_tracers = [t_ref () for t_ref in out_tracer_refs ]
415
408
invars = [getvar (t ) for t in in_tracers ]
416
409
outvars = [unused_var () if t is None else getvar (t ) for t in out_tracers ]
417
- return new_jaxpr_eqn (invars , outvars , primitive , bound_subjaxpr , params )
410
+ return new_jaxpr_eqn (invars , outvars , primitive , params )
418
411
419
412
def tracers_to_jaxpr (in_tracers , out_tracers ):
420
413
"""Constructs Jaxpr given tracers for inputs and outputs.
@@ -520,6 +513,7 @@ def _split_aval(unknown, aval):
520
513
521
514
522
515
remat_call_p = core .Primitive ('remat_call' )
516
+ remat_call_p .call_primitive = True
523
517
remat_call = partial (core .call_bind , remat_call_p )
524
518
remat_call_p .def_custom_bind (remat_call )
525
519
remat_call_p .def_impl (core .call_impl )
@@ -593,10 +587,9 @@ def _remat_partial_eval(trace, f, tracers, params):
593
587
const_tracers = map (trace .new_instantiated_const , consts )
594
588
lifted_jaxpr = convert_constvars_jaxpr (typed_jaxpr .jaxpr )
595
589
out_tracers = [JaxprTracer (trace , out_pval , None ) for out_pval in out_pvals ]
590
+ new_params = dict (params , call_jaxpr = lifted_jaxpr )
596
591
eqn = new_eqn_recipe (tuple (it .chain (const_tracers , instantiated_tracers )),
597
- out_tracers , remat_call_p ,
598
- params ,
599
- subjaxpr = lifted_jaxpr )
592
+ out_tracers , remat_call_p , new_params )
600
593
for t in out_tracers : t .recipe = eqn
601
594
return out_tracers
602
595
call_partial_eval_rules [remat_call_p ] = _remat_partial_eval
0 commit comments