Skip to content

Commit

Permalink
Functionality for minimizing TPD in envelope2D solver (#63)
Browse files Browse the repository at this point in the history
* tpd opt using adeptmodules

* pm branch with some bug fixes

* complex

* refactor

* PRNGKey class for filtering equinox modules

* added SrunLauncher + LocalProvider flow to PARSL

* arparse for opt and new cpu run

* refactor for utils
  • Loading branch information
joglekara committed Sep 9, 2024
1 parent 91a0924 commit 62071ec
Show file tree
Hide file tree
Showing 21 changed files with 704 additions and 569 deletions.
26 changes: 18 additions & 8 deletions adept/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


from diffrax import Solution, Euler, RESULTS
from equinox import Module
from equinox import Module, filter_jit
import mlflow, jax, numpy as np
from jax import numpy as jnp

Expand Down Expand Up @@ -122,7 +122,16 @@ def init_state_and_args(self):
"""
return {}

def init_modules(self) -> Dict:
def init_modules(self) -> Dict[str, Module]:
"""
This function initializes the necessary (trainable) physics modules that are required to run the simulation. These can be modules that
change the initial conditions, or the driver (boundary conditions), or the metric calculation. These modules are usually `eqx.Module`s
so that you can take derivatives against the (parameters of the) modules.
Returns:
Dict: A dictionary of the (trainable) modules that are required to run the simulation
"""
return {}

def __call__(self, trainable_modules: Dict, args: Dict):
Expand Down Expand Up @@ -224,8 +233,8 @@ def setup(self, cfg: Dict, adept_module: ADEPTModule = None) -> Dict[str, Module
from adept.utils import get_cfg

with mlflow.start_run(run_id=self.mlflow_run_id, nested=self.mlflow_nested) as mlflow_run:
with tempfile.TemporaryDirectory(dir=self.base_tempdir) as temp_path:
cfg = get_cfg(artifact_uri=mlflow_run.info.artifact_uri, temp_path=temp_path)
# with tempfile.TemporaryDirectory(dir=self.base_tempdir) as temp_path:
# cfg = get_cfg(artifact_uri=mlflow_run.info.artifact_uri, temp_path=temp_path)
modules = self._setup_(cfg, td, adept_module)
mlflow.log_artifacts(td) # logs the temporary directory to mlflow

Expand Down Expand Up @@ -255,7 +264,8 @@ def _get_adept_module_(self, cfg: Dict) -> ADEPTModule:

elif cfg["solver"] == "envelope-2d":
from adept.lpse2d.modules.base import BaseLPSE2D as this_module
from adept.lpse2d.datamodel import ConfigModel

# from adept.lpse2d.datamodel import ConfigModel

# config = ConfigModel(**cfg)

Expand All @@ -272,7 +282,7 @@ def _setup_(self, cfg: Dict, td: str, adept_module: ADEPTModule = None, log: boo
if adept_module is None:
self.adept_module = self._get_adept_module_(cfg)
else:
self.adept_module = adept_module
self.adept_module = adept_module(cfg)

# dump raw config
if log:
Expand Down Expand Up @@ -331,7 +341,7 @@ def __call__(self, modules: Dict = None) -> Tuple[Solution, Dict, str]:
run_id=self.mlflow_run_id, nested=self.mlflow_nested, log_system_metrics=True
) as mlflow_run:
t0 = time.time()
run_output = self.adept_module(modules, None)
run_output = filter_jit(self.adept_module.__call__)(modules, None)
mlflow.log_metrics({"run_time": round(time.time() - t0, 4)}) # logs the run time to mlflow

t0 = time.time()
Expand Down Expand Up @@ -366,7 +376,7 @@ def val_and_grad(self, modules: Dict = None) -> Tuple[float, Dict, Tuple[Solutio
run_id=self.mlflow_run_id, nested=self.mlflow_nested, log_system_metrics=True
) as mlflow_run:
t0 = time.time()
(val, run_output), grad = self.adept_module.vg(modules, None)
(val, run_output), grad = filter_jit(self.adept_module.vg)(modules, None)
flattened_grad, _ = jax.flatten_util.ravel_pytree(grad)
mlflow.log_metrics({"run_time": round(time.time() - t0, 4)}) # logs the run time to mlflow
mlflow.log_metrics({"val": float(val), "l2-grad": float(np.linalg.norm(flattened_grad))})
Expand Down
2 changes: 1 addition & 1 deletion adept/lpse2d/core/laser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def laser_update(self, t: float, y: jnp.ndarray, light_wave: Dict) -> Tuple[jnp.
E0_static = (
(1 + 0j - wpe**2.0 / (self.w0 * (1 + light_wave["delta_omega"][:, None, None])) ** 2) ** -0.25
* self.E0_source
* light_wave["amplitudes"][:, None, None]
* jnp.sqrt(light_wave["intensities"][:, None, None])
* jnp.exp(1j * k0 * self.x[None, :, None] + 1j * light_wave["initial_phase"][:, None, None])
)
dE0y = E0_static * jnp.exp(-1j * light_wave["delta_omega"][:, None, None] * self.w0 * t)
Expand Down
File renamed without changes.
Loading

0 comments on commit 62071ec

Please sign in to comment.