|
1101 | 1101 | "\n",
|
1102 | 1102 | "There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n",
|
1103 | 1103 | "\n",
|
1104 |
| - " - `lax.cond` _will be differentiable soon_\n", |
1105 |
| - " - `lax.while_loop` __non-differentiable__*\n", |
1106 |
| - " - `lax.fori_loop` __non-differentiable__*\n", |
| 1104 | + " - `lax.cond` _differentiable_\n", |
| 1105 | + " - `lax.while_loop` __fwd-mode-differentiable__\n", |
| 1106 | + " - `lax.fori_loop` __fwd-mode-differentiable__\n", |
1107 | 1107 | " - `lax.scan` _differentiable_\n",
|
1108 |
| - "\n", |
1109 |
| - "*_these can in principle be made to be __forward__-differentiable, but this isn't on the current roadmap._" |
| 1108 | + "\n" |
1110 | 1109 | ]
|
1111 | 1110 | },
|
1112 | 1111 | {
|
|
1289 | 1288 | "\\textrm{if} & ❌ & ✔ \\\\\n",
|
1290 | 1289 | "\\textrm{for} & ✔* & ✔\\\\\n",
|
1291 | 1290 | "\\textrm{while} & ✔* & ✔\\\\\n",
|
1292 |
| - "\\textrm{lax.cond} & ✔ & \\textrm{soon!}\\\\\n", |
1293 |
| - "\\textrm{lax.while_loop} & ✔ & ❌\\\\\n", |
1294 |
| - "\\textrm{lax.fori_loop} & ✔ & ❌\\\\\n", |
| 1291 | + "\\textrm{lax.cond} & ✔ & ✔\\\\\n", |
| 1292 | + "\\textrm{lax.while_loop} & ✔ & \\textrm{fwd}\\\\\n", |
| 1293 | + "\\textrm{lax.fori_loop} & ✔ & \\textrm{fwd}\\\\\n", |
1295 | 1294 | "\\textrm{lax.scan} & ✔ & ✔\\\\\n",
|
1296 | 1295 | "\\hline\n",
|
1297 | 1296 | "\\end{array}\n",
|
|
0 commit comments