Skip to content

Commit 2ea71b3

Browse files
committed
fix box bug in scan transpose
1 parent 64e6f93 commit 2ea71b3

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

jax/_src/lax/control_flow/loops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,14 +876,15 @@ def _scan_transpose(cts, *args, reverse, length, num_consts,
876876
jaxpr_trans = pe.move_outvars_to_back(
877877
jaxpr_trans, appends_out + [False] * (len(jaxpr_trans.out_avals) - len(appends_out)))
878878
num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind)
879-
in attrs_tracked if kind is pe.ReadWrite)
879+
in attrs_tracked if kind in (pe.ReadWrite, pe.BoxAttr))
880880
linear_trans = ([False] * num_ires + [False] * num_attr_carry +
881881
[True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
882882
[False] * num_eres)
883883
in_state = _get_states(attrs_tracked)
884884

885885
transpose_inputs = *ires, *in_state, *ct_consts, *ct_carry, *ct_ys, *eres
886886
transpose_num_out_carry = num_consts-num_ires+num_carry+num_attr_carry
887+
assert len(transpose_inputs) == len(linear_trans), breakpoint()
887888

888889
if not _split_transpose:
889890
outs = scan_p.bind(

tests/attrs_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from dataclasses import dataclass
1818
import itertools as it
19+
from functools import partial
1920

2021
from absl.testing import absltest
2122
from absl.testing import parameterized
@@ -1325,6 +1326,50 @@ def f(lst1, lst2):
13251326
with self.assertRaisesRegex(ValueError, "a List instance can't be passed"):
13261327
f(b, b)
13271328

1329+
def test_scan_transpose_regression(self):
1330+
@partial(jax.custom_vjp, nondiff_argnums=(0,))
1331+
def g2(box, x, i):
1332+
return x
1333+
1334+
def g2_fwd(box, x, i):
1335+
return x, i
1336+
1337+
def g2_bwd(box, i, g):
1338+
boxg = g.reshape((1, *g.shape))
1339+
box.set(jax.lax.dynamic_update_slice(box.get(), boxg, (i, 0, 0)))
1340+
return g, None
1341+
1342+
g2.defvjp(g2_fwd, g2_bwd)
1343+
1344+
def block_box(x, w, box, index):
1345+
x = x @ w
1346+
x = g2(box, x, index)
1347+
return x
1348+
1349+
def fwd_box(ws, x, box, stack):
1350+
def internal_scan(carry, w):
1351+
x, index = carry
1352+
x = block_box(x, w, box, index)
1353+
return (x, index+1), None
1354+
1355+
(x, _), _ = jax.lax.scan(internal_scan, (x, 0), ws)
1356+
return x
1357+
1358+
D = 16
1359+
L = 4
1360+
B = 2
1361+
x = jnp.ones((B, D))
1362+
y = jnp.ones((B, D))
1363+
ws = jnp.ones((L, D, D))
1364+
grad_box = Box(jnp.zeros((L, B, D)))
1365+
grad_list = {'block_out': List()}
1366+
1367+
def jax_loss(ws, inputs, targets, box_or_list):
1368+
preds = fwd_box(ws, inputs, box_or_list, stack=stack)
1369+
return jnp.square(preds - targets).mean()
1370+
1371+
attr_grads = jax.grad(jax_loss)(ws, x, y, grad_box) # don't crash
1372+
13281373

13291374
if __name__ == '__main__':
13301375
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)