Skip to content

Commit

Permalink
updated tutorial for inverting the lens equation
Browse files Browse the repository at this point in the history
  • Loading branch information
ConnorStoneAstro committed Sep 26, 2024
1 parent fac7eb7 commit e61dbb7
Show file tree
Hide file tree
Showing 4 changed files with 537 additions and 121 deletions.
337 changes: 335 additions & 2 deletions docs/source/tutorials/InvertLensEquation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"source": [
"# Inverting the Lens Equation\n",
"\n",
"The lens equation $\\vec{\\beta} = \\vec{\\theta} - \\vec{\\alpha}(\\vec{\\theta})$ allows us to find a point in the source plane given a point in the image plane. However, sometimes we know a point in the source plane and would like to see where it ends up in the image plane. This is not easy to do since a point in the source plane may map to multiple locations in the image plane. There is no closed form function to invert the lens equation, in large part because the deflection angle $\\vec{\\alpha}$ depends on the position in the image plane $\\vec{\\theta}$. To invert the lens equation, we will need to rely on optimization and a little luck to find all the images for a given source plane point. Below we will demonstrate how this is done in caustic!"
"The lens equation $\\vec{\\beta} = \\vec{\\theta} - \\vec{\\alpha}(\\vec{\\theta})$ allows us to find a point in the source plane given a point in the image plane. However, sometimes we know a point in the source plane and would like to see where it ends up in the image plane. This is not easy to do since a point in the source plane may map to multiple locations in the image plane. There is no closed form function to invert the lens equation, in large part because the deflection angle $\\vec{\\alpha}$ depends on the position in the image plane $\\vec{\\theta}$. To invert the lens equation, we will need to rely on optimization and a iterative procedures to find all the images for a given source plane point. Below we will demonstrate how this is done in caustic!"
]
},
{
Expand All @@ -23,6 +23,8 @@
"\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Polygon\n",
"from matplotlib.collections import PatchCollection\n",
"import numpy as np\n",
"\n",
"import caustics"
Expand Down Expand Up @@ -64,6 +66,14 @@
")"
]
},
{
"cell_type": "markdown",
"id": "2163ed78",
"metadata": {},
"source": [
"Here we run the forward raytracing for our particular lens model. In caustics we provide a convenient `forward_raytrace` function which can be called for any lens model. Internally, this constructs a number of triangles in the image plane, raytraces them to the source plane and identifies which ones contain the desired source plane position. Iteratively subdividing the triangles eventually converges on image plane positions which map to the desired source plane position. See further down for more detail."
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -82,6 +92,14 @@
"bx, by = lens.raytrace(x, y, z_s)"
]
},
{
"cell_type": "markdown",
"id": "462b2e8f",
"metadata": {},
"source": [
"When we raytrace the coordinates we get out from `forward_raytrace` it is not too surprising that they all give source plane positions very close to the desired source plane position. Here we plot them so you can see:"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -119,14 +137,53 @@
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "6641535a",
"metadata": {},
"source": [
"It is also often not necessary to model the central demagnified region since it is so faint (approximately a million times fainter in this case) that it doesn't contribute measurably to the flux of an image. We can very easily check the magnification of every point and remove the unnecessary one."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5d20df7",
"metadata": {},
"outputs": [],
"source": [
"m = lens.magnification(x, y, z_s)\n",
"print(m.detach().cpu().tolist())\n",
"N_m = torch.argsort(m)\n",
"\n",
"fig, ax = plt.subplots()\n",
"CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n",
"# Get the path from the matplotlib contour plot of the critical line\n",
"for path in paths:\n",
" # Collect the path into a discrete set of points\n",
" x1 = torch.tensor(list(float(vs[0]) for vs in path))\n",
" x2 = torch.tensor(list(float(vs[1]) for vs in path))\n",
" # raytrace the points to the source plane\n",
" y1, y2 = lens.raytrace(x1, x2, z_s)\n",
"\n",
" # Plot the caustic\n",
" ax.plot(y1, y2, color=\"r\", zorder=1)\n",
"\n",
"plt.scatter(x[N_m[1:]], y[N_m[1:]], color=\"b\", label=\"magnified\")\n",
"plt.scatter(x[N_m[0]], y[N_m[0]], color=\"r\", label=\"de-magnified\")\n",
"plt.axis(\"off\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "0163d5f7",
"metadata": {},
"source": [
"## Lets take a look\n",
"\n",
"Using the `LensSource` simulator and the forward raytracing coordinates we can focus our calculations on the regions of interest for each image."
"Using the `LensSource` simulator and the forward raytracing coordinates we can focus our calculations on the regions of interest for each image. Note however that the regions can overlap, which they do very slightly in this case."
]
},
{
Expand Down Expand Up @@ -219,10 +276,286 @@
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "7a4c8fd5",
"metadata": {},
"source": [
"## How forward_raytrace works\n",
"\n",
"All forward raytracing methods are imperfect as they involve iterative solutions which require enough resolution to pick out all the relevant image plane positions. To start, lets consider a more naive algorithm, simply placing random points in the image plane, then running a root-finding algorithm to get the source plane positions to line up."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e06ac42",
"metadata": {},
"outputs": [],
"source": [
"Ninit = 100\n",
"x_init = torch.randn(Ninit)\n",
"y_init = torch.randn(Ninit)\n",
"\n",
"\n",
"def raytrace(x, y):\n",
" return lens.raytrace(x, y, z_s)\n",
"\n",
"\n",
"final = caustics.lenses.func.forward_raytrace_rootfind(\n",
" x_init, y_init, sp_x, sp_y, raytrace\n",
")\n",
"x_final, y_final = final[..., 0], final[..., 1]"
]
},
{
"cell_type": "markdown",
"id": "2abb217e",
"metadata": {},
"source": [
"Here we easily find the four magnified images, but the central demagnified image is (often) not found by this method since a point has to get lucky enough to start very close to the correct position in order for the gradient based root finder to work."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d2fe0e6f",
"metadata": {
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"fig, ax = plt.subplots()\n",
"CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n",
"for path in paths:\n",
" # Collect the path into a discrete set of points\n",
" x1 = torch.tensor(list(float(vs[0]) for vs in path))\n",
" x2 = torch.tensor(list(float(vs[1]) for vs in path))\n",
" # raytrace the points to the source plane\n",
" y1, y2 = lens.raytrace(x1, x2, z_s)\n",
"\n",
" # Plot the caustic\n",
" ax.plot(y1, y2, color=\"r\", zorder=1)\n",
"colors = [\"tab:red\", \"tab:blue\", \"tab:green\", \"tab:orange\", \"tab:purple\"]\n",
"for c in colors:\n",
" if x_final.shape[0] == 0:\n",
" break\n",
" R = ((x_final[0] - x_final) ** 2 + (y_final[0] - y_final) ** 2).sqrt()\n",
" ax.scatter(x_init[R < 0.1], y_init[R < 0.1], color=c)\n",
" ax.scatter(x_final[0], y_final[0], color=\"k\", s=200, marker=\"*\")\n",
" ax.scatter(x_final[0], y_final[0], color=c, s=100, marker=\"*\")\n",
" x_init = x_init[R >= 0.1]\n",
" y_init = y_init[R >= 0.1]\n",
" x_final = x_final[R >= 0.1]\n",
" y_final = y_final[R >= 0.1]\n",
"ax.axes.set_axis_off()\n",
"ax.set_xlim([-fov / 1.9, fov / 1.9])\n",
"ax.set_ylim([-fov / 1.9, fov / 1.9])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b4c5d47b",
"metadata": {},
"source": [
"Let's now look at a more clever algorithm. We will map triangles in the image plane to triangles in the source plane, we may then explore recursively, any triangles which enclose the desired source point. Due to the non-linearity of the gravitational lensing transformation, we will also search the neighbor of any triangle that seems to have found an image position. First we highlight in green, any triangles which contain the source point, then expand to all their neighbors."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5677ef22",
"metadata": {
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"n = 10\n",
"s = torch.stack((sp_x, sp_y))\n",
"# Construct a tiling of the image plane (squares at this point)\n",
"X, Y = torch.meshgrid(\n",
" torch.linspace(-fov / 2, fov / 2, n),\n",
" torch.linspace(-fov / 2, fov / 2, n),\n",
" indexing=\"ij\",\n",
")\n",
"E1 = torch.stack((X, Y), dim=-1)\n",
"# build the upper and lower triangles within the squares of the grid\n",
"E1 = torch.cat(\n",
" (\n",
" torch.stack((E1[:-1, :-1], E1[:-1, 1:], E1[1:, 1:]), dim=-2),\n",
" torch.stack((E1[:-1, :-1], E1[1:, :-1], E1[1:, 1:]), dim=-2),\n",
" ),\n",
" dim=0,\n",
").reshape(-1, 3, 2)\n",
"fig, ax = plt.subplots()\n",
"CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n",
"for path in paths:\n",
" # Collect the path into a discrete set of points\n",
" x1 = torch.tensor(list(float(vs[0]) for vs in path))\n",
" x2 = torch.tensor(list(float(vs[1]) for vs in path))\n",
" # raytrace the points to the source plane\n",
" y1, y2 = lens.raytrace(x1, x2, z_s)\n",
"\n",
" # Plot the caustic\n",
" ax.plot(y1, y2, color=\"r\", zorder=1)\n",
"S = raytrace(E1[..., 0], E1[..., 1])\n",
"S = torch.stack(S, dim=-1)\n",
"\n",
"# Identify triangles that contain the source plane point\n",
"locate1 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n",
"patches = []\n",
"for e, loc in zip(E1, locate1):\n",
" patches.append(\n",
" Polygon(\n",
" e,\n",
" fill=loc,\n",
" alpha=0.4 if loc else 1,\n",
" color=\"tab:green\" if loc else \"k\",\n",
" linewidth=1,\n",
" )\n",
" )\n",
"p = PatchCollection(patches, match_original=True)\n",
"ax.add_collection(p)\n",
"ax.set_xlim([-fov / 1.9, fov / 1.9])\n",
"ax.set_ylim([-fov / 1.9, fov / 1.9])\n",
"ax.set_axis_off()\n",
"plt.show()\n",
"\n",
"# Get all the neighbors and upsample the triangles\n",
"E2 = E1[locate1]\n",
"E2 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E2)\n",
"E2 = E2.reshape(-1, 3, 2)\n",
"E2 = caustics.lenses.func.remove_triangle_duplicates(E2)\n",
"# Upsample the triangles\n",
"E2 = torch.vmap(caustics.lenses.func.triangle_upsample)(E2)\n",
"E2 = E2.reshape(-1, 3, 2)\n",
"\n",
"fig, ax = plt.subplots()\n",
"CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n",
"for path in paths:\n",
" # Collect the path into a discrete set of points\n",
" x1 = torch.tensor(list(float(vs[0]) for vs in path))\n",
" x2 = torch.tensor(list(float(vs[1]) for vs in path))\n",
" # raytrace the points to the source plane\n",
" y1, y2 = lens.raytrace(x1, x2, z_s)\n",
"\n",
" # Plot the caustic\n",
" ax.plot(y1, y2, color=\"r\", zorder=1)\n",
"S = raytrace(E2[..., 0], E2[..., 1])\n",
"S = torch.stack(S, dim=-1)\n",
"\n",
"# Identify triangles that contain the source plane point\n",
"locate2 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n",
"patches = []\n",
"for e, loc in zip(E2, locate2):\n",
" patches.append(\n",
" Polygon(\n",
" e,\n",
" fill=loc,\n",
" alpha=0.4 if loc else 1,\n",
" color=\"tab:green\" if loc else \"k\",\n",
" linewidth=0.5,\n",
" )\n",
" )\n",
"p = PatchCollection(patches, match_original=True)\n",
"ax.add_collection(p)\n",
"ax.set_xlim([-fov / 1.9, fov / 1.9])\n",
"ax.set_ylim([-fov / 1.9, fov / 1.9])\n",
"ax.set_axis_off()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "61fdd482",
"metadata": {},
"source": [
"The process repeats until the triangles have converged to a very small area, at which point we then run a root finding algorithm to get the final points."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ef54c41",
"metadata": {},
"outputs": [],
"source": [
"# Get all the neighbors and upsample the triangles\n",
"E3 = E2[locate2]\n",
"E3 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E3)\n",
"E3 = E3.reshape(-1, 3, 2)\n",
"E3 = caustics.lenses.func.remove_triangle_duplicates(E3)\n",
"# Upsample the triangles\n",
"E3 = torch.vmap(caustics.lenses.func.triangle_upsample)(E3)\n",
"E3 = E3.reshape(-1, 3, 2)\n",
"\n",
"fig, ax = plt.subplots()\n",
"CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n",
"for path in paths:\n",
" # Collect the path into a discrete set of points\n",
" x1 = torch.tensor(list(float(vs[0]) for vs in path))\n",
" x2 = torch.tensor(list(float(vs[1]) for vs in path))\n",
" # raytrace the points to the source plane\n",
" y1, y2 = lens.raytrace(x1, x2, z_s)\n",
"\n",
" # Plot the caustic\n",
" ax.plot(y1, y2, color=\"r\", zorder=1)\n",
"S = raytrace(E3[..., 0], E3[..., 1])\n",
"S = torch.stack(S, dim=-1)\n",
"\n",
"# Identify triangles that contain the source plane point\n",
"locate3 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n",
"patches = []\n",
"for e, loc in zip(E3, locate3):\n",
" patches.append(\n",
" Polygon(\n",
" e,\n",
" fill=loc,\n",
" alpha=0.4 if loc else 1,\n",
" color=\"tab:green\" if loc else \"k\",\n",
" linewidth=0.5,\n",
" )\n",
" )\n",
"p = PatchCollection(patches, match_original=True)\n",
"ax.add_collection(p)\n",
"ax.set_xlim([-fov / 1.9, fov / 1.9])\n",
"ax.set_ylim([-fov / 1.9, fov / 1.9])\n",
"ax.set_axis_off()\n",
"plt.show()\n",
"\n",
"# Run the root finding algorithm\n",
"E4 = E3[locate3].sum(dim=1) / 3\n",
"E4 = caustics.lenses.func.forward_raytrace_rootfind(\n",
" E4[..., 0], E4[..., 1], s[0], s[1], raytrace\n",
")\n",
"fig, ax = plt.subplots()\n",
"CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n",
"for path in paths:\n",
" # Collect the path into a discrete set of points\n",
" x1 = torch.tensor(list(float(vs[0]) for vs in path))\n",
" x2 = torch.tensor(list(float(vs[1]) for vs in path))\n",
" # raytrace the points to the source plane\n",
" y1, y2 = lens.raytrace(x1, x2, z_s)\n",
"\n",
" # Plot the caustic\n",
" ax.plot(y1, y2, color=\"r\", zorder=1)\n",
"ax.scatter(E4[..., 0], E4[..., 1], color=\"k\", s=100, marker=\"*\")\n",
"ax.scatter(E4[..., 0], E4[..., 1], color=\"tab:green\", s=50, marker=\"*\")\n",
"ax.set_xlim([-fov / 1.9, fov / 1.9])\n",
"ax.set_ylim([-fov / 1.9, fov / 1.9])\n",
"ax.set_axis_off()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e421807b",
"metadata": {},
"outputs": [],
"source": []
Expand Down
Loading

0 comments on commit e61dbb7

Please sign in to comment.