Skip to content

Commit

Permalink
Merge pull request #152 from kazewong/144-get-rid-of-random_key_set
Browse files Browse the repository at this point in the history
144 get rid of random key set
  • Loading branch information
kazewong authored Mar 26, 2024
2 parents da8cbc3 + 0ec8c5f commit 7810044
Show file tree
Hide file tree
Showing 19 changed files with 383 additions and 784 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9","3.10","3.11"]
python-version: ["3.10","3.11"]

steps:
- uses: actions/checkout@v3
Expand Down
20 changes: 11 additions & 9 deletions example/dualmoon.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
from flowMC.sampler.MALA import MALA
from flowMC.sampler.Sampler import Sampler
from flowMC.utils.PRNG_keys import initialize_rng_keys

import corner
import matplotlib.pyplot as plt
Expand All @@ -18,7 +17,7 @@ def target_dualmoon(x, data):
along the first and second dimension
"""
print("compile count")
term1 = 0.5 * ((jnp.linalg.norm(x - data) - 2) / 0.1) ** 2
term1 = 0.5 * ((jnp.linalg.norm(x - data['data']) - 2) / 0.1) ** 2
term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2
term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2
return -(term1 - logsumexp(term2) - logsumexp(term3))
Expand All @@ -35,21 +34,23 @@ def target_dualmoon(x, data):
num_epochs = 30
batch_size = 10000

data = jnp.zeros(n_dim)
data = {'data':jnp.zeros(n_dim)}

rng_key_set = initialize_rng_keys(n_chains, 42)
model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10))
rng_key = jax.random.PRNGKey(42)
rng_key, subkey = jax.random.split(rng_key)
model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, subkey)

initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1
rng_key, subkey = jax.random.split(rng_key)
initial_position = jax.random.normal(subkey, shape=(n_chains, n_dim)) * 1

MALA_Sampler = MALA(target_dualmoon, True, {"step_size": 0.1})

print("Initializing sampler class")

nf_sampler = Sampler(
n_dim,
rng_key_set,
jnp.zeros(5),
rng_key,
data,
MALA_Sampler,
model,
n_loop_training=n_loop_training,
Expand All @@ -67,7 +68,8 @@ def target_dualmoon(x, data):
nf_sampler.sample(initial_position, data)
summary = nf_sampler.get_sampler_state(training=True)
chains, log_prob, local_accs, global_accs, loss_vals = summary.values()
nf_samples = nf_sampler.sample_flow(10000)
rng_key, subkey = jax.random.split(rng_key)
nf_samples = nf_sampler.sample_flow(subkey, 10000)

print(
"chains shape: ",
Expand Down
244 changes: 34 additions & 210 deletions example/notebook/analyzingChains.ipynb

Large diffs are not rendered by default.

195 changes: 30 additions & 165 deletions example/notebook/dualmoon.ipynb

Large diffs are not rendered by default.

144 changes: 25 additions & 119 deletions example/notebook/maximizing_likelihood.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions example/train_normalizing_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
data = make_moons(n_samples=20000, noise=0.05)
data = jnp.array(data[0])

key1, rng, init_rng = jax.random.split(PRNGKeyArray(0), 3)
key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3)

model = MaskedCouplingRQSpline(
2,
Expand All @@ -40,4 +40,4 @@

key, model, loss = train_flow(rng, model, data, num_epochs, batch_size, verbose=True)

nf_samples = model.sample(PRNGKeyArray(124098), 5000)
nf_samples = model.sample(jax.random.PRNGKey(124098), 5000)
4 changes: 2 additions & 2 deletions src/flowMC/nfmodel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def n_features(self) -> int:
def save_model(self, path: str):
eqx.tree_serialise_leaves(path + ".eqx", self)

def load_model(self, path: str) -> eqx.Module:
return eqx.tree_deserialise_leaves(path + ".eqx", self)
def load_model(self, path: str):
self = eqx.tree_deserialise_leaves(path + ".eqx", self)


class Bijection(eqx.Module):
Expand Down
6 changes: 3 additions & 3 deletions src/flowMC/nfmodel/realNVP.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class RealNVP(NFModel):
MLP is needed to make sure the scaling between layers are more or less the same.
Args:
n_layer: (int) The number of affine coupling layers.
n_layers: (int) The number of affine coupling layers.
n_features: (int) The number of features in the input.
n_hidden: (int) The number of hidden units in the MLP.
dt: (Float) Scaling factor for the affine coupling layer.
Expand Down Expand Up @@ -128,7 +128,7 @@ def data_cov(self):
return jax.lax.stop_gradient(self._data_cov)

def __init__(
self, n_features: int, n_layer: int, n_hidden: int, key: PRNGKeyArray, **kwargs
self, n_features: int, n_layers: int, n_hidden: int, key: PRNGKeyArray, **kwargs
):

if kwargs.get("base_dist") is not None:
Expand All @@ -150,7 +150,7 @@ def __init__(

self._n_features = n_features
affine_coupling = []
for i in range(n_layer):
for i in range(n_layers):
key, scale_subkey, shift_subkey = jax.random.split(key, 3)
mask = np.ones(n_features)
mask[int(n_features / 2) :] = 0
Expand Down
10 changes: 5 additions & 5 deletions src/flowMC/sampler/NF_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class NFProposal(ProposalBase):
model: NFModel

def __init__(
self, logpdf: Callable, jit: bool, model: NFModel, n_sample_max: int = 10000
self, logpdf: Callable, jit: bool, model: NFModel, n_flow_sample: int = 10000
):
super().__init__(logpdf, jit, {})
self.model = model
self.n_sample_max = n_sample_max
self.n_flow_sample = n_flow_sample
self.update_vmap = jax.vmap(self.update, in_axes=(None, (0)))
if self.jit is True:
self.update_vmap = jax.jit(self.update_vmap)
Expand Down Expand Up @@ -179,9 +179,9 @@ def sample_flow(
n_chains = initial_position.shape[0]
n_dim = initial_position.shape[-1]
total_size = initial_position.shape[0] * n_steps
if total_size > self.n_sample_max:
if total_size > self.n_flow_sample:
rng_key = rng_key
n_batch = ceil(total_size / self.n_sample_max)
n_batch = ceil(total_size / self.n_flow_sample)
n_sample = total_size // n_batch
proposal_position = jnp.zeros(
(n_batch, n_sample, initial_position.shape[-1])
Expand Down Expand Up @@ -218,5 +218,5 @@ def sample_flow(
def tree_flatten(self):
children, aux_data = super().tree_flatten()
aux_data["model"] = self.model
aux_data["n_sample_max"] = self.n_sample_max
aux_data["n_sample_max"] = self.n_flow_sample
return (children, aux_data)
1 change: 1 addition & 0 deletions src/flowMC/sampler/Proposal_Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def sample(
data: PyTree,
verbose: bool = False,
) -> tuple[
PRNGKeyArray,
Float[Array, "n_chains n_steps n_dim"],
Float[Array, "n_chains n_steps 1"],
Int[Array, "n_chains n_steps 1"],
Expand Down
Loading

0 comments on commit 7810044

Please sign in to comment.