Skip to content

Commit 28e802c

Browse files
authored
Fix Gotchas notebook regarding control flow differentiation. (jax-ml#2194)
1 parent f6bd0a7 commit 28e802c

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

docs/notebooks/Common_Gotchas_in_JAX.ipynb

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,12 +1101,11 @@
11011101
"\n",
11021102
"There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n",
11031103
"\n",
1104-
" - `lax.cond` _will be differentiable soon_\n",
1105-
" - `lax.while_loop` __non-differentiable__*\n",
1106-
" - `lax.fori_loop` __non-differentiable__*\n",
1104+
" - `lax.cond` _differentiable_\n",
1105+
" - `lax.while_loop` __fwd-mode-differentiable__\n",
1106+
" - `lax.fori_loop` __fwd-mode-differentiable__\n",
11071107
" - `lax.scan` _differentiable_\n",
1108-
"\n",
1109-
"*_these can in principle be made to be __forward__-differentiable, but this isn't on the current roadmap._"
1108+
"\n"
11101109
]
11111110
},
11121111
{
@@ -1289,9 +1288,9 @@
12891288
"\\textrm{if} & ❌ & ✔ \\\\\n",
12901289
"\\textrm{for} & ✔* & ✔\\\\\n",
12911290
"\\textrm{while} & ✔* & ✔\\\\\n",
1292-
"\\textrm{lax.cond} & ✔ & \\textrm{soon!}\\\\\n",
1293-
"\\textrm{lax.while_loop} & ✔ & \\\\\n",
1294-
"\\textrm{lax.fori_loop} & ✔ & \\\\\n",
1291+
"\\textrm{lax.cond} & ✔ & \\\\\n",
1292+
"\\textrm{lax.while_loop} & ✔ & \\textrm{fwd}\\\\\n",
1293+
"\\textrm{lax.fori_loop} & ✔ & \\textrm{fwd}\\\\\n",
12951294
"\\textrm{lax.scan} & ✔ & ✔\\\\\n",
12961295
"\\hline\n",
12971296
"\\end{array}\n",

0 commit comments

Comments
 (0)