Skip to content

Commit 1f44ca6

Browse files
committed
fix box bug in scan transpose
1 parent 64e6f93 commit 1f44ca6

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

jax/_src/lax/control_flow/loops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ def _scan_transpose(cts, *args, reverse, length, num_consts,
876876
jaxpr_trans = pe.move_outvars_to_back(
877877
jaxpr_trans, appends_out + [False] * (len(jaxpr_trans.out_avals) - len(appends_out)))
878878
num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind)
879-
in attrs_tracked if kind is pe.ReadWrite)
879+
in attrs_tracked if kind in (pe.ReadWrite, pe.BoxAttr))
880880
linear_trans = ([False] * num_ires + [False] * num_attr_carry +
881881
[True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
882882
[False] * num_eres)
@@ -885,6 +885,8 @@ def _scan_transpose(cts, *args, reverse, length, num_consts,
885885
transpose_inputs = *ires, *in_state, *ct_consts, *ct_carry, *ct_ys, *eres
886886
transpose_num_out_carry = num_consts-num_ires+num_carry+num_attr_carry
887887

888+
assert len(transpose_inputs) == len(linear_trans), breakpoint()
889+
888890
if not _split_transpose:
889891
outs = scan_p.bind(
890892
*transpose_inputs,

0 commit comments

Comments
 (0)