|
16 | 16 |
|
17 | 17 | from dataclasses import dataclass
|
18 | 18 | import itertools as it
|
| 19 | +from functools import partial |
19 | 20 |
|
20 | 21 | from absl.testing import absltest
|
21 | 22 | from absl.testing import parameterized
|
@@ -1325,6 +1326,50 @@ def f(lst1, lst2):
|
1325 | 1326 | with self.assertRaisesRegex(ValueError, "a List instance can't be passed"):
|
1326 | 1327 | f(b, b)
|
1327 | 1328 |
|
| 1329 | + def test_scan_transpose_regression(self): |
| 1330 | + @partial(jax.custom_vjp, nondiff_argnums=(0,)) |
| 1331 | + def g2(box, x, i): |
| 1332 | + return x |
| 1333 | + |
| 1334 | + def g2_fwd(box, x, i): |
| 1335 | + return x, i |
| 1336 | + |
| 1337 | + def g2_bwd(box, i, g): |
| 1338 | + boxg = g.reshape((1, *g.shape)) |
| 1339 | + box.set(jax.lax.dynamic_update_slice(box.get(), boxg, (i, 0, 0))) |
| 1340 | + return g, None |
| 1341 | + |
| 1342 | + g2.defvjp(g2_fwd, g2_bwd) |
| 1343 | + |
| 1344 | + def block_box(x, w, box, index): |
| 1345 | + x = x @ w |
| 1346 | + x = g2(box, x, index) |
| 1347 | + return x |
| 1348 | + |
| 1349 | + def fwd_box(ws, x, box, stack): |
| 1350 | + def internal_scan(carry, w): |
| 1351 | + x, index = carry |
| 1352 | + x = block_box(x, w, box, index) |
| 1353 | + return (x, index+1), None |
| 1354 | + |
| 1355 | + (x, _), _ = jax.lax.scan(internal_scan, (x, 0), ws) |
| 1356 | + return x |
| 1357 | + |
| 1358 | + D = 16 |
| 1359 | + L = 4 |
| 1360 | + B = 2 |
| 1361 | + x = jnp.ones((B, D)) |
| 1362 | + y = jnp.ones((B, D)) |
| 1363 | + ws = jnp.ones((L, D, D)) |
| 1364 | + grad_box = Box(jnp.zeros((L, B, D))) |
| 1365 | + grad_list = {'block_out': List()} |
| 1366 | + |
| 1367 | + def jax_loss(ws, inputs, targets, box_or_list): |
| 1368 | + preds = fwd_box(ws, inputs, box_or_list, stack=stack) |
| 1369 | + return jnp.square(preds - targets).mean() |
| 1370 | + |
| 1371 | + attr_grads = jax.grad(jax_loss)(ws, x, y, grad_box) # don't crash |
| 1372 | + |
1328 | 1373 |
|
1329 | 1374 | if __name__ == '__main__':
|
1330 | 1375 | absltest.main(testLoader=jtu.JaxTestLoader())
|
0 commit comments