From ecbda4eead0874267bec30dba5ef9acaff81dfd8 Mon Sep 17 00:00:00 2001 From: "Connor Stone, PhD" Date: Thu, 26 Sep 2024 12:23:59 -0700 Subject: [PATCH] Improve forward raytrace to better find all images (#264) * starting to make recursive triangle forward raytrace * updated tutorial for inverting the lens equation * hide plotting cell * set sampling to uniform for naive method * remove unconverged points in naive example * fix invert lens demo to go all the way to convergence. * fix demagification value --- .../source/tutorials/InvertLensEquation.ipynb | 477 +++++++++++++++++- src/caustics/lenses/base.py | 66 ++- src/caustics/lenses/func/__init__.py | 14 + src/caustics/lenses/func/base.py | 205 ++++++-- src/caustics/utils.py | 3 +- 5 files changed, 704 insertions(+), 61 deletions(-) diff --git a/docs/source/tutorials/InvertLensEquation.ipynb b/docs/source/tutorials/InvertLensEquation.ipynb index bd5ea670..edd6dca9 100644 --- a/docs/source/tutorials/InvertLensEquation.ipynb +++ b/docs/source/tutorials/InvertLensEquation.ipynb @@ -7,7 +7,7 @@ "source": [ "# Inverting the Lens Equation\n", "\n", - "The lens equation $\\vec{\\beta} = \\vec{\\theta} - \\vec{\\alpha}(\\vec{\\theta})$ allows us to find a point in the source plane given a point in the image plane. However, sometimes we know a point in the source plane and would like to see where it ends up in the image plane. This is not easy to do since a point in the source plane may map to multiple locations in the image plane. There is no closed form function to invert the lens equation, in large part because the deflection angle $\\vec{\\alpha}$ depends on the position in the image plane $\\vec{\\theta}$. To invert the lens equation, we will need to rely on optimization and a little luck to find all the images for a given source plane point. Below we will demonstrate how this is done in caustic!" + "The lens equation $\\vec{\\beta} = \\vec{\\theta} - \\vec{\\alpha}(\\vec{\\theta})$ allows us to find a point in the source plane given a point in the image plane. However, sometimes we know a point in the source plane and would like to see where it ends up in the image plane. This is not easy to do since a point in the source plane may map to multiple locations in the image plane. There is no closed form function to invert the lens equation, in large part because the deflection angle $\\vec{\\alpha}$ depends on the position in the image plane $\\vec{\\theta}$. To invert the lens equation, we will need to rely on optimization and a iterative procedures to find all the images for a given source plane point. Below we will demonstrate how this is done in caustic!" ] }, { @@ -23,6 +23,8 @@ "\n", "import torch\n", "import matplotlib.pyplot as plt\n", + "from matplotlib.patches import Polygon\n", + "from matplotlib.collections import PatchCollection\n", "import numpy as np\n", "\n", "import caustics" @@ -64,6 +66,14 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "2163ed78", + "metadata": {}, + "source": [ + "Here we run the forward raytracing for our particular lens model. In caustics we provide a convenient `forward_raytrace` function which can be called for any lens model. Internally, this constructs a number of triangles in the image plane, raytraces them to the source plane and identifies which ones contain the desired source plane position. Iteratively subdividing the triangles eventually converges on image plane positions which map to the desired source plane position. See further down for more detail." + ] + }, { "cell_type": "code", "execution_count": null, @@ -82,6 +92,14 @@ "bx, by = lens.raytrace(x, y, z_s)" ] }, + { + "cell_type": "markdown", + "id": "462b2e8f", + "metadata": {}, + "source": [ + "When we raytrace the coordinates we get out from `forward_raytrace` it is not too surprising that they all give source plane positions very close to the desired source plane position. Here we plot them so you can see:" + ] + }, { "cell_type": "code", "execution_count": null, @@ -119,6 +137,45 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "6641535a", + "metadata": {}, + "source": [ + "It is also often not necessary to model the central demagnified region since it is so faint (approximately a 100,000 times fainter in this case) that it doesn't contribute measurably to the flux of an image. We can very easily check the magnification of every point and remove the unnecessary one." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5d20df7", + "metadata": {}, + "outputs": [], + "source": [ + "m = lens.magnification(x, y, z_s)\n", + "print(m.detach().cpu().tolist())\n", + "N_m = torch.argsort(m)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "# Get the path from the matplotlib contour plot of the critical line\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "\n", + "plt.scatter(x[N_m[1:]], y[N_m[1:]], color=\"b\", label=\"magnified\")\n", + "plt.scatter(x[N_m[0]], y[N_m[0]], color=\"r\", label=\"de-magnified\")\n", + "plt.axis(\"off\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, { "cell_type": "markdown", "id": "0163d5f7", @@ -126,7 +183,7 @@ "source": [ "## Lets take a look\n", "\n", - "Using the `LensSource` simulator and the forward raytracing coordinates we can focus our calculations on the regions of interest for each image." + "Using the `LensSource` simulator and the forward raytracing coordinates we can focus our calculations on the regions of interest for each image. Note however that the regions can overlap, which they do very slightly in this case." ] }, { @@ -219,10 +276,426 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "7a4c8fd5", + "metadata": {}, + "source": [ + "## How forward_raytrace works\n", + "\n", + "All forward raytracing methods are imperfect as they involve iterative solutions which require enough resolution to pick out all the relevant image plane positions. To start, lets consider a more naive algorithm, simply placing random points in the image plane, then running a root-finding algorithm to get the source plane positions to line up." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e06ac42", + "metadata": {}, + "outputs": [], + "source": [ + "Ninit = 100\n", + "x_init = (torch.rand(Ninit) - 0.5) * fov\n", + "y_init = (torch.rand(Ninit) - 0.5) * fov\n", + "\n", + "\n", + "def raytrace(x, y):\n", + " return lens.raytrace(x, y, z_s)\n", + "\n", + "\n", + "final = caustics.lenses.func.forward_raytrace_rootfind(\n", + " x_init, y_init, sp_x, sp_y, raytrace\n", + ")\n", + "x_final, y_final = final[..., 0], final[..., 1]\n", + "\n", + "# Pick only points that converged\n", + "bx_final, by_final = raytrace(x_final, y_final)\n", + "R = torch.sqrt((sp_x - bx_final) ** 2 + (sp_y - by_final) ** 2)\n", + "x_final = x_final[R < 1e-3]\n", + "y_final = y_final[R < 1e-3]" + ] + }, + { + "cell_type": "markdown", + "id": "2abb217e", + "metadata": {}, + "source": [ + "Here we easily find the four magnified images, but the central demagnified image is (often) not found by this method since a point has to get lucky enough to start very close to the correct position in order for the gradient based root finder to work." + ] + }, { "cell_type": "code", "execution_count": null, "id": "d2fe0e6f", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "colors = [\"tab:red\", \"tab:blue\", \"tab:green\", \"tab:orange\", \"tab:purple\"]\n", + "for c in colors:\n", + " if x_final.shape[0] == 0:\n", + " break\n", + " R = ((x_final[0] - x_final) ** 2 + (y_final[0] - y_final) ** 2).sqrt()\n", + " ax.scatter(x_init[R < 0.1], y_init[R < 0.1], color=c)\n", + " ax.scatter(x_final[0], y_final[0], color=\"k\", s=200, marker=\"*\")\n", + " ax.scatter(x_final[0], y_final[0], color=c, s=100, marker=\"*\")\n", + " x_init = x_init[R >= 0.1]\n", + " y_init = y_init[R >= 0.1]\n", + " x_final = x_final[R >= 0.1]\n", + " y_final = y_final[R >= 0.1]\n", + "ax.axes.set_axis_off()\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b4c5d47b", + "metadata": {}, + "source": [ + "Let's now look at a more clever algorithm. We will map triangles in the image plane to triangles in the source plane, we may then explore recursively, any triangles which enclose the desired source point. Due to the non-linearity of the gravitational lensing transformation, we will also search the neighbor of any triangle that seems to have found an image position. First we highlight in green, any triangles which contain the source point, then expand to all their neighbors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5677ef22", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "n = 10\n", + "s = torch.stack((sp_x, sp_y))\n", + "# Construct a tiling of the image plane (squares at this point)\n", + "X, Y = torch.meshgrid(\n", + " torch.linspace(-fov / 2, fov / 2, n),\n", + " torch.linspace(-fov / 2, fov / 2, n),\n", + " indexing=\"ij\",\n", + ")\n", + "E1 = torch.stack((X, Y), dim=-1)\n", + "# build the upper and lower triangles within the squares of the grid\n", + "E1 = torch.cat(\n", + " (\n", + " torch.stack((E1[:-1, :-1], E1[:-1, 1:], E1[1:, 1:]), dim=-2),\n", + " torch.stack((E1[:-1, :-1], E1[1:, :-1], E1[1:, 1:]), dim=-2),\n", + " ),\n", + " dim=0,\n", + ").reshape(-1, 3, 2)\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E1[..., 0], E1[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate1 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E1, locate1):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=1,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "# Get all the neighbors and upsample the triangles\n", + "E2 = E1[locate1]\n", + "E2 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E2)\n", + "E2 = E2.reshape(-1, 3, 2)\n", + "E2 = caustics.lenses.func.remove_triangle_duplicates(E2)\n", + "# Upsample the triangles\n", + "E2 = torch.vmap(caustics.lenses.func.triangle_upsample)(E2)\n", + "E2 = E2.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E2[..., 0], E2[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate2 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E2, locate2):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "61fdd482", + "metadata": {}, + "source": [ + "The process repeats until the triangles have converged to a very small area, at which point we then run a root finding algorithm to get the final points. The central region is a very unstable optimum, so we need to use the triangle method for several iterations before we can run the root finder to get the exact optimal point." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ef54c41", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "# Get all the neighbors and upsample the triangles\n", + "E3 = E2[locate2]\n", + "E3 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E3)\n", + "E3 = E3.reshape(-1, 3, 2)\n", + "E3 = caustics.lenses.func.remove_triangle_duplicates(E3)\n", + "# Upsample the triangles\n", + "E3 = torch.vmap(caustics.lenses.func.triangle_upsample)(E3)\n", + "E3 = E3.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E3[..., 0], E3[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate3 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E3, locate3):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "# Get all the neighbors and upsample the triangles\n", + "E4 = E3[locate3]\n", + "E4 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E4)\n", + "E4 = E4.reshape(-1, 3, 2)\n", + "E4 = caustics.lenses.func.remove_triangle_duplicates(E4)\n", + "# Upsample the triangles\n", + "E4 = torch.vmap(caustics.lenses.func.triangle_upsample)(E4)\n", + "E4 = E4.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E4[..., 0], E4[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate4 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E4, locate4):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "# Get all the neighbors and upsample the triangles\n", + "E5 = E4[locate4]\n", + "E5 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E5)\n", + "E5 = E5.reshape(-1, 3, 2)\n", + "E5 = caustics.lenses.func.remove_triangle_duplicates(E5)\n", + "# Upsample the triangles\n", + "E5 = torch.vmap(caustics.lenses.func.triangle_upsample)(E5)\n", + "E5 = E5.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E5[..., 0], E5[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate5 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E5, locate5):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "# Get all the neighbors and upsample the triangles\n", + "E6 = E5[locate5]\n", + "E6 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E6)\n", + "E6 = E6.reshape(-1, 3, 2)\n", + "E6 = caustics.lenses.func.remove_triangle_duplicates(E6)\n", + "# Upsample the triangles\n", + "E6 = torch.vmap(caustics.lenses.func.triangle_upsample)(E6)\n", + "E6 = E6.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E6[..., 0], E6[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate6 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E6, locate6):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "\n", + "# Run the root finding algorithm\n", + "E7 = E6[locate6].sum(dim=1) / 3\n", + "E7 = caustics.lenses.func.forward_raytrace_rootfind(\n", + " E7[..., 0], E7[..., 1], s[0], s[1], raytrace\n", + ")\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "ax.scatter(E7[..., 0], E7[..., 1], color=\"k\", s=100, marker=\"*\")\n", + "ax.scatter(E7[..., 0], E7[..., 1], color=\"tab:green\", s=50, marker=\"*\")\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e421807b", "metadata": {}, "outputs": [], "source": [] diff --git a/src/caustics/lenses/base.py b/src/caustics/lenses/base.py index 11673e3b..ee819f2b 100644 --- a/src/caustics/lenses/base.py +++ b/src/caustics/lenses/base.py @@ -130,13 +130,16 @@ def forward_raytrace( z_s: Tensor, *args, params: Optional["Packed"] = None, - epsilon=1e-2, - n_init=100, - fov=5.0, + epsilon: float = 1e-3, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + fov: float = 5.0, + divisions: int = 100, **kwargs, ) -> tuple[Tensor, Tensor]: """ - Perform a forward ray-tracing operation which maps from the source plane to the image plane. + Perform a forward ray-tracing operation which maps from the source plane + to the image plane. Parameters ---------- @@ -159,18 +162,20 @@ def forward_raytrace( Dynamic parameter container for the lens model. Defaults to None. epsilon: Tensor - maximum distance between two images (arcsec) before they are considered the same image. + maximum distance between two images (arcsec) before they are + considered the same image. *Unit: arcsec* - n_init: int - number of random initialization points used to try and find image plane points. - fov: float the field of view in which the initial random samples are taken. *Unit: arcsec* + divisions: int + the number of divisions of the fov on each axis when constructing + the grid to perform in the triangle search. + Returns ------- x_component: Tensor @@ -183,18 +188,41 @@ def forward_raytrace( *Unit: arcsec* """ - - # TODO make FOV more general so that it doesn't have to be centered on zero,zero - if fov is None: - raise ValueError("fov must be given to generate initial guesses") - + raytrace = partial(self.raytrace, params=params, z_s=z_s) + if x0 is None: + x0 = torch.zeros((), device=bx.device, dtype=bx.dtype) + if y0 is None: + y0 = torch.zeros((), device=by.device, dtype=by.dtype) + # X = torch.stack((x0, y0)).repeat(4, 1) + # X[0] -= fov / 2 + # X[1][0] -= fov / 2 + # X[1][1] += fov / 2 + # X[2][0] += fov / 2 + # X[2][1] -= fov / 2 + # X[3] += fov / 2 + + # Sx, Sy = raytrace(X[..., 0], X[..., 1]) + # S = torch.stack((Sx, Sy)).T + # res1, ap1 = func.triangle_search( + # torch.stack((bx, by)), + # X[:3], + # S[:3], + # raytrace, + # epsilon, + # torch.zeros((0, 2)), + # ) + # res2, ap2 = func.triangle_search( + # torch.stack((bx, by)), + # X[1:], + # S[1:], + # raytrace, + # epsilon, + # torch.zeros((0, 2)), + # ) + # res = torch.cat((res1, res2), dim=0) + # return res[:, 0], res[:, 1], torch.cat((ap1, ap2), dim=0) return func.forward_raytrace( - bx, - by, - partial(self.raytrace, params=params, z_s=z_s), - epsilon, - n_init, - fov, + torch.stack((bx, by)), raytrace, x0, y0, fov, divisions, epsilon ) diff --git a/src/caustics/lenses/func/__init__.py b/src/caustics/lenses/func/__init__.py index f9bfc03c..f46be3df 100644 --- a/src/caustics/lenses/func/__init__.py +++ b/src/caustics/lenses/func/__init__.py @@ -1,5 +1,12 @@ from .base import ( forward_raytrace, + triangle_contains, + triangle_area, + triangle_neighbors, + triangle_upsample, + triangle_equals, + remove_triangle_duplicates, + forward_raytrace_rootfind, physical_from_reduced_deflection_angle, reduced_from_physical_deflection_angle, time_delay_arcsec2_to_days, @@ -60,6 +67,13 @@ __all__ = ( "forward_raytrace", + "triangle_contains", + "triangle_area", + "triangle_neighbors", + "triangle_upsample", + "triangle_equals", + "remove_triangle_duplicates", + "forward_raytrace_rootfind", "physical_from_reduced_deflection_angle", "reduced_from_physical_deflection_angle", "time_delay_arcsec2_to_days", diff --git a/src/caustics/lenses/func/base.py b/src/caustics/lenses/func/base.py index e86a579b..9108aed7 100644 --- a/src/caustics/lenses/func/base.py +++ b/src/caustics/lenses/func/base.py @@ -4,38 +4,136 @@ from ...constants import arcsec_to_rad, c_Mpc_s, days_to_seconds -def forward_raytrace(bx, by, raytrace, epsilon, n_init, fov): +def triangle_contains(p, v): """ - Perform a forward ray-tracing operation which maps from the source plane to the image plane. + determine if point v is inside triangle p. Where p is a (3,2) tensor, and v + is a (2,) tensor. + """ + p01 = p[1] - p[0] + p02 = p[2] - p[0] + dp0p02 = p[0][0] * p02[1] - p[0][1] * p02[0] + dp0p01 = p[0][0] * p01[1] - p[0][1] * p01[0] + dp01p02 = p01[0] * p02[1] - p01[1] * p02[0] + dvp02 = v[0] * p02[1] - v[1] * p02[0] + dvp01 = v[0] * p01[1] - v[1] * p01[0] + a = (dvp02 - dp0p02) / dp01p02 + b = -(dvp01 - dp0p01) / dp01p02 + return (a >= 0) & (b >= 0) & (a + b <= 1) + + +def triangle_area(p): + """ + Determine the area of triangle p where p is a (3,2) tensor. + """ + return ( + 0.5 + * ( + p[0][0] * (p[1][1] - p[2][1]) + + p[1][0] * (p[2][1] - p[0][1]) + + p[2][0] * (p[0][1] - p[1][1]) + ).abs() + ) - Parameters - ---------- - bx: Tensor - Tensor of x coordinate in the source plane. - *Unit: arcsec* +def triangle_neighbors(p): + """ + Build a set of neighbors for triangle p where p is a (3,2) tensor. The + neighbors all have the same shape as p, but are various translations and + reflections of p that share a common edge or vertex. + """ + p01 = p[1] - p[0] + p02 = p[2] - p[0] + p12 = p[2] - p[1] + pref = -(p - p[0]) + p[0] + return torch.stack( + ( + p, + p + p01, + p - p01, + p + p02, + p - p02, + p + p12, + p - p12, + pref, + pref + p01, + pref + 2 * p01, + pref + p02, + pref + 2 * p02, + pref + p01 + p02, + ), + dim=0, + ) - by: Tensor - Tensor of y coordinate in the source plane. - *Unit: arcsec* +def triangle_upsample(p): + """ + Upsample triangle p where p is a (3,2) tensor. The upsampled triangles are + all triangles internal to p built by taking the midpoints of the edges of p. + """ + p01 = (p[1] + p[0]) / 2 + p02 = (p[2] + p[0]) / 2 + p12 = (p[2] + p[1]) / 2 + return torch.stack( + ( + torch.stack((p[0], p01, p02), dim=0), + torch.stack((p01, p[1], p12), dim=0), + torch.stack((p02, p12, p[2]), dim=0), + torch.stack((p01, p12, p02), dim=0), + ), + dim=0, + ) - raytrace: function - function that takes in the x and y coordinates in the image plane and returns the x and y coordinates in the source plane. - epsilon: Tensor - maximum distance between two images (arcsec) before they are considered the same image. +def triangle_equals(p1, p2): + """ + Determine if two triangles are equal. Where p1 and p2 are (3,2) tensors. + """ + return torch.all((p1 - p2).abs() < 1e-6) + + +def remove_triangle_duplicates(p): + unique_triangles = torch.zeros((0, 3, 2)) + B = p.shape[0] + batch_triangle_equals = torch.vmap(triangle_equals, in_dims=(None, 0)) + for i in range(B): + # Compare current triangle with all triangles in the unique list + if i == 0 or not batch_triangle_equals(p[i], unique_triangles).any(): + unique_triangles = torch.cat((unique_triangles, p[i].unsqueeze(0)), dim=0) + + return unique_triangles + + +def forward_raytrace_rootfind(ix, iy, bx, by, raytrace): + """ + Perform a forward ray-tracing operation which maps from the source plane to + the image plane. + + Parameters + ---------- + ix: Tensor + Tensor of x coordinate in the image plane. This initializes the + ray-tracing optimization. Should have shape (B, 2). *Unit: arcsec* - n_init: int - number of random initialization points used to try and find image plane points. + iy: Tensor + Tensor of y coordinate in the image plane. This initializes the + ray-tracing optimization. Should have shape (B, 2). + + bx: Tensor + Tensor of x coordinate in the source plane. Should be a scalar. - fov: float - the field of view in which the initial random samples are taken. + *Unit: arcsec* + + by: Tensor + Tensor of y coordinate in the source plane. Should be a scalar. *Unit: arcsec* + raytrace: function + function that takes in the x and y coordinates in the image plane and + returns the x and y coordinates in the source plane. + Returns ------- x_component: Tensor @@ -48,37 +146,66 @@ def forward_raytrace(bx, by, raytrace, epsilon, n_init, fov): *Unit: arcsec* """ - bxy = torch.stack((bx, by)).repeat(n_init, 1) # has shape (n_init, Dout:2) - - # Random starting points in image plane - guesses = ( - torch.as_tensor(fov, dtype=bx.dtype) - * (torch.rand(n_init, 2, dtype=bx.dtype) - 0.5) - ).to( - device=bxy.device - ) # Has shape (n_init, Din:2) - + ixy = torch.stack((ix, iy), dim=1) # has shape (B, Din:2) + bxy = torch.stack((bx, by)).repeat(ix.shape[0], 1) # has shape (B, Dout:2) # Optimize guesses in image plane x, l, c = batch_lm( # noqa: E741 Unused `l` variable - guesses, + ixy, bxy, lambda *a, **k: torch.stack( raytrace(a[0][..., 0], a[0][..., 1], *a[1:], **k), dim=-1 ), ) + return x - # Clip points that didn't converge - x = x[c < 1e-2 * epsilon**2] - # Cluster results into n-images - res = [] - while len(x) > 0: - res.append(x[0]) - d = torch.linalg.norm(x - x[0], dim=-1) - x = x[d > epsilon] +def forward_raytrace(s, raytrace, x0, y0, fov, n, epsilon): - res = torch.stack(res, dim=0) - return res[..., 0], res[..., 1] + # Construct a tiling of the image plane (squares at this point) + X, Y = torch.meshgrid( + torch.linspace(x0 - fov / 2, x0 + fov / 2, n), + torch.linspace(y0 - fov / 2, y0 + fov / 2, n), + indexing="ij", + ) + E = torch.stack((X, Y), dim=-1) + # build the upper and lower triangles within the squares of the grid + E = torch.cat( + ( + torch.stack((E[:-1, :-1], E[:-1, 1:], E[1:, 1:]), dim=-2), + torch.stack((E[:-1, :-1], E[1:, :-1], E[1:, 1:]), dim=-2), + ), + dim=0, + ).reshape(-1, 3, 2) + + i = 0 + while True: + + # Expand the search to neighboring triangles + if i > 0: # no need for neighbors in the first iteration + E = torch.vmap(triangle_neighbors)(E) + E = E.reshape(-1, 3, 2) + E = remove_triangle_duplicates(E) + # Upsample the triangles + E = torch.vmap(triangle_upsample)(E) + E = E.reshape(-1, 3, 2) + + S = raytrace(E[..., 0], E[..., 1]) + S = torch.stack(S, dim=-1) + + # Identify triangles that contain the source plane point + locate = torch.vmap(triangle_contains, in_dims=(0, None))(S, s) + E = E[locate] + i += 1 + + if triangle_area(E[0]) > epsilon**2: + # Rootfind the source plane point in the triangle + Emid = E.sum(dim=1) / 3 + Emid = forward_raytrace_rootfind( + Emid[..., 0], Emid[..., 1], s[0], s[1], raytrace + ) + if torch.all(torch.vmap(triangle_contains)(E, Emid)): + break + return Emid[..., 0], Emid[..., 1] def physical_from_reduced_deflection_angle(ax, ay, d_s, d_ls): diff --git a/src/caustics/utils.py b/src/caustics/utils.py index ccbe24d5..f6399979 100644 --- a/src/caustics/utils.py +++ b/src/caustics/utils.py @@ -1037,7 +1037,8 @@ def _lm_step(f, X, Y, Cinv, L, Lup, Ldn, epsilon, L_min, L_max): chi2_new = (dYnew @ Cinv @ dYnew).sum(-1) # Test - rho = (chi2 - chi2_new) / torch.abs(h @ (L * torch.dot(torch.diag(hess), h) + grad)) # fmt: skip + expected_improvement = torch.dot(h, hess @ h) + 2 * torch.dot(h, grad) + rho = (chi2 - chi2_new) / torch.abs(expected_improvement) # fmt: skip # Update X = torch.where(rho >= epsilon, X + h, X)