diff --git a/.devcontainer/cpu/devcontainer.json b/.devcontainer/cpu/devcontainer.json index 55c62524..c71bc569 100644 --- a/.devcontainer/cpu/devcontainer.json +++ b/.devcontainer/cpu/devcontainer.json @@ -6,8 +6,8 @@ "context": "../..", "dockerfile": "../Dockerfile", "args": { - "CLANG_VERSION": "" - } + "CLANG_VERSION": "", + }, }, // Use 'forwardPorts' to make a list of ports inside the container available locally. @@ -26,10 +26,10 @@ "ms-python.python", "ms-vsliveshare.vsliveshare", "DavidAnson.vscode-markdownlint", - "GitHub.copilot" - ] - } - } + "GitHub.copilot", + ], + }, + }, // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. // "remoteUser": "root" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9b27e127..7af59aee 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -71,4 +71,4 @@ jobs: pytest -vvv --cov=${{ env.PROJECT_NAME }} --cov-report=xml --cov-report=term tests/ - name: Upload coverage reports to Codecov with GitHub Action - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2d9514af..896eae87 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ exclude: | (?x)^( - tests/utils.py + tests/utils/ ) ci: @@ -9,7 +9,7 @@ ci: repos: - repo: https://github.com/psf/black - rev: "23.12.1" + rev: "24.1.1" hooks: - id: black-jupyter @@ -50,20 +50,19 @@ repos: args: [--prose-wrap=always] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.1.9" + rev: "v0.2.1" hooks: - id: ruff args: ["--fix", "--show-fixes"] - # 2023-12-11: Not use mypy for now. - # - repo: https://github.com/pre-commit/mirrors-mypy - # rev: "v1.8.0" - # hooks: - # - id: mypy - # files: src|tests - # args: [] - # additional_dependencies: - # - pytest + - repo: https://github.com/pre-commit/mirrors-mypy + rev: "v1.8.0" + hooks: + - id: mypy + files: src + args: ["--ignore-missing-imports"] + additional_dependencies: + - pytest - repo: https://github.com/codespell-project/codespell rev: "v2.2.6" @@ -72,7 +71,7 @@ repos: args: ["--write-changes", "--ignore-words", ".codespell-whitelist"] - repo: https://github.com/kynan/nbstripout - rev: 0.6.1 + rev: 0.7.1 hooks: - id: nbstripout args: [--extra-keys=metadata.kernelspec metadata.language_info.version] diff --git a/README.md b/README.md index 245b8d7b..04a25641 100644 --- a/README.md +++ b/README.md @@ -32,10 +32,10 @@ import matplotlib.pyplot as plt import caustics import torch -cosmology = caustics.cosmology.FlatLambdaCDM() -sie = caustics.lenses.SIE(cosmology=cosmology, name="lens") -src = caustics.light.Sersic(name="source") -lnslt = caustics.light.Sersic(name="lenslight") +cosmology = caustics.FlatLambdaCDM() +sie = caustics.SIE(cosmology=cosmology, name="lens") +src = caustics.Sersic(name="source") +lnslt = caustics.Sersic(name="lenslight") x = torch.tensor([ # z_s z_l x0 y0 q phi b x0 y0 q phi n Re @@ -44,7 +44,7 @@ x = torch.tensor([ 5.0, -0.2, 0.0, 0.8, 0.0, 1., 1.0, 10.0 ]) # fmt: skip -minisim = caustics.sims.Lens_Source( +minisim = caustics.Lens_Source( lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100 ) plt.imshow(minisim(x, quad_level=3), origin="lower") diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index 4929d38b..998e1c82 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -61,4 +61,31 @@ Finalizing a Pull Request Once the PR is submitted, we will look through it and request any changes necessary before merging it into the main branch. You can make those changes just like any other edits on your fork. Then when you push them, they will be joined in to the PR automatically and any unit tests will run again. +Black Formatting Exceptions for Equations +----------------------------------------- + +In the **caustics** project, we utilize the Black code formatter to ensure consistent and readable code. However, there are instances where the automatic formatting performed by Black may not align with the desired formatting for equations within the code. + +To address this, we have introduced the use of ``#fmt: skip`` tags to exempt specific code blocks or lines from Black formatting. This is particularly useful when dealing with equations that have a specific format or layout that should not be altered by the code formatter. + +How to Use ``#fmt: skip`` for Equations +--------------------------------------- + +To exempt a specific section of code, such as an equation, from Black formatting, simply add a comment with the ``#fmt: skip`` tag at the end of the line containing the code block. For example: + +.. code-block:: python + + psi = (q**2 * (x**2 + self.s**2) + y**2).sqrt() # fmt: skip + +In the above example, the line with the ``#fmt: skip`` comment informs Black to skip formatting for the following line containing the equation. This allows developers to maintain control over the formatting of equations while still benefiting from the automatic formatting provided by Black for the rest of the codebase. + +Best Practices for Black Formatting Exceptions +---------------------------------------------- + +- Use ``#fmt: skip`` sparingly and only for sections where manual formatting is essential. +- Clearly document the reason for using ``#fmt: skip`` to provide context for future developers. + +By incorporating ``#fmt: skip`` tags for equations, we strike a balance between automated code formatting and the need for manual control over certain code elements. + + Once the PR has been merged, you may delete your fork if you aren't using it any more, or take on a new issue, it's up to you! diff --git a/docs/source/intro.md b/docs/source/intro.md index 02b156b2..ba9c6a94 100644 --- a/docs/source/intro.md +++ b/docs/source/intro.md @@ -27,10 +27,10 @@ import matplotlib.pyplot as plt import caustics import torch -cosmology = caustics.cosmology.FlatLambdaCDM() -sie = caustics.lenses.SIE(cosmology=cosmology, name="lens") -src = caustics.light.Sersic(name="source") -lnslt = caustics.light.Sersic(name="lenslight") +cosmology = caustics.FlatLambdaCDM() +sie = caustics.SIE(cosmology=cosmology, name="lens") +src = caustics.Sersic(name="source") +lnslt = caustics.Sersic(name="lenslight") x = torch.tensor([ # z_s z_l x0 y0 q phi b x0 y0 q phi n Re @@ -39,7 +39,7 @@ x = torch.tensor([ 5.0, -0.2, 0.0, 0.8, 0.0, 1., 1.0, 10.0 ]) # fmt: skip -minisim = caustics.sims.Lens_Source( +minisim = caustics.Lens_Source( lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100 ) plt.imshow(minisim(x, quad_level=3), origin="lower") diff --git a/docs/source/tutorials/BasicIntroduction.ipynb b/docs/source/tutorials/BasicIntroduction.ipynb index 2b24631d..1f90ff07 100644 --- a/docs/source/tutorials/BasicIntroduction.ipynb +++ b/docs/source/tutorials/BasicIntroduction.ipynb @@ -44,7 +44,7 @@ "metadata": {}, "outputs": [], "source": [ - "cosmology = caustics.cosmology.FlatLambdaCDM()" + "cosmology = caustics.FlatLambdaCDM()" ] }, { @@ -62,7 +62,7 @@ "metadata": {}, "outputs": [], "source": [ - "sie = caustics.lenses.SIE(cosmology=cosmology, name=\"lens\")" + "sie = caustics.SIE(cosmology=cosmology, name=\"lens\")" ] }, { @@ -80,7 +80,7 @@ "metadata": {}, "outputs": [], "source": [ - "src = caustics.light.Sersic(name=\"source\")" + "src = caustics.Sersic(name=\"source\")" ] }, { @@ -98,7 +98,7 @@ "metadata": {}, "outputs": [], "source": [ - "lnslt = caustics.light.Sersic(name=\"lenslight\")" + "lnslt = caustics.Sersic(name=\"lenslight\")" ] }, { @@ -116,7 +116,7 @@ "metadata": {}, "outputs": [], "source": [ - "minisim = caustics.sims.Lens_Source(\n", + "minisim = caustics.Lens_Source(\n", " lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100\n", ")" ] diff --git a/docs/source/tutorials/InvertLensEquation.ipynb b/docs/source/tutorials/InvertLensEquation.ipynb index 55b850de..777ec174 100644 --- a/docs/source/tutorials/InvertLensEquation.ipynb +++ b/docs/source/tutorials/InvertLensEquation.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "b7d08f1d", + "id": "0", "metadata": {}, "source": [ "# Inverting the Lens Equation\n", @@ -13,7 +13,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4027aaf9", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -36,13 +36,13 @@ { "cell_type": "code", "execution_count": null, - "id": "2118e1c1", + "id": "2", "metadata": {}, "outputs": [], "source": [ "# initialization stuff for an SIE lens\n", "\n", - "cosmology = caustics.cosmology.FlatLambdaCDM(name=\"cosmo\")\n", + "cosmology = caustics.FlatLambdaCDM(name=\"cosmo\")\n", "cosmology.to(dtype=torch.float32)\n", "n_pix = 100\n", "res = 0.05\n", @@ -56,7 +56,7 @@ ")\n", "z_l = torch.tensor(0.5, dtype=torch.float32)\n", "z_s = torch.tensor(1.5, dtype=torch.float32)\n", - "lens = caustics.lenses.SIE(\n", + "lens = caustics.SIE(\n", " cosmology=cosmology,\n", " name=\"sie\",\n", " z_l=z_l,\n", @@ -72,7 +72,7 @@ { "cell_type": "code", "execution_count": null, - "id": "98e46aa1", + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -90,7 +90,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eb73147c", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -122,7 +122,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e9f2d09f-6327-4d01-8d62-baa749ffc621", + "id": "5", "metadata": {}, "outputs": [], "source": [] diff --git a/docs/source/tutorials/LensZoo.ipynb b/docs/source/tutorials/LensZoo.ipynb index 8183ff9b..8bf0b44b 100644 --- a/docs/source/tutorials/LensZoo.ipynb +++ b/docs/source/tutorials/LensZoo.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "e4054d16", + "id": "0", "metadata": {}, "source": [ "# A Menagerie of Lenses\n", @@ -13,7 +13,7 @@ { "cell_type": "code", "execution_count": null, - "id": "beeb58fa", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -34,14 +34,14 @@ { "cell_type": "code", "execution_count": null, - "id": "72cbde6d", + "id": "2", "metadata": {}, "outputs": [], "source": [ - "cosmology = caustics.cosmology.FlatLambdaCDM()\n", + "cosmology = caustics.FlatLambdaCDM()\n", "cosmology.to(dtype=torch.float32)\n", "z_s = torch.tensor(1.0)\n", - "base_sersic = caustics.light.Sersic(\n", + "base_sersic = caustics.Sersic(\n", " x0=0.1,\n", " y0=0.1,\n", " q=0.6,\n", @@ -70,7 +70,7 @@ }, { "cell_type": "markdown", - "id": "aa183e3d", + "id": "3", "metadata": {}, "source": [ "## Point (Point)\n", @@ -81,17 +81,17 @@ { "cell_type": "code", "execution_count": null, - "id": "4b4b2faa", + "id": "4", "metadata": {}, "outputs": [], "source": [ - "lens = caustics.lenses.Point(\n", + "lens = caustics.Point(\n", " cosmology=cosmology,\n", " x0=0.0,\n", " y0=0.0,\n", " th_ein=1.0,\n", ")\n", - "sim = caustics.sims.Lens_Source(\n", + "sim = caustics.Lens_Source(\n", " lens=lens,\n", " source=base_sersic,\n", " pixelscale=res,\n", @@ -111,7 +111,7 @@ }, { "cell_type": "markdown", - "id": "2ca96b29", + "id": "5", "metadata": {}, "source": [ "## Singular Isothermal Sphere (SIS)\n", @@ -122,17 +122,17 @@ { "cell_type": "code", "execution_count": null, - "id": "6d563d58", + "id": "6", "metadata": {}, "outputs": [], "source": [ - "lens = caustics.lenses.SIS(\n", + "lens = caustics.SIS(\n", " cosmology=cosmology,\n", " x0=0.0,\n", " y0=0.0,\n", " th_ein=1.0,\n", ")\n", - "sim = caustics.sims.Lens_Source(\n", + "sim = caustics.Lens_Source(\n", " lens=lens,\n", " source=base_sersic,\n", " pixelscale=res,\n", @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "e09f874d", + "id": "7", "metadata": {}, "source": [ "## Singular Isothermal Ellipsoid (SIE)\n", @@ -166,11 +166,11 @@ { "cell_type": "code", "execution_count": null, - "id": "63c78c9a", + "id": "8", "metadata": {}, "outputs": [], "source": [ - "lens = caustics.lenses.SIE(\n", + "lens = caustics.SIE(\n", " cosmology=cosmology,\n", " x0=0.0,\n", " y0=0.0,\n", @@ -178,7 +178,7 @@ " phi=np.pi / 2,\n", " b=1.0,\n", ")\n", - "sim = caustics.sims.Lens_Source(\n", + "sim = caustics.Lens_Source(\n", " lens=lens,\n", " source=base_sersic,\n", " pixelscale=res,\n", @@ -201,7 +201,7 @@ }, { "cell_type": "markdown", - "id": "794060c3", + "id": "9", "metadata": {}, "source": [ "## Elliptical Power Law (EPL)\n", @@ -212,11 +212,11 @@ { "cell_type": "code", "execution_count": null, - "id": "557bad9f", + "id": "10", "metadata": {}, "outputs": [], "source": [ - "lens = caustics.lenses.EPL(\n", + "lens = caustics.EPL(\n", " cosmology=cosmology,\n", " x0=0.0,\n", " y0=0.0,\n", @@ -225,7 +225,7 @@ " b=1.0,\n", " t=0.5,\n", ")\n", - "sim = caustics.sims.Lens_Source(\n", + "sim = caustics.Lens_Source(\n", " lens=lens,\n", " source=base_sersic,\n", " pixelscale=res,\n", @@ -248,7 +248,7 @@ }, { "cell_type": "markdown", - "id": "1aa4a0b4", + "id": "11", "metadata": {}, "source": [ "## Navarro Frenk White profile (NFW)\n", @@ -261,18 +261,18 @@ { "cell_type": "code", "execution_count": null, - "id": "bb209aba", + "id": "12", "metadata": {}, "outputs": [], "source": [ - "lens = caustics.lenses.NFW(\n", + "lens = caustics.NFW(\n", " cosmology=cosmology,\n", " x0=0.0,\n", " y0=0.0,\n", " m=1e13,\n", " c=20.0,\n", ")\n", - "sim = caustics.sims.Lens_Source(\n", + "sim = caustics.Lens_Source(\n", " lens=lens,\n", " source=base_sersic,\n", " pixelscale=res,\n", @@ -295,7 +295,7 @@ }, { "cell_type": "markdown", - "id": "d96d3992", + "id": "13", "metadata": {}, "source": [ "## Truncated NFW profile (TNFW)\n", @@ -310,11 +310,11 @@ { "cell_type": "code", "execution_count": null, - "id": "dd5805a5", + "id": "14", "metadata": {}, "outputs": [], "source": [ - "lens = caustics.lenses.TNFW(\n", + "lens = caustics.TNFW(\n", " cosmology=cosmology,\n", " x0=0.0,\n", " y0=0.0,\n", @@ -322,7 +322,7 @@ " scale_radius=1.0,\n", " tau=3.0,\n", ")\n", - "sim = caustics.sims.Lens_Source(\n", + "sim = caustics.Lens_Source(\n", " lens=lens,\n", " source=base_sersic,\n", " pixelscale=res,\n", @@ -345,7 +345,7 @@ }, { "cell_type": "markdown", - "id": "f6f19942", + "id": "15", "metadata": {}, "source": [ "## Pseudo Jaffe (PseudoJaffe)\n", @@ -360,11 +360,11 @@ { "cell_type": "code", "execution_count": null, - "id": "0c8985d8", + "id": "16", "metadata": {}, "outputs": [], "source": [ - "lens = caustics.lenses.PseudoJaffe(\n", + "lens = caustics.PseudoJaffe(\n", " cosmology=cosmology,\n", " x0=0.0,\n", " y0=0.0,\n", @@ -372,7 +372,7 @@ " core_radius=5e-1,\n", " scale_radius=15.0,\n", ")\n", - "sim = caustics.sims.Lens_Source(\n", + "sim = caustics.Lens_Source(\n", " lens=lens,\n", " source=base_sersic,\n", " pixelscale=res,\n", @@ -395,7 +395,7 @@ }, { "cell_type": "markdown", - "id": "b232c8b0", + "id": "17", "metadata": {}, "source": [ "## External Shear (ExternalShear)\n", @@ -406,18 +406,18 @@ { "cell_type": "code", "execution_count": null, - "id": "e1d7b927", + "id": "18", "metadata": {}, "outputs": [], "source": [ - "lens = caustics.lenses.ExternalShear(\n", + "lens = caustics.ExternalShear(\n", " cosmology=cosmology,\n", " x0=0.0,\n", " y0=0.0,\n", " gamma_1=1.0,\n", " gamma_2=-1.0,\n", ")\n", - "sim = caustics.sims.Lens_Source(\n", + "sim = caustics.Lens_Source(\n", " lens=lens,\n", " source=base_sersic,\n", " pixelscale=res,\n", @@ -438,7 +438,7 @@ }, { "cell_type": "markdown", - "id": "688d0796", + "id": "19", "metadata": {}, "source": [ "## Mass Sheet (MassSheet)\n", @@ -449,17 +449,17 @@ { "cell_type": "code", "execution_count": null, - "id": "cdfba784", + "id": "20", "metadata": {}, "outputs": [], "source": [ - "lens = caustics.lenses.MassSheet(\n", + "lens = caustics.MassSheet(\n", " cosmology=cosmology,\n", " x0=0.0,\n", " y0=0.0,\n", " surface_density=1.5,\n", ")\n", - "sim = caustics.sims.Lens_Source(\n", + "sim = caustics.Lens_Source(\n", " lens=lens,\n", " source=base_sersic,\n", " pixelscale=res,\n", @@ -483,7 +483,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d6c44b6c", + "id": "21", "metadata": {}, "outputs": [], "source": [] diff --git a/docs/source/tutorials/MultiplaneDemo.ipynb b/docs/source/tutorials/MultiplaneDemo.ipynb index 3e3275ed..e96aa09a 100644 --- a/docs/source/tutorials/MultiplaneDemo.ipynb +++ b/docs/source/tutorials/MultiplaneDemo.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "2a324070-a163-4b79-8e77-819da73083f3", + "id": "0", "metadata": {}, "source": [ "# Multiplane Lensing\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "45b6a8b4", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -35,12 +35,12 @@ { "cell_type": "code", "execution_count": null, - "id": "ab43e042", + "id": "2", "metadata": {}, "outputs": [], "source": [ "# initialization stuff for lenses\n", - "cosmology = caustics.cosmology.FlatLambdaCDM(name=\"cosmo\")\n", + "cosmology = caustics.FlatLambdaCDM(name=\"cosmo\")\n", "cosmology.to(dtype=torch.float32)\n", "n_pix = 100\n", "res = 0.5\n", @@ -58,7 +58,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ea49d25d", + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -73,7 +73,7 @@ "\n", " for _ in range(N_lenses):\n", " lenses.append(\n", - " caustics.lenses.SIE(\n", + " caustics.SIE(\n", " cosmology=cosmology,\n", " z_l=z_p,\n", " x0=np.random.uniform(-fov / 4.0, fov / 4.0),\n", @@ -85,17 +85,17 @@ " )\n", "\n", " planes.append(\n", - " caustics.lenses.SinglePlane(\n", + " caustics.SinglePlane(\n", " z_l=z_p, cosmology=cosmology, lenses=lenses, name=f\"plane_{p}\"\n", " )\n", " )\n", "\n", - "lens = caustics.lenses.Multiplane(name=\"multiplane\", cosmology=cosmology, lenses=planes)" + "lens = caustics.Multiplane(name=\"multiplane\", cosmology=cosmology, lenses=planes)" ] }, { "cell_type": "markdown", - "id": "4aa429c8", + "id": "4", "metadata": {}, "source": [ "## Effective Reduced Deflection Angles" @@ -104,7 +104,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f2e0a341", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -128,7 +128,7 @@ }, { "cell_type": "markdown", - "id": "c7ad98c6", + "id": "6", "metadata": {}, "source": [ "## Critical Lines" @@ -137,7 +137,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b2e23c3e", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -162,7 +162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7cd1f948", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -189,7 +189,7 @@ }, { "cell_type": "markdown", - "id": "94144e25", + "id": "9", "metadata": {}, "source": [ "## Effective Convergence" @@ -198,7 +198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d8a84fde", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -213,7 +213,7 @@ { "cell_type": "code", "execution_count": null, - "id": "708338df", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -227,7 +227,7 @@ }, { "cell_type": "markdown", - "id": "d6933714-dbce-4e5a-a92a-cd1a7b4df54e", + "id": "12", "metadata": {}, "source": [ "## Time Delay" @@ -236,7 +236,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9a4c6c28-2dd4-4f96-a4d4-9e5d44f3aa77", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -251,7 +251,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28cbf540-c3f4-4cfc-8fb8-d7dad1eee028", + "id": "14", "metadata": {}, "outputs": [], "source": [] diff --git a/docs/source/tutorials/Parameters.ipynb b/docs/source/tutorials/Parameters.ipynb index a4823871..ddf2fcda 100644 --- a/docs/source/tutorials/Parameters.ipynb +++ b/docs/source/tutorials/Parameters.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "6527560d-1978-4187-8b40-50867258cb62", + "id": "0", "metadata": {}, "source": [ "# Parameters\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "36b860f4-5675-48a3-a867-f4e1ddaf47ce", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -33,7 +33,7 @@ }, { "cell_type": "markdown", - "id": "15337194-64a3-4048-9564-1ff1295bf283", + "id": "2", "metadata": {}, "source": [ "## Setting static/dynamic parameters\n", @@ -44,31 +44,31 @@ { "cell_type": "code", "execution_count": null, - "id": "c9ae5563-6e55-4af3-a20a-cacfd257a2e8", + "id": "3", "metadata": {}, "outputs": [], "source": [ "# Flat cosmology with all dynamic parameters\n", - "cosmo = caustics.cosmology.FlatLambdaCDM(name=\"cosmo\", h0=None, Om0=None)\n", + "cosmo = caustics.FlatLambdaCDM(name=\"cosmo\", h0=None, Om0=None)\n", "\n", "# SIE lens with q and b as static parameters\n", - "lens = caustics.lenses.SIE(cosmology=cosmo, q=0.4, b=1.0)\n", + "lens = caustics.SIE(cosmology=cosmo, q=0.4, b=1.0)\n", "\n", "# Sersic with all dynamic parameters except the sersic index, effective radius, and effective brightness\n", - "source = caustics.light.Sersic(name=\"source\", n=2.0, Re=1.0, Ie=1.0)\n", + "source = caustics.Sersic(name=\"source\", n=2.0, Re=1.0, Ie=1.0)\n", "\n", "# Sersic with all dynamic parameters except the x position, position angle, and effective radius\n", - "lens_light = caustics.light.Sersic(name=\"lenslight\", x0=0.0, phi=1.3, Re=1.0)\n", + "lens_light = caustics.Sersic(name=\"lenslight\", x0=0.0, phi=1.3, Re=1.0)\n", "\n", "# A simulator which captures all these parameters into a single DAG\n", - "sim = caustics.sims.Lens_Source(\n", + "sim = caustics.Lens_Source(\n", " lens=lens, source=source, lens_light=lens_light, pixelscale=0.05, pixels_x=100\n", ")" ] }, { "cell_type": "markdown", - "id": "28dea7af-b5fe-468a-bdbf-66fceba75945", + "id": "4", "metadata": {}, "source": [ "We can have the simulator print a graph of the DAG from it's perspective. Note that the white boxes are dynamic parameters while the grey boxes are static parameters" @@ -77,7 +77,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0041971d-0f17-4ee4-9b6a-81bea8c51369", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -87,7 +87,7 @@ { "cell_type": "code", "execution_count": null, - "id": "298aa458-e518-424d-af9d-1ce45d55e4f4", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -99,7 +99,7 @@ { "cell_type": "code", "execution_count": null, - "id": "08e6a094-5630-45cd-ab6b-1dbd89b3ebad", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -111,7 +111,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f6bbccd-ae91-444f-a8ac-ec31b1562d78", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -177,7 +177,7 @@ }, { "cell_type": "markdown", - "id": "f3a1a9db-d825-47e8-96f1-c556f1dbe32c", + "id": "9", "metadata": {}, "source": [ "## Manual Inputs\n", @@ -188,7 +188,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5b77bf0e-49dd-424c-8268-cb654297a896", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -221,7 +221,7 @@ { "cell_type": "code", "execution_count": null, - "id": "064832a0-e59b-4e91-8438-edf1aeff471a", + "id": "11", "metadata": {}, "outputs": [], "source": [] diff --git a/docs/source/tutorials/Playground.ipynb b/docs/source/tutorials/Playground.ipynb index de66dc36..f4e49419 100644 --- a/docs/source/tutorials/Playground.ipynb +++ b/docs/source/tutorials/Playground.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "1fafae89", + "id": "0", "metadata": {}, "source": [ "# Lensing playground\n", @@ -13,7 +13,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c0d85608", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -33,7 +33,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b1ac7bf7", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -49,14 +49,14 @@ ")\n", "z_l = torch.tensor(0.5, dtype=torch.float32)\n", "z_s = torch.tensor(1.5, dtype=torch.float32)\n", - "cosmology = caustics.cosmology.FlatLambdaCDM(name=\"cosmo\")\n", + "cosmology = caustics.FlatLambdaCDM(name=\"cosmo\")\n", "cosmology.to(dtype=torch.float32)" ] }, { "cell_type": "code", "execution_count": null, - "id": "964a76a6", + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -64,7 +64,7 @@ "\n", "\n", "def plot_lens_metrics(thx0, thy0, q, phi, b):\n", - " lens = caustics.lenses.SIE(\n", + " lens = caustics.SIE(\n", " cosmology=cosmology,\n", " z_l=z_l,\n", " x0=thx0,\n", @@ -107,7 +107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3b04c973", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -124,7 +124,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dce1edef", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -143,7 +143,7 @@ " Re_src,\n", " Ie_src,\n", "):\n", - " lens = caustics.lenses.SIE(\n", + " lens = caustics.SIE(\n", " cosmology,\n", " z_l,\n", " x0=x0_lens,\n", @@ -152,7 +152,7 @@ " phi=phi_lens,\n", " b=b_lens,\n", " )\n", - " source = caustics.light.Sersic(\n", + " source = caustics.Sersic(\n", " x0=x0_src,\n", " y0=y0_src,\n", " q=q_src,\n", @@ -184,7 +184,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c03161b9", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -208,7 +208,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35336e43", + "id": "7", "metadata": {}, "outputs": [], "source": [] diff --git a/docs/source/tutorials/Simulators.ipynb b/docs/source/tutorials/Simulators.ipynb index 91e9f2bd..1174a4a1 100644 --- a/docs/source/tutorials/Simulators.ipynb +++ b/docs/source/tutorials/Simulators.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "f3ed209f", + "id": "0", "metadata": {}, "source": [ "# Now you're thinking with Simulators\n", @@ -13,7 +13,7 @@ { "cell_type": "code", "execution_count": null, - "id": "89275b65", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -32,7 +32,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88a9200a", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -49,21 +49,21 @@ ")\n", "z_l = torch.tensor(0.5, dtype=torch.float32)\n", "z_s = torch.tensor(1.5, dtype=torch.float32)\n", - "cosmology = caustics.cosmology.FlatLambdaCDM(name=\"cosmo\")\n", + "cosmology = caustics.FlatLambdaCDM(name=\"cosmo\")\n", "cosmology.to(dtype=torch.float32)" ] }, { "cell_type": "code", "execution_count": null, - "id": "1665abeb", + "id": "3", "metadata": {}, "outputs": [], "source": [ "# demo simulator with sersic source, SIE lens. then sample some examples. demo the model graph\n", "\n", "\n", - "class Simple_Sim(caustics.sims.Simulator):\n", + "class Simple_Sim(caustics.Simulator):\n", " def __init__(\n", " self,\n", " lens,\n", @@ -95,12 +95,12 @@ { "cell_type": "code", "execution_count": null, - "id": "0babaead", + "id": "4", "metadata": {}, "outputs": [], "source": [ - "sie = caustics.lenses.SIE(cosmology, name=\"sie\")\n", - "src = caustics.light.Sersic(name=\"src\")\n", + "sie = caustics.SIE(cosmology, name=\"sie\")\n", + "src = caustics.Sersic(name=\"src\")\n", "\n", "sim = Simple_Sim(sie, src, torch.tensor(0.8))\n", "\n", @@ -110,7 +110,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e672be73", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -138,16 +138,16 @@ { "cell_type": "code", "execution_count": null, - "id": "fe16052c", + "id": "6", "metadata": {}, "outputs": [], "source": [ - "sie = caustics.lenses.SIE(cosmology, name=\"sie\")\n", + "sie = caustics.SIE(cosmology, name=\"sie\")\n", "hdu = fits.open(\n", " \"https://www.legacysurvey.org/viewer/fits-cutout?ra=36.3684&dec=-25.6389&size=250&layer=ls-dr9&pixscale=0.262&bands=r\"\n", ")\n", "image_data = np.array(hdu[0].data, dtype=np.float64)\n", - "src = caustics.light.Pixelated(\n", + "src = caustics.Pixelated(\n", " name=\"ESO479_G1\", image=torch.tensor(image_data, dtype=torch.float32)\n", ")\n", "\n", @@ -159,7 +159,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b9921f68", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -187,7 +187,7 @@ }, { "cell_type": "markdown", - "id": "e16ca698", + "id": "8", "metadata": {}, "source": [ "## Setting static/dynamic parameters\n", @@ -202,18 +202,18 @@ { "cell_type": "code", "execution_count": null, - "id": "91d3cf26", + "id": "9", "metadata": {}, "outputs": [], "source": [ - "sief = caustics.lenses.SIE(\n", + "sief = caustics.SIE(\n", " name=\"sie\",\n", " cosmology=cosmology,\n", " z_l=torch.tensor(0.5),\n", " x0=torch.tensor(0.0),\n", " y0=torch.tensor(0.0),\n", ")\n", - "srcf = caustics.light.Sersic(name=\"src\", n=torch.tensor(2.0))\n", + "srcf = caustics.Sersic(name=\"src\", n=torch.tensor(2.0))\n", "\n", "simf = Simple_Sim(sief, srcf, z_s=torch.tensor(0.8))\n", "\n", @@ -223,7 +223,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6b029159", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -247,7 +247,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6e244b74", + "id": "11", "metadata": {}, "outputs": [], "source": [] @@ -255,7 +255,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fb6619ec-9057-4daa-916b-9384b74a5e29", + "id": "12", "metadata": {}, "outputs": [], "source": [] diff --git a/docs/source/tutorials/VisualizeCaustics.ipynb b/docs/source/tutorials/VisualizeCaustics.ipynb index 206fb6a1..51345450 100644 --- a/docs/source/tutorials/VisualizeCaustics.ipynb +++ b/docs/source/tutorials/VisualizeCaustics.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "3bb2cd70-e98d-4e8d-b195-64d7e5de6e13", + "id": "0", "metadata": {}, "source": [ "# Visualize Caustics\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f716feef", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -36,15 +36,15 @@ { "cell_type": "code", "execution_count": null, - "id": "bdede2df", + "id": "2", "metadata": {}, "outputs": [], "source": [ "# initialization stuff for an SIE lens\n", "\n", - "cosmology = caustics.cosmology.FlatLambdaCDM(name=\"cosmo\")\n", + "cosmology = caustics.FlatLambdaCDM(name=\"cosmo\")\n", "cosmology.to(dtype=torch.float32)\n", - "sie = caustics.lenses.SIE(cosmology, name=\"sie\")\n", + "sie = caustics.SIE(cosmology, name=\"sie\")\n", "n_pix = 100\n", "res = 0.05\n", "upsample_factor = 2\n", @@ -72,7 +72,7 @@ }, { "cell_type": "markdown", - "id": "38dd09e3", + "id": "3", "metadata": {}, "source": [ "## Critical Lines\n", @@ -83,7 +83,7 @@ { "cell_type": "code", "execution_count": null, - "id": "487b3030", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -106,7 +106,7 @@ }, { "cell_type": "markdown", - "id": "3b099cfd", + "id": "5", "metadata": {}, "source": [ "## Caustics\n", @@ -117,7 +117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0c1f1177", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -139,7 +139,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1e77da2e", + "id": "7", "metadata": {}, "outputs": [], "source": [] diff --git a/pyproject.toml b/pyproject.toml index 75379562..59cc742e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,10 @@ Issues = "https://github.com/Ciela-Institute/caustics/issues" [project.optional-dependencies] dev = [ - "lenstronomy==1.11.1" + "lenstronomy==1.11.1", + "pytest>=8.0,<9", + "pytest-cov>=4.1,<5", + "pre-commit>=3.6,<4" ] [tool.hatch.metadata.hooks.requirements_txt] @@ -61,3 +64,6 @@ local_scheme = "no-local-version" [tool.ruff] # Same as Black. line-length = 100 + +[tool.pytest.ini_options] +norecursedirs = "tests/utils" diff --git a/requirements.txt b/requirements.txt index 85e4c715..6d29e11e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,8 @@ astropy>=5.2.1,<6.0.0 graphviz==0.20.1 h5py>=3.8.0 levmarq_torch==0.0.1 +numba>=0.58.1,<0.59.0 numpy>=1.23.5 +safetensors>=0.4.1 scipy>=1.8.0 torch>=2.0.0 diff --git a/src/caustics/__init__.py b/src/caustics/__init__.py index 6af251f4..e43a8a04 100644 --- a/src/caustics/__init__.py +++ b/src/caustics/__init__.py @@ -1,23 +1,65 @@ from ._version import version as VERSION # noqa -from . import constants, lenses, cosmology, packed, parametrized, light, utils, sims +from .cosmology import ( + Cosmology, + FlatLambdaCDM, + h0_default, + critical_density_0_default, + Om0_default, +) +from .lenses import ( + ThinLens, + ThickLens, + EPL, + ExternalShear, + PixelatedConvergence, + Multiplane, + NFW, + Point, + PseudoJaffe, + SIE, + SIS, + SinglePlane, + MassSheet, + TNFW, +) +from .light import Source, Pixelated, Sersic # PROBESDataset conflicts with .data +from .data import HDF5Dataset, IllustrisKappaDataset, PROBESDataset +from . import utils +from .sims import Lens_Source, Simulator from .tests import test -# from .demo import * - __version__ = VERSION __author__ = "Ciela" __all__ = [ - # Modules - "constants", - "lenses", - "cosmology", - "packed", - "parametrized", - "light", + "Cosmology", + "FlatLambdaCDM", + "h0_default", + "critical_density_0_default", + "Om0_default", + "ThinLens", + "ThickLens", + "EPL", + "ExternalShear", + "PixelatedConvergence", + "Multiplane", + "NFW", + "Point", + "PseudoJaffe", + "SIE", + "SIS", + "SinglePlane", + "MassSheet", + "TNFW", + "Source", + "Pixelated", + "Sersic", + "HDF5Dataset", + "IllustrisKappaDataset", + "PROBESDataset", "utils", - "sims", - # Functions + "Lens_Source", + "Simulator", "test", ] diff --git a/src/caustics/cosmology/FlatLambdaCDM.py b/src/caustics/cosmology/FlatLambdaCDM.py new file mode 100644 index 00000000..ffc21a68 --- /dev/null +++ b/src/caustics/cosmology/FlatLambdaCDM.py @@ -0,0 +1,201 @@ +# mypy: disable-error-code="operator" +from typing import Optional + +import torch +from torch import Tensor + +from astropy.cosmology import default_cosmology +from scipy.special import hyp2f1 + +from ..utils import interp1d +from ..parametrized import unpack +from ..packed import Packed +from ..constants import c_Mpc_s, km_to_Mpc +from .base import ( + Cosmology, +) + +_h0_default = float(default_cosmology.get().h) +_critical_density_0_default = float( + default_cosmology.get().critical_density(0).to("solMass/Mpc^3").value +) +_Om0_default = float(default_cosmology.get().Om0) + +# Set up interpolator to speed up comoving distance calculations in Lambda-CDM +# cosmologies. Construct with float64 precision. +_comoving_distance_helper_x_grid = 10 ** torch.linspace(-3, 1, 500, dtype=torch.float64) +_comoving_distance_helper_y_grid = torch.as_tensor( + _comoving_distance_helper_x_grid + * hyp2f1(1 / 3, 1 / 2, 4 / 3, -(_comoving_distance_helper_x_grid**3)), + dtype=torch.float64, +) + +h0_default = torch.tensor(_h0_default) +critical_density_0_default = torch.tensor(_critical_density_0_default) +Om0_default = torch.tensor(_Om0_default) + + +class FlatLambdaCDM(Cosmology): + """ + Subclass of Cosmology representing a Flat Lambda Cold Dark Matter (LCDM) + cosmology with no radiation. + """ + + def __init__( + self, + h0: Optional[Tensor] = h0_default, + critical_density_0: Optional[Tensor] = critical_density_0_default, + Om0: Optional[Tensor] = Om0_default, + name: Optional[str] = None, + ): + """ + Initialize a new instance of the FlatLambdaCDM class. + + Parameters + ---------- + name: str + Name of the cosmology. + h0: Optional[Tensor] + Hubble constant over 100. Default is h0_default. + critical_density_0: (Optional[Tensor]) + Critical density at z=0. Default is critical_density_0_default. + Om0: Optional[Tensor] + Matter density parameter at z=0. Default is Om0_default. + """ + super().__init__(name) + + self.add_param("h0", h0) + self.add_param("critical_density_0", critical_density_0) + self.add_param("Om0", Om0) + + self._comoving_distance_helper_x_grid = _comoving_distance_helper_x_grid.to( + dtype=torch.float32 + ) + self._comoving_distance_helper_y_grid = _comoving_distance_helper_y_grid.to( + dtype=torch.float32 + ) + + def to( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ): + super().to(device, dtype) + self._comoving_distance_helper_y_grid = ( + self._comoving_distance_helper_y_grid.to(device, dtype) + ) + self._comoving_distance_helper_x_grid = ( + self._comoving_distance_helper_x_grid.to(device, dtype) + ) + + def hubble_distance(self, h0): + """ + Calculate the Hubble distance. + + Parameters + ---------- + h0: Tensor + Hubble constant. + + Returns + ------- + Tensor + Hubble distance. + """ + return c_Mpc_s / (100 * km_to_Mpc) / h0 + + @unpack + def critical_density( + self, + z: Tensor, + *args, + params: Optional["Packed"] = None, + h0: Optional[Tensor] = None, + critical_density_0: Optional[Tensor] = None, + Om0: Optional[Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Calculate the critical density at redshift z. + + Parameters + ---------- + z: Tensor + Redshift. + params: (Packed, optional) + Dynamic parameter container for the computation. + + Returns + ------- + torch.Tensor + Critical density at redshift z. + """ + Ode0 = 1 - Om0 + return critical_density_0 * (Om0 * (1 + z) ** 3 + Ode0) # fmt: skip + + @unpack + def _comoving_distance_helper( + self, x: Tensor, *args, params: Optional["Packed"] = None, **kwargs + ) -> Tensor: + """ + Helper method for computing comoving distances. + + Parameters + ---------- + x: Tensor + Input tensor. + + Returns + ------- + Tensor + Computed comoving distances. + """ + return interp1d( + self._comoving_distance_helper_x_grid, + self._comoving_distance_helper_y_grid, + torch.atleast_1d(x), + ).reshape(x.shape) + + @unpack + def comoving_distance( + self, + z: Tensor, + *args, + params: Optional["Packed"] = None, + h0: Optional[Tensor] = None, + critical_density_0: Optional[Tensor] = None, + Om0: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + """ + Calculate the comoving distance to redshift z. + + Parameters + ---------- + z: Tensor + Redshift. + params: (Packed, optional) + Dynamic parameter container for the computation. + + Returns + ------- + Tensor + Comoving distance to redshift z. + """ + Ode0 = 1 - Om0 + ratio = (Om0 / Ode0) ** (1 / 3) + DH = self.hubble_distance(h0) + DC1z = self._comoving_distance_helper((1 + z) * ratio, params) + DC = self._comoving_distance_helper(ratio, params) + return DH * (DC1z - DC) / (Om0 ** (1 / 3) * Ode0 ** (1 / 6)) # fmt: skip + + @unpack + def transverse_comoving_distance( + self, + z: Tensor, + *args, + params: Optional["Packed"] = None, + h0: Optional[Tensor] = None, + critical_density_0: Optional[Tensor] = None, + Om0: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + return self.comoving_distance(z, params, **kwargs) diff --git a/src/caustics/cosmology/__init__.py b/src/caustics/cosmology/__init__.py new file mode 100644 index 00000000..c5146d0d --- /dev/null +++ b/src/caustics/cosmology/__init__.py @@ -0,0 +1,15 @@ +from .base import Cosmology +from .FlatLambdaCDM import ( + FlatLambdaCDM, + h0_default, + critical_density_0_default, + Om0_default, +) + +__all__ = [ + "Cosmology", + "FlatLambdaCDM", + "h0_default", + "critical_density_0_default", + "Om0_default", +] diff --git a/src/caustics/cosmology.py b/src/caustics/cosmology/base.py similarity index 54% rename from src/caustics/cosmology.py rename to src/caustics/cosmology/base.py index 51cee50f..31528f1d 100644 --- a/src/caustics/cosmology.py +++ b/src/caustics/cosmology/base.py @@ -1,39 +1,13 @@ +# mypy: disable-error-code="operator" from abc import abstractmethod from math import pi from typing import Optional -import torch -from astropy.cosmology import default_cosmology -from scipy.special import hyp2f1 from torch import Tensor -from .utils import interp1d -from .constants import G_over_c2, c_Mpc_s, km_to_Mpc -from .parametrized import Parametrized, unpack -from .packed import Packed - -__all__ = ( - "h0_default", - "critical_density_0_default", - "Om0_default", - "Cosmology", - "FlatLambdaCDM", -) - -h0_default = float(default_cosmology.get().h) -critical_density_0_default = float( - default_cosmology.get().critical_density(0).to("solMass/Mpc^3").value -) -Om0_default = float(default_cosmology.get().Om0) - -# Set up interpolator to speed up comoving distance calculations in Lambda-CDM -# cosmologies. Construct with float64 precision. -_comoving_distance_helper_x_grid = 10 ** torch.linspace(-3, 1, 500, dtype=torch.float64) -_comoving_distance_helper_y_grid = torch.as_tensor( - _comoving_distance_helper_x_grid - * hyp2f1(1 / 3, 1 / 2, 4 / 3, -(_comoving_distance_helper_x_grid**3)), - dtype=torch.float64, -) +from ..constants import G_over_c2 +from ..parametrized import Parametrized, unpack +from ..packed import Packed class Cosmology(Parametrized): @@ -56,7 +30,7 @@ class Cosmology(Parametrized): Name of the cosmological model. """ - def __init__(self, name: str = None): + def __init__(self, name: Optional[str] = None): """ Initialize the Cosmology. @@ -283,169 +257,3 @@ def critical_surface_density( d_s = self.angular_diameter_distance(z_s, params) d_ls = self.angular_diameter_distance_z1z2(z_l, z_s, params) return d_s / (4 * pi * G_over_c2 * d_l * d_ls) # fmt: skip - - -class FlatLambdaCDM(Cosmology): - """ - Subclass of Cosmology representing a Flat Lambda Cold Dark Matter (LCDM) - cosmology with no radiation. - """ - - def __init__( - self, - h0: Optional[Tensor] = torch.tensor(h0_default), - critical_density_0: Optional[Tensor] = torch.tensor(critical_density_0_default), - Om0: Optional[Tensor] = torch.tensor(Om0_default), - name: str = None, - ): - """ - Initialize a new instance of the FlatLambdaCDM class. - - Parameters - ---------- - name: str - Name of the cosmology. - h0: Optional[Tensor] - Hubble constant over 100. Default is h0_default. - critical_density_0: (Optional[Tensor]) - Critical density at z=0. Default is critical_density_0_default. - Om0: Optional[Tensor] - Matter density parameter at z=0. Default is Om0_default. - """ - super().__init__(name) - - self.add_param("h0", h0) - self.add_param("critical_density_0", critical_density_0) - self.add_param("Om0", Om0) - - self._comoving_distance_helper_x_grid = _comoving_distance_helper_x_grid.to( - dtype=torch.float32 - ) - self._comoving_distance_helper_y_grid = _comoving_distance_helper_y_grid.to( - dtype=torch.float32 - ) - - def to( - self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None - ): - super().to(device, dtype) - self._comoving_distance_helper_y_grid = ( - self._comoving_distance_helper_y_grid.to(device, dtype) - ) - self._comoving_distance_helper_x_grid = ( - self._comoving_distance_helper_x_grid.to(device, dtype) - ) - - def hubble_distance(self, h0): - """ - Calculate the Hubble distance. - - Parameters - ---------- - h0: Tensor - Hubble constant. - - Returns - ------- - Tensor - Hubble distance. - """ - return c_Mpc_s / (100 * km_to_Mpc) / h0 - - @unpack - def critical_density( - self, - z: Tensor, - *args, - params: Optional["Packed"] = None, - h0: Tensor = None, - critical_density_0: Tensor = None, - Om0: Tensor = None, - **kwargs, - ) -> torch.Tensor: - """ - Calculate the critical density at redshift z. - - Parameters - ---------- - z: Tensor - Redshift. - params: (Packed, optional) - Dynamic parameter container for the computation. - - Returns - ------- - torch.Tensor - Critical density at redshift z. - """ - Ode0 = 1 - Om0 - return critical_density_0 * (Om0 * (1 + z) ** 3 + Ode0) # fmt: skip - - @unpack - def _comoving_distance_helper( - self, x: Tensor, *args, params: Optional["Packed"] = None, **kwargs - ) -> Tensor: - """ - Helper method for computing comoving distances. - - Parameters - ---------- - x: Tensor - Input tensor. - - Returns - ------- - Tensor - Computed comoving distances. - """ - return interp1d( - self._comoving_distance_helper_x_grid, - self._comoving_distance_helper_y_grid, - torch.atleast_1d(x), - ).reshape(x.shape) - - @unpack - def comoving_distance( - self, - z: Tensor, - *args, - params: Optional["Packed"] = None, - h0: Tensor = None, - critical_density_0: Tensor = None, - Om0: Tensor = None, - **kwargs, - ) -> Tensor: - """ - Calculate the comoving distance to redshift z. - - Parameters - ---------- - z: Tensor - Redshift. - params: (Packed, optional) - Dynamic parameter container for the computation. - - Returns - ------- - Tensor - Comoving distance to redshift z. - """ - Ode0 = 1 - Om0 - ratio = (Om0 / Ode0) ** (1 / 3) - DH = self.hubble_distance(h0) - DC1z = self._comoving_distance_helper((1 + z) * ratio, params) - DC = self._comoving_distance_helper(ratio, params) - return DH * (DC1z - DC) / (Om0 ** (1 / 3) * Ode0 ** (1 / 6)) # fmt: skip - - @unpack - def transverse_comoving_distance( - self, - z: Tensor, - *args, - params: Optional["Packed"] = None, - h0: Tensor = None, - critical_density_0: Tensor = None, - Om0: Tensor = None, - **kwargs, - ) -> Tensor: - return self.comoving_distance(z, params, **kwargs) diff --git a/src/caustics/io.py b/src/caustics/io.py new file mode 100644 index 00000000..ddb517c1 --- /dev/null +++ b/src/caustics/io.py @@ -0,0 +1,126 @@ +from pathlib import Path +import json +import struct + +DEFAULT_ENCODING = "utf-8" +SAFETENSORS_METADATA = "__metadata__" + + +def _normalize_path(path: "str | Path") -> Path: + # Convert string path to Path object + if isinstance(path, str): + path = Path(path) + + # Get absolute path + return path.absolute() + + +def to_file( + path: "str | Path", data: "str | bytes", encoding: str = DEFAULT_ENCODING +) -> str: + """ + Save data string or bytes to specified file path + + Parameters + ---------- + path : str or Path + The path to save the data to + data : str | bytes + The data string or bytes to save to file + encoding : str, optional + The string encoding to use, by default "utf-8" + + Returns + ------- + str + The path string where the data is saved + """ + # TODO: Update to allow for remote paths saving + + # Convert string data to bytes + if isinstance(data, str): + data = data.encode(encoding) + + # Normalize path to pathlib.Path object + path = _normalize_path(path) + + with open(path, "wb") as f: + f.write(data) + + return str(path.absolute()) + + +def from_file(path: "str | Path") -> bytes: + """ + Load data from specified file path + + Parameters + ---------- + path : str or Path + The path to load the data from + + Returns + ------- + bytes + The data bytes loaded from the file + """ + # TODO: Update to allow for remote paths loading + + # Normalize path to pathlib.Path object + path = _normalize_path(path) + + return path.read_bytes() + + +def _get_safetensors_header(path: "str | Path") -> dict: + """ + Read specified file header to a dictionary + + Parameters + ---------- + path : str or Path + The path to get header from + + Returns + ------- + dict + The header dictionary + """ + # TODO: Update to allow for remote paths loading of header + + # Normalize path to pathlib.Path object + path = _normalize_path(path) + + # Doing this avoids reading the whole safetensors + # file in case that it's large + with open(path, "rb") as f: + # Get the size of the header by reading first 8 bytes + (length_of_header,) = struct.unpack(" dict: + """ + Get the metadata from the specified file path + + Parameters + ---------- + path : str or Path + The path to get the metadata from + + Returns + ------- + dict + The metadata dictionary + """ + header = _get_safetensors_header(path) + + # Only return the metadata + # if it's not even there, just return blank dict + return header.get(SAFETENSORS_METADATA, {}) diff --git a/src/caustics/lenses/base.py b/src/caustics/lenses/base.py index 9232e17e..9dced0dd 100644 --- a/src/caustics/lenses/base.py +++ b/src/caustics/lenses/base.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="call-overload" from abc import abstractmethod from typing import Optional, Union from functools import partial @@ -21,7 +22,7 @@ class Lens(Parametrized): Base class for all lenses """ - def __init__(self, cosmology: Cosmology, name: str = None): + def __init__(self, cosmology: Cosmology, name: Optional[str] = None): """ Initializes a new instance of the Lens class. @@ -99,9 +100,7 @@ def magnification( Tensor Gravitational magnification at the given coordinates. """ - return get_magnification( - partial(self.raytrace, params=params), x, y, z_s, **kwargs - ) + return get_magnification(partial(self.raytrace, params=params), x, y, z_s) @unpack def forward_raytrace( @@ -517,7 +516,7 @@ def _jacobian_lens_equation_finitediff( J = self._jacobian_effective_deflection_angle_finitediff( x, y, z_s, pixelscale, params, **kwargs ) - return torch.eye(2) - J + return torch.eye(2).to(J.device) - J @unpack def _jacobian_lens_equation_autograd( @@ -537,7 +536,7 @@ def _jacobian_lens_equation_autograd( J = self._jacobian_effective_deflection_angle_autograd( x, y, z_s, params, **kwargs ) - return torch.eye(2) - J.detach() + return torch.eye(2).to(J.device) - J.detach() @unpack def effective_convergence_div( @@ -609,7 +608,7 @@ def __init__( self, cosmology: Cosmology, z_l: Optional[Union[Tensor, float]] = None, - name: str = None, + name: Optional[str] = None, ): super().__init__(cosmology=cosmology, name=name) self.add_param("z_l", z_l) @@ -622,7 +621,7 @@ def reduced_deflection_angle( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, + z_l: Optional[Tensor] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """ @@ -662,7 +661,7 @@ def physical_deflection_angle( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, + z_l: Optional[Tensor] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """ @@ -766,7 +765,7 @@ def surface_density( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, + z_l: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -844,7 +843,7 @@ def time_delay( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, + z_l: Optional[Tensor] = None, shapiro_time_delay: bool = True, geometric_time_delay: bool = True, **kwargs, @@ -1017,7 +1016,7 @@ def _jacobian_lens_equation_finitediff( J = self._jacobian_deflection_angle_finitediff( x, y, z_s, pixelscale, params, **kwargs ) - return torch.eye(2) - J + return torch.eye(2).to(J.device) - J @unpack def _jacobian_lens_equation_autograd( @@ -1035,4 +1034,4 @@ def _jacobian_lens_equation_autograd( """ # Build Jacobian J = self._jacobian_deflection_angle_autograd(x, y, z_s, params, **kwargs) - return torch.eye(2) - J.detach() + return torch.eye(2).to(J.device) - J.detach() diff --git a/src/caustics/lenses/epl.py b/src/caustics/lenses/epl.py index ab177b54..d4b09577 100644 --- a/src/caustics/lenses/epl.py +++ b/src/caustics/lenses/epl.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="operator" from typing import Optional, Union import torch @@ -80,7 +81,7 @@ def __init__( t: Optional[Union[Tensor, float]] = None, s: float = 0.0, n_iter: int = 18, - name: str = None, + name: Optional[str] = None, ): """ Initialize an EPL lens model. @@ -137,13 +138,13 @@ def reduced_deflection_angle( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - q: Tensor = None, - phi: Tensor = None, - b: Tensor = None, - t: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + q: Optional[Tensor] = None, + phi: Optional[Tensor] = None, + b: Optional[Tensor] = None, + t: Optional[Tensor] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """ @@ -220,13 +221,13 @@ def potential( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - q: Tensor = None, - phi: Tensor = None, - b: Tensor = None, - t: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + q: Optional[Tensor] = None, + phi: Optional[Tensor] = None, + b: Optional[Tensor] = None, + t: Optional[Tensor] = None, **kwargs, ): """ @@ -251,7 +252,7 @@ def potential( ax, ay = self.reduced_deflection_angle(x, y, z_s, params) ax, ay = derotate(ax, ay, -phi) x, y = translate_rotate(x, y, x0, y0, phi) - return (x * ax + y * ay) / (2 - t) + return (x * ax + y * ay) / (2 - t) # fmt: skip @unpack def convergence( @@ -261,15 +262,15 @@ def convergence( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - q: Tensor = None, - phi: Tensor = None, - b: Tensor = None, - t: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + q: Optional[Tensor] = None, + phi: Optional[Tensor] = None, + b: Optional[Tensor] = None, + t: Optional[Tensor] = None, **kwargs, - ): + ) -> Tensor: """ Compute the convergence of the lens, which describes the local density of the lens. diff --git a/src/caustics/lenses/external_shear.py b/src/caustics/lenses/external_shear.py index 6aee98a7..3f4c226a 100644 --- a/src/caustics/lenses/external_shear.py +++ b/src/caustics/lenses/external_shear.py @@ -50,7 +50,7 @@ def __init__( gamma_1: Optional[Union[Tensor, float]] = None, gamma_2: Optional[Union[Tensor, float]] = None, s: float = 0.0, - name: str = None, + name: Optional[str] = None, ): super().__init__(cosmology, z_l, name=name) @@ -68,11 +68,11 @@ def reduced_deflection_angle( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - gamma_1: Tensor = None, - gamma_2: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + gamma_1: Optional[Tensor] = None, + gamma_2: Optional[Tensor] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """ @@ -113,11 +113,11 @@ def potential( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - gamma_1: Tensor = None, - gamma_2: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + gamma_1: Optional[Tensor] = None, + gamma_2: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -151,11 +151,11 @@ def convergence( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - gamma_1: Tensor = None, - gamma_2: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + gamma_1: Optional[Tensor] = None, + gamma_2: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ diff --git a/src/caustics/lenses/mass_sheet.py b/src/caustics/lenses/mass_sheet.py index d51c5a0c..612e5e2d 100644 --- a/src/caustics/lenses/mass_sheet.py +++ b/src/caustics/lenses/mass_sheet.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="operator" from typing import Optional, Union import torch @@ -48,7 +49,7 @@ def __init__( x0: Optional[Union[Tensor, float]] = None, y0: Optional[Union[Tensor, float]] = None, surface_density: Optional[Union[Tensor, float]] = None, - name: str = None, + name: Optional[str] = None, ): super().__init__(cosmology, z_l, name=name) @@ -64,10 +65,10 @@ def reduced_deflection_angle( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - surface_density: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + surface_density: Optional[Tensor] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """ @@ -103,10 +104,10 @@ def potential( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - surface_density: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + surface_density: Optional[Tensor] = None, **kwargs, ) -> Tensor: # Meneghetti eq 3.81 @@ -120,10 +121,10 @@ def convergence( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - surface_density: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + surface_density: Optional[Tensor] = None, **kwargs, ) -> Tensor: # Essentially by definition diff --git a/src/caustics/lenses/multiplane.py b/src/caustics/lenses/multiplane.py index 8ae1b641..90b6f812 100644 --- a/src/caustics/lenses/multiplane.py +++ b/src/caustics/lenses/multiplane.py @@ -32,7 +32,9 @@ class Multiplane(ThickLens): List of thin lenses. """ - def __init__(self, cosmology: Cosmology, lenses: list[ThinLens], name: str = None): + def __init__( + self, cosmology: Cosmology, lenses: list[ThinLens], name: Optional[str] = None + ): super().__init__(cosmology, name=name) self.lenses = lenses for lens in lenses: @@ -118,9 +120,7 @@ def _raytrace_helper( ) TD += (-tau_ij * beta_ij * arcsec_to_rad**2) * potential if geometric_time_delay: - TD += (tau_ij * arcsec_to_rad**2 * 0.5) * ( - alpha_x**2 + alpha_y**2 - ) + TD += (tau_ij * arcsec_to_rad**2 * 0.5) * (alpha_x**2 + alpha_y**2) # Propagate rays to next plane (basically eq 18) X = X + D * theta_x * arcsec_to_rad diff --git a/src/caustics/lenses/nfw.py b/src/caustics/lenses/nfw.py index cfe68f18..148ac1ae 100644 --- a/src/caustics/lenses/nfw.py +++ b/src/caustics/lenses/nfw.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="operator,union-attr" from math import pi from typing import Optional, Union @@ -82,7 +83,7 @@ def __init__( c: Optional[Union[Tensor, float]] = None, s: float = 0.0, use_case="batchable", - name: str = None, + name: Optional[str] = None, ): """ Initialize an instance of the NFW lens class. @@ -133,11 +134,11 @@ def get_scale_radius( self, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - m: Tensor = None, - c: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + m: Optional[Tensor] = None, + c: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -168,11 +169,11 @@ def get_scale_density( self, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - m: Tensor = None, - c: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + m: Optional[Tensor] = None, + c: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -201,11 +202,11 @@ def get_convergence_s( z_s, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - m: Tensor = None, - c: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + m: Optional[Tensor] = None, + c: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -388,11 +389,11 @@ def reduced_deflection_angle( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - m: Tensor = None, - c: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + m: Optional[Tensor] = None, + c: Optional[Tensor] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """ @@ -437,11 +438,11 @@ def convergence( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - m: Tensor = None, - c: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + m: Optional[Tensor] = None, + c: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -480,11 +481,11 @@ def potential( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - m: Tensor = None, - c: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + m: Optional[Tensor] = None, + c: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ diff --git a/src/caustics/lenses/pixelated_convergence.py b/src/caustics/lenses/pixelated_convergence.py index 10cda750..af2b84f5 100644 --- a/src/caustics/lenses/pixelated_convergence.py +++ b/src/caustics/lenses/pixelated_convergence.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="index" from math import pi from typing import Optional @@ -36,7 +37,7 @@ def __init__( convolution_mode: str = "fft", use_next_fast_len: bool = True, padding: str = "zero", - name: str = None, + name: Optional[str] = None, ): """Strong lensing with user provided kappa map @@ -260,10 +261,10 @@ def reduced_deflection_angle( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - convergence_map: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + convergence_map: Optional[Tensor] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """ @@ -367,10 +368,10 @@ def potential( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - convergence_map: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + convergence_map: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -453,10 +454,10 @@ def convergence( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - convergence_map: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + convergence_map: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ diff --git a/src/caustics/lenses/point.py b/src/caustics/lenses/point.py index b640f84d..619a1482 100644 --- a/src/caustics/lenses/point.py +++ b/src/caustics/lenses/point.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="operator" from typing import Optional, Union import torch @@ -48,7 +49,7 @@ def __init__( y0: Optional[Union[Tensor, float]] = None, th_ein: Optional[Union[Tensor, float]] = None, s: float = 0.0, - name: str = None, + name: Optional[str] = None, ): """ Initialize the Point class. @@ -85,10 +86,10 @@ def reduced_deflection_angle( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - th_ein: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + th_ein: Optional[Tensor] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """ @@ -124,10 +125,10 @@ def potential( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - th_ein: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + th_ein: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -161,10 +162,10 @@ def convergence( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - th_ein: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + th_ein: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ diff --git a/src/caustics/lenses/pseudo_jaffe.py b/src/caustics/lenses/pseudo_jaffe.py index 671333bf..ebf1cb3b 100644 --- a/src/caustics/lenses/pseudo_jaffe.py +++ b/src/caustics/lenses/pseudo_jaffe.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="operator" from math import pi from typing import Optional, Union @@ -60,7 +61,7 @@ def __init__( core_radius: Optional[Union[Tensor, float]] = None, scale_radius: Optional[Union[Tensor, float]] = None, s: float = 0.0, - name: str = None, + name: Optional[str] = None, ): """ Initialize the PseudoJaffe class. @@ -101,12 +102,12 @@ def get_convergence_0( z_s, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - mass: Tensor = None, - core_radius: Tensor = None, - scale_radius: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + mass: Optional[Tensor] = None, + core_radius: Optional[Tensor] = None, + scale_radius: Optional[Tensor] = None, **kwargs, ): d_l = self.cosmology.angular_diameter_distance(z_l, params) @@ -120,12 +121,12 @@ def mass_enclosed_2d( z_s, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - mass: Tensor = None, - core_radius: Tensor = None, - scale_radius: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + mass: Optional[Tensor] = None, + core_radius: Optional[Tensor] = None, + scale_radius: Optional[Tensor] = None, **kwargs, ): """ @@ -199,12 +200,12 @@ def reduced_deflection_angle( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - mass: Tensor = None, - core_radius: Tensor = None, - scale_radius: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + mass: Optional[Tensor] = None, + core_radius: Optional[Tensor] = None, + scale_radius: Optional[Tensor] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """Calculate the deflection angle. @@ -241,12 +242,12 @@ def potential( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - mass: Tensor = None, - core_radius: Tensor = None, - scale_radius: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + mass: Optional[Tensor] = None, + core_radius: Optional[Tensor] = None, + scale_radius: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -298,12 +299,12 @@ def convergence( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - mass: Tensor = None, - core_radius: Tensor = None, - scale_radius: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + mass: Optional[Tensor] = None, + core_radius: Optional[Tensor] = None, + scale_radius: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ diff --git a/src/caustics/lenses/sie.py b/src/caustics/lenses/sie.py index 7ea2a5e7..85f6dde9 100644 --- a/src/caustics/lenses/sie.py +++ b/src/caustics/lenses/sie.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="operator,union-attr" from typing import Optional, Union from torch import Tensor @@ -56,7 +57,7 @@ def __init__( phi: Optional[Union[Tensor, float]] = None, b: Optional[Union[Tensor, float]] = None, s: float = 0.0, - name: str = None, + name: Optional[str] = None, ): """ Initialize the SIE lens model. @@ -98,12 +99,12 @@ def reduced_deflection_angle( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - q: Tensor = None, - phi: Tensor = None, - b: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + q: Optional[Tensor] = None, + phi: Optional[Tensor] = None, + b: Optional[Tensor] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """ @@ -141,12 +142,12 @@ def potential( z_s: Tensor, *args, params: Optional["Packed"] = None, - x0: Tensor = None, - z_l: Tensor = None, - y0: Tensor = None, - q: Tensor = None, - phi: Tensor = None, - b: Tensor = None, + x0: Optional[Tensor] = None, + z_l: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + q: Optional[Tensor] = None, + phi: Optional[Tensor] = None, + b: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -181,12 +182,12 @@ def convergence( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - q: Tensor = None, - phi: Tensor = None, - b: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + q: Optional[Tensor] = None, + phi: Optional[Tensor] = None, + b: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ diff --git a/src/caustics/lenses/singleplane.py b/src/caustics/lenses/singleplane.py index 9644f7d2..f9b858c1 100644 --- a/src/caustics/lenses/singleplane.py +++ b/src/caustics/lenses/singleplane.py @@ -27,7 +27,11 @@ class SinglePlane(ThinLens): """ def __init__( - self, cosmology: Cosmology, lenses: list[ThinLens], name: str = None, **kwargs + self, + cosmology: Cosmology, + lenses: list[ThinLens], + name: Optional[str] = None, + **kwargs, ): """ Initialize the SinglePlane lens model. diff --git a/src/caustics/lenses/sis.py b/src/caustics/lenses/sis.py index 15a54172..df4f4d76 100644 --- a/src/caustics/lenses/sis.py +++ b/src/caustics/lenses/sis.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="operator" from typing import Optional, Union from torch import Tensor @@ -47,7 +48,7 @@ def __init__( y0: Optional[Union[Tensor, float]] = None, th_ein: Optional[Union[Tensor, float]] = None, s: float = 0.0, - name: str = None, + name: Optional[str] = None, ): """ Initialize the SIS lens model. @@ -67,10 +68,10 @@ def reduced_deflection_angle( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - th_ein: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + th_ein: Optional[Tensor] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """ @@ -106,10 +107,10 @@ def potential( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - th_ein: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + th_ein: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -143,10 +144,10 @@ def convergence( z_s: Tensor, *args, params: Optional["Packed"] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - th_ein: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + th_ein: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ diff --git a/src/caustics/lenses/tnfw.py b/src/caustics/lenses/tnfw.py index 3c9478d6..cb613687 100644 --- a/src/caustics/lenses/tnfw.py +++ b/src/caustics/lenses/tnfw.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="operator,union-attr" from math import pi from typing import Optional, Union @@ -95,7 +96,7 @@ def __init__( s: float = 0.0, interpret_m_total_mass: bool = True, use_case="batchable", - name: str = None, + name: Optional[str] = None, ): """ Initialize an instance of the TNFW lens class. @@ -150,12 +151,12 @@ def get_concentration( self, *args, params: Optional[Packed] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - mass: Tensor = None, - scale_radius: Tensor = None, - tau: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + mass: Optional[Tensor] = None, + scale_radius: Optional[Tensor] = None, + tau: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -194,12 +195,12 @@ def get_truncation_radius( self, *args, params: Optional[Packed] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - mass: Tensor = None, - scale_radius: Tensor = None, - tau: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + mass: Optional[Tensor] = None, + scale_radius: Optional[Tensor] = None, + tau: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -234,12 +235,12 @@ def get_M0( self, *args, params: Optional[Packed] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - mass: Tensor = None, - scale_radius: Tensor = None, - tau: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + mass: Optional[Tensor] = None, + scale_radius: Optional[Tensor] = None, + tau: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -280,12 +281,12 @@ def get_scale_density( self, *args, params: Optional[Packed] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - mass: Tensor = None, - scale_radius: Tensor = None, - tau: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + mass: Optional[Tensor] = None, + scale_radius: Optional[Tensor] = None, + tau: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -324,12 +325,12 @@ def convergence( z_s: Tensor, *args, params: Optional[Packed] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - mass: Tensor = None, - scale_radius: Tensor = None, - tau: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + mass: Optional[Tensor] = None, + scale_radius: Optional[Tensor] = None, + tau: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -384,12 +385,12 @@ def mass_enclosed_2d( z_s: Tensor, *args, params: Optional[Packed] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - mass: Tensor = None, - scale_radius: Tensor = None, - tau: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + mass: Optional[Tensor] = None, + scale_radius: Optional[Tensor] = None, + tau: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ @@ -437,12 +438,12 @@ def physical_deflection_angle( z_s: Tensor, *args, params: Optional[Packed] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - mass: Tensor = None, - scale_radius: Tensor = None, - tau: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + mass: Optional[Tensor] = None, + scale_radius: Optional[Tensor] = None, + tau: Optional[Tensor] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """Compute the physical deflection angle (arcsec) for this lens at @@ -493,12 +494,12 @@ def potential( z_s: Tensor, *args, params: Optional[Packed] = None, - z_l: Tensor = None, - x0: Tensor = None, - y0: Tensor = None, - mass: Tensor = None, - scale_radius: Tensor = None, - tau: Tensor = None, + z_l: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + mass: Optional[Tensor] = None, + scale_radius: Optional[Tensor] = None, + tau: Optional[Tensor] = None, **kwargs, ) -> Tensor: """ diff --git a/src/caustics/light/pixelated.py b/src/caustics/light/pixelated.py index c445c4d0..7ea08408 100644 --- a/src/caustics/light/pixelated.py +++ b/src/caustics/light/pixelated.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="union-attr" from typing import Optional, Union from torch import Tensor @@ -42,7 +43,7 @@ def __init__( y0: Optional[Union[Tensor, float]] = None, pixelscale: Optional[Union[Tensor, float]] = None, shape: Optional[tuple[int, ...]] = None, - name: str = None, + name: Optional[str] = None, ): """ Constructs the `Pixelated` object with the given parameters. @@ -83,10 +84,10 @@ def brightness( y, *args, params: Optional["Packed"] = None, - x0: Tensor = None, - y0: Tensor = None, - image: Tensor = None, - pixelscale: Tensor = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + image: Optional[Tensor] = None, + pixelscale: Optional[Tensor] = None, **kwargs, ): """ diff --git a/src/caustics/light/sersic.py b/src/caustics/light/sersic.py index 11bc5e05..a5b8067c 100644 --- a/src/caustics/light/sersic.py +++ b/src/caustics/light/sersic.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="operator,union-attr" from typing import Optional, Union from torch import Tensor @@ -53,7 +54,7 @@ def __init__( Ie: Optional[Union[Tensor, float]] = None, s: float = 0.0, use_lenstronomy_k=False, - name: str = None, + name: Optional[str] = None, ): """ Constructs the `Sersic` object with the given parameters. @@ -100,13 +101,13 @@ def brightness( y, *args, params: Optional["Packed"] = None, - x0: Tensor = None, - y0: Tensor = None, - q: Tensor = None, - phi: Tensor = None, - n: Tensor = None, - Re: Tensor = None, - Ie: Tensor = None, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + q: Optional[Tensor] = None, + phi: Optional[Tensor] = None, + n: Optional[Tensor] = None, + Re: Optional[Tensor] = None, + Ie: Optional[Tensor] = None, **kwargs, ): """ diff --git a/src/caustics/parameter.py b/src/caustics/parameter.py index 933de529..f86c874b 100644 --- a/src/caustics/parameter.py +++ b/src/caustics/parameter.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="union-attr" from typing import Optional, Union import torch diff --git a/src/caustics/parametrized.py b/src/caustics/parametrized.py index 5f9fab8f..6f00ccc8 100644 --- a/src/caustics/parametrized.py +++ b/src/caustics/parametrized.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="var-annotated,index,type-arg" from collections import OrderedDict from math import prod from typing import Optional, Union @@ -59,7 +60,7 @@ class Parametrized: Number of static parameters. """ - def __init__(self, name: str = None): + def __init__(self, name: Optional[str] = None): if name is None: name = self._default_name() check_valid_name(name) diff --git a/src/caustics/sims/simulator.py b/src/caustics/sims/simulator.py index c005f93e..b12e1457 100644 --- a/src/caustics/sims/simulator.py +++ b/src/caustics/sims/simulator.py @@ -1,4 +1,9 @@ +from typing import Dict +from torch import Tensor + from ..parametrized import Parametrized +from .state_dict import StateDict +from ..namespace_dict import NestedNamespaceDict __all__ = ("Simulator",) @@ -25,3 +30,68 @@ def __call__(self, *args, **kwargs): rest_args = tuple() return self.forward(packed_args, *rest_args, **kwargs) + + @staticmethod + def __set_module_params(module: Parametrized, params: Dict[str, Tensor]): + for k, v in params.items(): + setattr(module, k, v) + + def state_dict(self) -> StateDict: + return StateDict.from_params(self.params) + + def load_state_dict(self, file_path: str) -> "Simulator": + """ + Loads and then sets the state of the simulator from a file + + Parameters + ---------- + file_path : str | Path + The file path to a safetensors file + to load the state from + + Returns + ------- + Simulator + The simulator with the loaded state + """ + loaded_state_dict = StateDict.load(file_path) + self.set_state_dict(loaded_state_dict) + return self + + def set_state_dict(self, state_dict: StateDict) -> "Simulator": + """ + Sets the state of the simulator from a state dict + + Parameters + ---------- + state_dict : StateDict + The state dict to load from + + Returns + ------- + Simulator + The simulator with the loaded state + """ + # TODO: Do some checks for the state dict metadata + + # Convert to nested namespace dict + param_dicts = NestedNamespaceDict(state_dict) + + # Grab params for the current module + self_params = param_dicts.pop(self.name) + + def _set_params(module): + # Start from root, and move down the DAG + if module.name in param_dicts: + module_params = param_dicts[module.name] + self.__set_module_params(module, module_params) + if module._childs != {}: + for child in module._childs.values(): + _set_params(child) + + # Set the parameters of the current module + self.__set_module_params(self, self_params) + + # Set the parameters of the children modules + _set_params(self) + return self diff --git a/src/caustics/sims/state_dict.py b/src/caustics/sims/state_dict.py new file mode 100644 index 00000000..e8ca97c6 --- /dev/null +++ b/src/caustics/sims/state_dict.py @@ -0,0 +1,310 @@ +from datetime import datetime as dt +from collections import OrderedDict +from typing import Any, Dict, Optional +from pathlib import Path + +from torch import Tensor +import torch +from .._version import __version__ +from ..namespace_dict import NamespaceDict, NestedNamespaceDict +from .. import io + +from safetensors.torch import save, load_file + +IMMUTABLE_ERR = TypeError("'StateDict' cannot be modified after creation.") +PARAM_KEYS = ["dynamic", "static"] + + +def _sanitize(tensors_dict: Dict[str, Optional[Tensor]]) -> Dict[str, Tensor]: + """ + Sanitize the input dictionary of tensors by + replacing Nones with tensors of size 0. + + Parameters + ---------- + tensors_dict : dict + A dictionary of tensors, including None. + + Returns + ------- + dict + A dictionary of tensors, with empty tensors + replaced by tensors of size 0. + """ + return { + k: v if isinstance(v, Tensor) else torch.ones(0) + for k, v in tensors_dict.items() + } + + +def _merge_and_flatten(params: "NamespaceDict | NestedNamespaceDict") -> NamespaceDict: + """ + Extract the parameters from a nested dictionary + of parameters and merge them into a single + dictionary of parameters. + + Parameters + ---------- + params : NamespaceDict | NestedNamespaceDict + The nested dictionary of parameters + that includes both "static" and "dynamic". + + Returns + ------- + NamespaceDict + The merged dictionary of parameters. + + Raises + ------ + TypeError + If the input ``params`` is not a + ``NamespaceDict`` or ``NestedNamespaceDict``. + ValueError + If the input ``params`` is a ``NestedNamespaceDict`` + but does not have the keys ``"static"`` and ``"dynamic"``. + """ + if not isinstance(params, (NamespaceDict, NestedNamespaceDict)): + raise TypeError("params must be a NamespaceDict or NestedNamespaceDict") + + if isinstance(params, NestedNamespaceDict): + # In this case, params is the full parameters + # with both "static" and "dynamic" keys + if sorted(params.keys()) != PARAM_KEYS: + raise ValueError(f"params must have keys {PARAM_KEYS}") + + # Extract the "static" and "dynamic" parameters + param_dicts = list(params.values()) + + # Merge the "static" and "dynamic" dictionaries + # to a single merged dictionary + final_dict = NestedNamespaceDict() + for pdict in param_dicts: + for k, v in pdict.items(): + if k not in final_dict: + final_dict[k] = v + else: + final_dict[k] = {**final_dict[k], **v} + + # Flatten the dictionary to a single level + params = final_dict.flatten() + return params + + +def _get_param_values(flat_params: "NamespaceDict") -> Dict[str, Optional[Tensor]]: + """ + Get the values of the parameters from a + flattened dictionary of parameters. + + Parameters + ---------- + flat_params : NamespaceDict + A flattened dictionary of parameters. + + Returns + ------- + Dict[str, Optional[Tensor]] + A dictionary of parameter values, + these values can be a tensor or None. + """ + return {k: v.value for k, v in flat_params.items()} + + +def _extract_tensors_dict( + params: "NamespaceDict | NestedNamespaceDict", +) -> Dict[str, Optional[Tensor]]: + """ + Extract the tensors from a nested dictionary + of parameters and merge them into a single + dictionary of parameters. Then return a + dictionary of tensors by getting the parameter + tensor values. + + Parameters + ---------- + params : NestedNamespaceDict + The nested dictionary of parameters + that includes both "static" and "dynamic" + export_params : bool, optional + Whether to return the merged parameters as well, + not just the dictionary of tensors, + by default False. + + Returns + ------- + dict + A dictionary of tensors + """ + all_params = _merge_and_flatten(params) + return _get_param_values(all_params) + + +class ImmutableODict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._created = True + + def __delitem__(self, _) -> None: + raise IMMUTABLE_ERR + + def __setitem__(self, key: str, value: Any) -> None: + if hasattr(self, "_created"): + raise IMMUTABLE_ERR + super().__setitem__(key, value) + + def __setattr__(self, name, value) -> None: + if hasattr(self, "_created"): + raise IMMUTABLE_ERR + return super().__setattr__(name, value) + + +class StateDict(ImmutableODict): + """A dictionary object that is immutable after creation. + This is used to store the parameters of a simulator at a given + point in time. + + Methods + ------- + to_params() + Convert the state dict to a dictionary of parameters. + """ + + __slots__ = ("_metadata", "_created", "_created_time") + + def __init__(self, metadata=None, *args, **kwargs): + # Get created time + self._created_time = dt.utcnow() + # Create metadata + _meta = { + "software_version": __version__, + "created_time": self._created_time.isoformat(), + } + if metadata: + _meta.update(metadata) + + # Set metadata + self._metadata = ImmutableODict(_meta) + + # Now create the object, this will set _created + # to True, and prevent any further modification + super().__init__(*args, **kwargs) + + def __delitem__(self, _) -> None: + raise IMMUTABLE_ERR + + def __setitem__(self, key: str, value: Any) -> None: + if hasattr(self, "_created"): + raise IMMUTABLE_ERR + super().__setitem__(key, value) + + @classmethod + def from_params(cls, params: "NestedNamespaceDict | NamespaceDict"): + """Class method to create a StateDict + from a dictionary of parameters + + Parameters + ---------- + params : NamespaceDict + A dictionary of parameters, + can either be the full parameters + that are "static" and "dynamic", + or "static" only. + + Returns + ------- + StateDict + A state dictionary object + """ + tensors_dict = _extract_tensors_dict(params) + return cls(**tensors_dict) + + def to_params(self) -> NestedNamespaceDict: + """ + Convert the state dict to + a nested dictionary of parameters. + + Returns + ------- + NestedNamespaceDict + A nested dictionary of parameters. + """ + from ..parameter import Parameter + + params = NamespaceDict() + for k, v in self.items(): + if v.nelement() == 0: + # Set to None if the tensor is empty + v = None + params[k] = Parameter(v) + return NestedNamespaceDict(params) + + def save(self, file_path: "str | Path | None" = None) -> str: + """ + Saves the state dictionary to an optional + ``file_path`` as safetensors format. + If ``file_path`` is not given, + this will default to a file in + the current working directory. + + *Note: The path specified must + have a '.st' extension.* + + Parameters + ---------- + file_path : str, optional + The file path to save the + state dictionary to, by default None + + Returns + ------- + str + The final path of the saved file + """ + input_path: Path + + if not file_path: + input_path = Path.cwd() / self.__st_file + elif isinstance(file_path, str): + input_path = Path(file_path) + else: + input_path = file_path + + ext = ".st" + if input_path.suffix != ext: + raise ValueError(f"File must have '{ext}' extension") + + return io.to_file(input_path, self._to_safetensors()) + + @classmethod + def load(cls, file_path: str) -> "StateDict": + """ + Loads the state dictionary from a + specified ``file_path``. + + Parameters + ---------- + file_path : str + The file path to load the + state dictionary from. + + Returns + ------- + StateDict + The loaded state dictionary + """ + # TODO: Need to rethink this for remote paths + + # Load just the metadata + metadata = io.get_safetensors_metadata(file_path) + + # Load the full data to cpu first + st_dict = load_file(file_path) + st_dict = {k: v if v.nelement() > 0 else None for k, v in st_dict.items()} + return cls(metadata=metadata, **st_dict) + + @property + def __st_file(self) -> str: + file_format = "%Y%m%dT%H%M%S_caustics.st" + return self._created_time.strftime(file_format) + + def _to_safetensors(self) -> bytes: + return save(_sanitize(self), metadata=self._metadata) diff --git a/src/caustics/utils.py b/src/caustics/utils.py index 42c04d1b..e16775f0 100644 --- a/src/caustics/utils.py +++ b/src/caustics/utils.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="misc" from math import pi from typing import Callable, Optional, Tuple, Union from functools import partial, lru_cache diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..5c764be4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,5 @@ +import sys +import os + +# Add the helpers directory to the path so we can import the helpers +sys.path.append(os.path.join(os.path.dirname(__file__), "utils")) diff --git a/tests/sims/conftest.py b/tests/sims/conftest.py new file mode 100644 index 00000000..7f45cce7 --- /dev/null +++ b/tests/sims/conftest.py @@ -0,0 +1,40 @@ +import pytest + +from caustics.sims.simulator import Simulator +from caustics.lenses import EPL +from caustics.light import Sersic +from caustics.cosmology import FlatLambdaCDM + + +@pytest.fixture +def test_epl_values(): + return { + "z_l": 0.5, + "phi": 0.0, + "b": 1.0, + "t": 1.0, + } + + +@pytest.fixture +def test_sersic_values(): + return { + "q": 0.9, + "phi": 0.3, + "n": 1.0, + } + + +@pytest.fixture +def simple_common_sim(test_epl_values, test_sersic_values): + class Sim(Simulator): + def __init__(self): + super().__init__() + self.cosmo = FlatLambdaCDM(h0=None) + self.epl = EPL(self.cosmo, **test_epl_values) + self.sersic = Sersic(**test_sersic_values) + self.add_param("z_s", 1.0) + + sim = Sim() + yield sim + del sim diff --git a/tests/sims/test_simulator.py b/tests/sims/test_simulator.py new file mode 100644 index 00000000..611f54b6 --- /dev/null +++ b/tests/sims/test_simulator.py @@ -0,0 +1,92 @@ +import pytest +from pathlib import Path +import sys + +import torch + +from caustics.sims.state_dict import ( + StateDict, + _extract_tensors_dict, +) + + +@pytest.fixture +def state_dict(simple_common_sim): + return simple_common_sim.state_dict() + + +@pytest.fixture +def expected_tensors(simple_common_sim): + tensors_dict = _extract_tensors_dict(simple_common_sim.params) + return tensors_dict + + +class TestSimulator: + def test_state_dict(self, state_dict, expected_tensors): + # Check state_dict type and default keys + assert isinstance(state_dict, StateDict) + + # Trying to modify state_dict should raise TypeError + with pytest.raises(TypeError): + state_dict["params"] = -1 + + # Check _metadata keys + assert "software_version" in state_dict._metadata + assert "created_time" in state_dict._metadata + + # Check params + assert dict(state_dict) == expected_tensors + + def test_set_module_params(self, simple_common_sim): + params = {"param1": torch.as_tensor(1), "param2": torch.as_tensor(2)} + # Call the __set_module_params method + simple_common_sim._Simulator__set_module_params(simple_common_sim, params) + + # Check if the module attributes have been set correctly + assert simple_common_sim.param1 == params["param1"] + assert simple_common_sim.param2 == params["param2"] + + @pytest.mark.skipif( + sys.platform.startswith("win"), + reason="Built-in open has different behavior on Windows", + ) + def test_load_state_dict(self, simple_common_sim): + simple_common_sim.epl.x0 = 0.0 + simple_common_sim.sersic.x0 = 1.0 + + fpath = simple_common_sim.state_dict().save() + loaded_state_dict = StateDict.load(fpath) + + # Change a value in the simulator + simple_common_sim.z_s = 3.0 + simple_common_sim.epl.x0 = None + simple_common_sim.sersic.x0 = None + + # Ensure that the simulator has been changed + assert ( + loaded_state_dict[f"{simple_common_sim.name}.z_s"] + != simple_common_sim.z_s.value + ) + assert ( + loaded_state_dict[f"{simple_common_sim.epl.name}.x0"] + != simple_common_sim.epl.x0.value + ) + assert ( + loaded_state_dict[f"{simple_common_sim.sersic.name}.x0"] + != simple_common_sim.sersic.x0.value + ) + + # Load the state dict form file + simple_common_sim.load_state_dict(fpath) + + # Once loaded now the values should be the same + assert ( + loaded_state_dict[f"{simple_common_sim.name}.z_s"] + == simple_common_sim.z_s.value + ) + + assert simple_common_sim.epl.x0.value == torch.as_tensor(0.0) + assert simple_common_sim.sersic.x0.value == torch.as_tensor(1.0) + + # Cleanup after + Path(fpath).unlink() diff --git a/tests/sims/test_state_dict.py b/tests/sims/test_state_dict.py new file mode 100644 index 00000000..a309c270 --- /dev/null +++ b/tests/sims/test_state_dict.py @@ -0,0 +1,189 @@ +from pathlib import Path +from tempfile import TemporaryDirectory +import sys + +import pytest +import torch +from collections import OrderedDict +from safetensors.torch import save, load +from datetime import datetime as dt +from caustics.parameter import Parameter +from caustics.namespace_dict import NamespaceDict, NestedNamespaceDict +from caustics.sims.state_dict import ( + ImmutableODict, + StateDict, + IMMUTABLE_ERR, + _sanitize, + _merge_and_flatten, + _get_param_values, +) +from caustics import __version__ + + +class TestImmutableODict: + def test_constructor(self): + odict = ImmutableODict(a=1, b=2, c=3) + assert isinstance(odict, OrderedDict) + assert odict == {"a": 1, "b": 2, "c": 3} + assert hasattr(odict, "_created") + assert odict._created is True + + def test_setitem(self): + odict = ImmutableODict() + with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): + odict["key"] = "value" + + def test_delitem(self): + odict = ImmutableODict(key="value") + with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): + del odict["key"] + + def test_setattr(self): + odict = ImmutableODict() + with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): + odict.meta = {"key": "value"} + + +class TestStateDict: + simple_tensors = {"var1": torch.as_tensor(1.0), "var2": torch.as_tensor(2.0)} + + @pytest.fixture(scope="class") + def simple_state_dict(self): + return StateDict(**self.simple_tensors) + + def test_constructor(self): + time_format = "%Y-%m-%dT%H:%M:%S" + time_str_now = dt.utcnow().strftime(time_format) + state_dict = StateDict(**self.simple_tensors) + + # Get the created time and format to nearest seconds + sd_ct_dt = dt.fromisoformat(state_dict._metadata["created_time"]) + sd_ct_str = sd_ct_dt.strftime(time_format) + + # Check the default metadata and content + assert hasattr(state_dict, "_metadata") + assert state_dict._created is True + assert state_dict._metadata["software_version"] == __version__ + assert sd_ct_str == time_str_now + assert dict(state_dict) == self.simple_tensors + + def test_constructor_with_metadata(self): + time_format = "%Y-%m-%dT%H:%M:%S" + time_str_now = dt.utcnow().strftime(time_format) + metadata = {"created_time": time_str_now, "software_version": "0.0.1"} + state_dict = StateDict(metadata=metadata, **self.simple_tensors) + + assert isinstance(state_dict._metadata, ImmutableODict) + assert dict(state_dict._metadata) == dict(metadata) + + def test_setitem(self, simple_state_dict): + with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): + simple_state_dict["var1"] = torch.as_tensor(3.0) + + def test_delitem(self, simple_state_dict): + with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): + del simple_state_dict["var1"] + + def test_from_params(self, simple_common_sim): + params: NestedNamespaceDict = simple_common_sim.params + all_params = _merge_and_flatten(params) + tensors_dict = _get_param_values(all_params) + + expected_state_dict = StateDict(**tensors_dict) + + # Full parameters + state_dict = StateDict.from_params(params) + assert state_dict == expected_state_dict + + # Static only + state_dict = StateDict.from_params(all_params) + assert state_dict == expected_state_dict + + # Check for TypeError when passing a NamespaceDict or NestedNamespaceDict + with pytest.raises(TypeError): + StateDict.from_params({"a": 1, "b": 2}) + + # Check for TypeError when passing a NestedNamespaceDict + # without the "static" and "dynamic" keys + with pytest.raises(ValueError): + StateDict.from_params(NestedNamespaceDict({"a": 1, "b": 2})) + + def test_to_params(self): + params_with_none = {"var3": torch.ones(0), **self.simple_tensors} + state_dict = StateDict(**params_with_none) + params = StateDict(**params_with_none).to_params() + assert isinstance(params, NamespaceDict) + + for k, v in params.items(): + tensor_value = state_dict[k] + if tensor_value.nelement() > 0: + assert isinstance(v, Parameter) + assert v.value == tensor_value + + def test__to_safetensors(self): + state_dict = StateDict(**self.simple_tensors) + # Save to safetensors + tensors_bytes = state_dict._to_safetensors() + expected_bytes = save(_sanitize(state_dict), metadata=state_dict._metadata) + + # Reload to back to tensors dict + # this is done because the information + # might be stored in different arrangements + # within the safetensors bytes + loaded_tensors = load(tensors_bytes) + loaded_expected_tensors = load(expected_bytes) + assert loaded_tensors == loaded_expected_tensors + + def test_st_file_string(self, simple_state_dict): + file_format = "%Y%m%dT%H%M%S_caustics.st" + expected_file = simple_state_dict._created_time.strftime(file_format) + + assert simple_state_dict._StateDict__st_file == expected_file + + @pytest.mark.skipif( + sys.platform.startswith("win"), + reason="Built-in open has different behavior on Windows", + ) + def test_save(self, simple_state_dict): + # Check for default save path + expected_fpath = Path.cwd() / simple_state_dict._StateDict__st_file + default_fpath = simple_state_dict.save() + + assert Path(default_fpath).exists() + assert default_fpath == str(expected_fpath.absolute()) + + # Cleanup after + Path(default_fpath).unlink() + + # Check for specified save path + with TemporaryDirectory() as tempdir: + tempdir = Path(tempdir) + # Correct extension and path in a tempdir + fpath = tempdir / "test.st" + saved_path = simple_state_dict.save(str(fpath.absolute())) + + assert Path(saved_path).exists() + assert saved_path == str(fpath.absolute()) + + # Test save Path + fpath1 = tempdir / "test1.st" + saved_path = simple_state_dict.save(fpath1) + assert Path(saved_path).exists() + assert saved_path == str(fpath1.absolute()) + + # Wrong extension + wrong_fpath = tempdir / "test.txt" + with pytest.raises(ValueError): + saved_path = simple_state_dict.save(str(wrong_fpath.absolute())) + + @pytest.mark.skipif( + sys.platform.startswith("win"), + reason="Built-in open has different behavior on Windows", + ) + def test_load(self, simple_state_dict): + fpath = simple_state_dict.save() + loaded_state_dict = StateDict.load(fpath) + assert loaded_state_dict == simple_state_dict + + # Cleanup after + Path(fpath).unlink() diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 00000000..d3116b7c --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,63 @@ +from pathlib import Path +import tempfile +import struct +import json +import torch +from safetensors.torch import save +from caustics.io import ( + _get_safetensors_header, + _normalize_path, + to_file, + from_file, + get_safetensors_metadata, +) + + +def test_normalize_path(): + path_obj = Path().joinpath("path", "to", "file.txt") + # Test with a string path + path_str = str(path_obj) + normalized_path = _normalize_path(path_str) + assert normalized_path == path_obj.absolute() + assert str(normalized_path) == str(path_obj.absolute()) + + # Test with a Path object + normalized_path = _normalize_path(path_obj) + assert normalized_path == path_obj.absolute() + + +def test_to_and_from_file(): + with tempfile.TemporaryDirectory() as tmpdir: + fpath = Path(tmpdir) / "test.txt" + data = "test data" + + # Test to file + ffile = to_file(fpath, data) + + assert Path(ffile).exists() + assert ffile == str(fpath.absolute()) + assert Path(ffile).read_text() == data + + # Test from file + assert from_file(fpath) == data.encode("utf-8") + + +def test_get_safetensors_metadata(): + with tempfile.TemporaryDirectory() as tmpdir: + fpath = Path(tmpdir) / "test.st" + meta_dict = {"meta": "data"} + tensors_bytes = save({"test1": torch.as_tensor(1.0)}, metadata=meta_dict) + fpath.write_bytes(tensors_bytes) + + # Manually get header + first_bytes_length = 8 + (length_of_header,) = struct.unpack("