Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simulator class #37

Merged
merged 4 commits into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/caustic/lenses/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import abstractmethod
from typing import Any, Optional
from functools import partial

import torch
from torch import Tensor
Expand All @@ -11,7 +12,6 @@

__all__ = ("ThinLens", "ThickLens")


class ThickLens(Parametrized):
"""
Base class for modeling gravitational lenses that cannot be treated using the thin lens approximation.
Expand Down Expand Up @@ -117,8 +117,7 @@ def magnification(self, thx: Tensor, thy: Tensor, z_s: Tensor, x) -> Tensor:
Returns:
Tensor: The gravitational lensing magnification at the given coordinates.
"""
return get_magnification(self.raytrace, thx, thy, z_s, x)

return get_magnification(partial(self.raytrace, x = x), thx, thy, z_s)

class ThinLens(Parametrized):
"""Base class for thin gravitational lenses.
Expand Down Expand Up @@ -328,4 +327,4 @@ def magnification(self, thx: Tensor, thy: Tensor, z_s: Tensor, x) -> Tensor:
Returns:
Tensor: Gravitational magnification at the given coordinates.
"""
return get_magnification(self.raytrace, thx, thy, z_s, x)
return get_magnification(partial(self.raytrace, x = x), thx, thy, z_s)
14 changes: 7 additions & 7 deletions src/caustic/lenses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def get_pix_jacobian(
raytrace, thx, thy, z_s, x
raytrace, thx, thy, z_s
) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
"""Computes the Jacobian matrix of the partial derivatives of the
image position with respect to the source position
Expand All @@ -27,11 +27,11 @@ def get_pix_jacobian(
The Jacobian matrix of the image position with respect to the source position at the given point.

"""
jac = torch.func.jacfwd(raytrace, (0, 1))(thx, thy, z_s, x) # type: ignore
jac = torch.func.jacfwd(raytrace, (0, 1))(thx, thy, z_s) # type: ignore
return jac


def get_pix_magnification(raytrace, thx, thy, z_s, x) -> Tensor:
def get_pix_magnification(raytrace, thx, thy, z_s) -> Tensor:
"""
Computes the magnification at a single point on the lensing plane. The magnification is derived from the determinant
of the Jacobian matrix of the image position with respect to the source position.
Expand All @@ -46,11 +46,11 @@ def get_pix_magnification(raytrace, thx, thy, z_s, x) -> Tensor:
Returns:
The magnification at the given point on the lensing plane.
"""
jac = get_pix_jacobian(raytrace, thx, thy, z_s, x)
jac = get_pix_jacobian(raytrace, thx, thy, z_s)
return 1 / (jac[0][0] * jac[1][1] - jac[0][1] * jac[1][0]).abs()


def get_magnification(raytrace, thx, thy, z_s, x) -> Tensor:
def get_magnification(raytrace, thx, thy, z_s) -> Tensor:
"""
Computes the magnification over a grid on the lensing plane. This is done by calling `get_pix_magnification`
for each point on the grid.
Expand All @@ -65,6 +65,6 @@ def get_magnification(raytrace, thx, thy, z_s, x) -> Tensor:
Returns:
A tensor representing the magnification at each point on the grid.
"""
return vmap_n(get_pix_magnification, 2, (None, 0, 0, None, None))(
raytrace, thx, thy, z_s, x
return vmap_n(get_pix_magnification, 2, (None, 0, 0, None))(
raytrace, thx, thy, z_s
)
1 change: 0 additions & 1 deletion src/caustic/packed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import OrderedDict


class Packed(OrderedDict):
"""
Dummy wrapper for `x` so other functions can check its type.
Expand Down
8 changes: 3 additions & 5 deletions src/caustic/parametrized.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

__all__ = ("Parametrized",)


class Parametrized:
"""
Represents a class with Param and Parametrized attributes, typically used to construct parts of a simulator
Expand Down Expand Up @@ -282,7 +281,7 @@ def pack(
ValueError: If the number of dynamic arguments does not match the expected number.
ValueError: If the input is a tensor and the shape does not match the expected shape.
"""
if isinstance(x, dict):
if isinstance(x, (dict, Packed)):
missing_names = [
name for name in chain([self.name], self._descendants) if name not in x
]
Expand All @@ -302,7 +301,7 @@ def pack(
# TODO: give component and arg names
raise ValueError(
f"{n_passed} dynamic args were passed, but {n_expected} are "
"required"
"required."
)

cur_offset = self.n_dynamic
Expand Down Expand Up @@ -559,8 +558,7 @@ def add_params(p: Parametrized, dot):
add_params(desc, dot)

return dot



# class ParametrizedList(Parametrized):
# """
# TODO
Expand Down
18 changes: 18 additions & 0 deletions src/caustic/simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .parameterized import Parametrized

__all__ = ("Simulator", )

class Simulator(Parametrized):
"""A caustic simulator using Parametrized framework.

Defines a simulator class which is a callable function that
operates on the Parametrized framework. Users define the `forward`
method which takes as its first argument an object which can be
packed, all other args and kwargs are simply passed to the forward
method.

See `Parametrized` for details on how to add/access parameters.

"""
def __call__(self, *args, **kwargs):
return self.forward(self.pack(args[0]), *args[1:], **kwargs)