Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions docs/notebooks/Common_Gotchas_in_JAX.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1101,12 +1101,11 @@
"\n",
"There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n",
"\n",
" - `lax.cond` _will be differentiable soon_\n",
" - `lax.while_loop` __non-differentiable__*\n",
" - `lax.fori_loop` __non-differentiable__*\n",
" - `lax.cond` _differentiable_\n",
" - `lax.while_loop` __fwd-mode-differentiable__\n",
" - `lax.fori_loop` __fwd-mode-differentiable__\n",
" - `lax.scan` _differentiable_\n",
"\n",
"*_these can in principle be made to be __forward__-differentiable, but this isn't on the current roadmap._"
"\n"
]
},
{
Expand Down Expand Up @@ -1289,9 +1288,9 @@
"\\textrm{if} & ❌ & ✔ \\\\\n",
"\\textrm{for} & ✔* & ✔\\\\\n",
"\\textrm{while} & ✔* & ✔\\\\\n",
"\\textrm{lax.cond} & ✔ & \\textrm{soon!}\\\\\n",
"\\textrm{lax.while_loop} & ✔ & \\\\\n",
"\\textrm{lax.fori_loop} & ✔ & \\\\\n",
"\\textrm{lax.cond} & ✔ & \\\\\n",
"\\textrm{lax.while_loop} & ✔ & \\textrm{fwd}\\\\\n",
"\\textrm{lax.fori_loop} & ✔ & \\textrm{fwd}\\\\\n",
"\\textrm{lax.scan} & ✔ & ✔\\\\\n",
"\\hline\n",
"\\end{array}\n",
Expand Down