Skip to content

Commit

Permalink
Finish formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Mar 4, 2025
1 parent ffce829 commit 3cf9e1f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 275 deletions.
266 changes: 0 additions & 266 deletions src/flowMC/resource/local_kernel/flowHMC.py

This file was deleted.

14 changes: 6 additions & 8 deletions src/flowMC/utils/PythonFunctionWrap.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import warnings
from functools import wraps
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Tuple, Union

Array = Any
PyTree = Union[Array, Iterable[Array], Dict[Any, Array], NamedTuple]
SampleStats = Dict[str, Array]
Extras = PyTree
from typing import Any, Callable, List, Tuple

import jax
import jax.numpy as jnp
Expand All @@ -19,6 +14,9 @@
from jaxtyping import PyTree


Array = Any


def wrap_python_log_prob_fn(python_log_prob_fn: Callable[..., Array]):
@custom_vmap
@wraps(python_log_prob_fn)
Expand Down Expand Up @@ -63,7 +61,7 @@ def eval_one(x):

def _tree_dtype(tree: PyTree) -> Any:
leaves, _ = tree_flatten(tree)
from_dtypes = [dtypes.dtype(l) for l in leaves]
from_dtypes = [dtypes.dtype(leaf) for leaf in leaves]
return dtypes.result_type(*from_dtypes)


Expand Down Expand Up @@ -94,7 +92,7 @@ def unravel_one(flat):
def _ravel_inner(lst: List[Array]) -> Tuple[Array, UnravelFn]:
if not lst:
return jnp.array([], jnp.float32), lambda _: []
from_dtypes = [dtypes.dtype(l) for l in lst]
from_dtypes = [dtypes.dtype(leaf) for leaf in lst]
to_dtype = dtypes.result_type(*from_dtypes)
shapes = [jnp.shape(x)[1:] for x in lst]
indices = np.cumsum([int(np.prod(s)) for s in shapes])
Expand Down
3 changes: 2 additions & 1 deletion src/flowMC/utils/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ def plot_summary(sampler: Sampler, **plotkwargs) -> None:
"""Create plots of the most important quantities in the summary.
Args:
training (bool, optional): If True, plot training quantities. If False, plot production quantities. Defaults to False.
training (bool, optional): If True, plot training quantities.
If False, plot production quantities. Defaults to False.
"""
keys = ["local_accs", "global_accs", "log_prob"]

Expand Down

0 comments on commit 3cf9e1f

Please sign in to comment.