From 1de8195b6b64a1ec781048a6b5d8df5218c46a4e Mon Sep 17 00:00:00 2001 From: Lukas Heinrich Date: Wed, 15 Jul 2020 14:21:04 +0200 Subject: [PATCH] go --- Tutorial.ipynb | 248 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 246 insertions(+), 2 deletions(-) diff --git a/Tutorial.ipynb b/Tutorial.ipynb index ba9c02e..a09e85d 100644 --- a/Tutorial.ipynb +++ b/Tutorial.ipynb @@ -1318,7 +1318,45 @@ " repetitive calculations of the same values\n", "* (we will see that) arbitrary control flows are handles naturally\n", "* it's something that is easy for a comoputer do and for a progarmmer to imlpement\n", - "\n" + "\n", + "\n", + "\n", + "Some notes on pros and cons:\n", + "\n", + "**In the forward mode**:\n", + "\n", + "the signature of each opeartion basically extends \n", + " ```c++\n", + " float f(float x,float y,float z)\n", + " ```\n", + " to\n", + " ```c++\n", + " pair f(float x,float dx,float y,float float dy, float z,float dz)\n", + " ```\n", + " * if you use composite types (\"dual numbers\") that hold both x,dx you can basically \n", + " keep the signature unchanged\n", + " ```c++\n", + " f(dual x, dual x, dual z)\n", + " ```\n", + " * together with operator overloading on these dual types e.g. `dual * dual` you can \n", + " essentially keep the source code unchanged\n", + " ```c++\n", + " float f(float x, float y): return x*y\n", + " ``` \n", + " ->\n", + " ```c++\n", + " dual f(dual x,dual y): return x*y\n", + " ```\n", + " \n", + "* That means it's very easy implement. And memory efficient, no superfluous values are kept when they run out of scope.\n", + "* But forward more better for vector-value functions of few parameters\n", + "\n", + "\n", + "**In the reverse mode**:\n", + "\n", + "* very efficient, but we need to keep track of order (need a \"tape\" of sorts)\n", + "* since we need to access all intermediate varriables, we can run into memory bounds\n", + "* the procedurer is a bit more complex than fwd: 1) run fwd, 2) zero grads 3) run bwd\n" ] }, { @@ -1334,7 +1372,213 @@ "source": [ "Yes there are! And a lot of them in many languages. On the othe rhand, try finding CAS systems in each of those \n", "\n", - "\"A" + "\"A\n", + "\n", + "This is PyHEP, so let's focus on Python. Here, basically what you think of as \"Machine Learning frameworks\" are at the core autodiff libraries\n", + "\n", + "* Tensorflow\n", + "* PyTorch\n", + "* JAX" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's focus on jax" + ] + }, + { + "cell_type": "code", + "execution_count": 171, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp" + ] + }, + { + "cell_type": "code", + "execution_count": 172, + "metadata": {}, + "outputs": [], + "source": [ + "def f(x):\n", + " return x**2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`jax.numpy` is almost a drop-in rerplacement for `numpy`. I do `import jax.numpy as jnp` but if you're daring you could do `import jax.numpy as np`" + ] + }, + { + "cell_type": "code", + "execution_count": 191, + "metadata": {}, + "outputs": [], + "source": [ + "x = jnp.array([1,2,3])\n", + "y = jnp.array([2,3,4])" + ] + }, + { + "cell_type": "code", + "execution_count": 192, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[3 5 7]\n", + "[ 2 6 12]\n", + "[0. 0.6931472 1.0986123]\n", + "[ 7.389056 20.085537 54.598152]\n" + ] + } + ], + "source": [ + "print(x+y)\n", + "print(x*y)\n", + "print(jnp.log(x))\n", + "print(jnp.exp(y))" + ] + }, + { + "cell_type": "code", + "execution_count": 219, + "metadata": {}, + "outputs": [], + "source": [ + "def f(x):\n", + " return x**3" + ] + }, + { + "cell_type": "code", + "execution_count": 220, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "64.0\n", + "48.0\n", + "24.0\n", + "6.0\n", + "0.0\n" + ] + } + ], + "source": [ + "print(f(4.0))\n", + "print(jax.grad(f)(4.0)) #boom!\n", + "print(jax.grad(jax.grad(f))(4.0)) #boom!\n", + "print(jax.grad(jax.grad(jax.grad(f)))(4.0)) #boom!\n", + "print(jax.grad(jax.grad(jax.grad(jax.grad(f))))(4.0)) #boom!" + ] + }, + { + "cell_type": "code", + "execution_count": 221, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 221, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "xi = jnp.linspace(-5,5)\n", + "yi = f(xi)\n", + "\n", + "plt.plot(xi,yi)" + ] + }, + { + "cell_type": "code", + "execution_count": 222, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Gradient only defined for scalar-output functions. Output had shape: (50,).", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/Code/pyhep2020-autodiff-tutorial/venv/lib/python3.7/site-packages/jax/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 411\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mwraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdocstr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdocstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margnums\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margnums\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 412\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 413\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 414\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Code/pyhep2020-autodiff-tutorial/venv/lib/python3.7/site-packages/jax/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 472\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 473\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 474\u001b[0;31m \u001b[0m_check_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 475\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdtypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 476\u001b[0m \u001b[0mtree_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_check_output_dtype_grad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mholomorphic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Code/pyhep2020-autodiff-tutorial/venv/lib/python3.7/site-packages/jax/api.py\u001b[0m in \u001b[0;36m_check_scalar\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 493\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mShapedArray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 494\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 495\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"had shape: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 496\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 497\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"had abstract value {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: Gradient only defined for scalar-output functions. Output had shape: (50,)." + ] + } + ], + "source": [ + "jax.grad(f)(xi)" + ] + }, + { + "cell_type": "code", + "execution_count": 227, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 227, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "g1i = jax.vmap(jax.grad(f))(xi)\n", + "g2i = jax.vmap(jax.grad(jax.grad(f)))(xi)\n", + "g3i = jax.vmap(jax.grad(jax.grad(jax.grad(f))))(xi)\n", + "plt.plot(xi,yi, label = \"f\")\n", + "plt.plot(xi,g1i, label = \"f'\")\n", + "plt.plot(xi,g2i, label = \"f''\")\n", + "plt.plot(xi,g3i, label = \"f'''\")\n", + "plt.legend()" ] }, {