@@ -3299,6 +3299,59 @@ def body_fun(c, _):
3299
3299
outs_ref = body_fun (body_fun (init_vals , [x [0 ] for x in xs ])[0 ], [x [1 ] for x in xs ])[0 ]
3300
3300
self .assertAllClose (outs , outs_ref , check_dtypes = False )
3301
3301
3302
+ @parameterized .parameters (itertools .product (range (3 ), repeat = 4 ))
3303
+ @jtu .run_on_devices ("cpu" )
3304
+ def test_scan_forwarding_correctness (
3305
+ self ,
3306
+ seed ,
3307
+ num_body_consts ,
3308
+ num_const_fwds ,
3309
+ num_input_fwds ):
3310
+
3311
+ num_carry = num_const_fwds + 4
3312
+ num_xs = num_input_fwds + 2
3313
+ num_ys = num_xs + 1
3314
+
3315
+ rng = np .random .RandomState (seed )
3316
+ carry_perm = rng .permutation (num_carry )
3317
+ carry_iperm = np .argsort (carry_perm )
3318
+
3319
+ xs_perm = rng .permutation (num_xs )
3320
+ ys_perm = rng .permutation (num_ys )
3321
+ f = np .arange (num_xs )
3322
+ f = [f [i ] if idx < num_input_fwds else None for idx , i in enumerate (xs_perm )]
3323
+ f += [None ]
3324
+ in_fwd = [f [i ] for i in ys_perm ]
3325
+
3326
+ body_consts = [rng .randn (3 ) for _ in range (num_body_consts )]
3327
+ init_vals = list (rng .uniform (size = num_carry ))
3328
+
3329
+ def body_fun (c , x ):
3330
+ c = [c [i ] for i in carry_iperm ]
3331
+ carry_fwds , carry_dont_fwd = split_list (c , [num_const_fwds ])
3332
+ carry_dont_fwd = [jnp .sin (x ) * sum (jnp .sum (c ) for c in body_consts )
3333
+ for x in carry_dont_fwd ]
3334
+ new_c_perm = [* carry_fwds , * carry_dont_fwd ]
3335
+ new_c = [new_c_perm [i ] for i in carry_perm ]
3336
+
3337
+ x = [x [i ] for i in xs_perm ]
3338
+ x_fwd , x_dont_fwd = split_list (x , [num_input_fwds ])
3339
+ x_dont_fwd = [jnp .cos (x ) * sum (jnp .sum (c ) for c in body_consts )
3340
+ for x in x_dont_fwd ]
3341
+ y = [* x_fwd , * x_dont_fwd , 0 ]
3342
+ y = [y [i ] for i in ys_perm ]
3343
+
3344
+ return new_c , y
3345
+
3346
+ xs = list (rng .uniform (size = (num_xs , 2 )))
3347
+ final , outs = jax .lax .scan (body_fun , init_vals , xs )
3348
+ for f , y in zip (in_fwd , outs ):
3349
+ if f is not None :
3350
+ self .assertAllClose (y , xs [f ])
3351
+
3352
+ final_ref = body_fun (body_fun (init_vals , [x [0 ] for x in xs ])[0 ], [x [1 ] for x in xs ])[0 ]
3353
+ self .assertAllClose (final , final_ref , check_dtypes = False )
3354
+
3302
3355
def test_scan_diff_of_print (self ):
3303
3356
# ref: https://github.com/jax-ml/jax/issues/28738
3304
3357
def f (c , _ ):
@@ -3311,6 +3364,24 @@ def g(x):
3311
3364
eqn_jaxpr = jaxpr .eqns [0 ].params ["jaxpr" ]
3312
3365
self .assertIn ("debug_callback" , [e .primitive .name for e in eqn_jaxpr .eqns ])
3313
3366
3367
+ def test_scan_input_to_output_forwarding (self ):
3368
+ def f (c , x ):
3369
+ return c + 1 , x
3370
+ def g (x ):
3371
+ return jax .lax .scan (f , 0 , x )
3372
+ jaxpr = jax .make_jaxpr (g )(jnp .arange (3. ))
3373
+ self .assertLen (jaxpr .eqns [0 ].params ["jaxpr" ].jaxpr .outvars , 1 )
3374
+
3375
+ def test_scan_only_forwarding (self ):
3376
+ def f (_ , x ):
3377
+ return None , x
3378
+ def g (x ):
3379
+ return jax .lax .scan (f , None , x )
3380
+ x = jnp .arange (3 )
3381
+ jaxpr = jax .make_jaxpr (g )(x )
3382
+ self .assertLen (jaxpr .eqns , 0 )
3383
+ self .assertArraysEqual (g (x )[1 ], x )
3384
+
3314
3385
3315
3386
if __name__ == '__main__' :
3316
3387
absltest .main (testLoader = jtu .JaxTestLoader ())
0 commit comments