Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions powerbox_jax/powerbox_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import jax.numpy as np
from powerbox_jax import dft
from powerbox_jax.tools import _magnitude_grid

import numpyro
import numpyro.distributions as dist

def _make_hermitian(mag, pha):
r"""
Expand Down Expand Up @@ -57,9 +58,6 @@ class PowerBox(object):
vol_normalised_power : bool, optional
Whether the input power spectrum, ``pk``, is volume-weighted. Default True because of standard cosmological
usage.
seed: int, optional
A random seed to define the initial conditions. If not set, it will remain random, and each call to eg.
:meth:`delta_x()` will produce a *different* realisation.
Notes
-----
A number of conventions need to be listed.
Expand All @@ -75,8 +73,6 @@ class PowerBox(object):
.. note:: None of the n-dimensional arrays that are created within the class are stored, due to the inefficiency
in memory consumption that this would imply. Thus, each large array is created and *returned* by their
respective method, to be stored/discarded by the user.
.. warning:: Due to the above note, repeated calls to eg. :meth:`delta_x()` will produce *different* realisations
of the real-space field, unless the `seed` parameter is set in the constructor.
Examples
--------
To create a 3-dimensional box of gaussian over-densities, gridded into 100 bins, with cosmological conventions,
Expand All @@ -91,7 +87,7 @@ class PowerBox(object):
>>> plt.imshow(pb.delta_x())
"""

def __init__(self, N, pk, key, dim=2, boxlength=1.0, supplied_freqs=None, ensure_physical=False, a=1., b=1.,
def __init__(self, N, pk, dim=2, boxlength=1.0, supplied_freqs=None, ensure_physical=False, a=1., b=1.,
vol_normalised_power=True):

self.N = N
Expand All @@ -111,8 +107,6 @@ def __init__(self, N, pk, key, dim=2, boxlength=1.0, supplied_freqs=None, ensure
self.ensure_physical = ensure_physical
self.Ntot = self.N ** self.dim

self.seed = key

if N % 2 == 0:
self._even = True
else:
Expand Down Expand Up @@ -156,7 +150,7 @@ def get_freqs(self):
dk = np.array(Lk) / np.array(_N)

_myfreq = lambda n,d: dft.fftfreq(n, d=d, b=self.fourier_b)
freq = jax.tree_multimap(_myfreq, list(_N), list(dk))
freq = jax.tree_multimap(_myfreq, list(self.shape), list(dk))
return freq, axes, left_edge

@property
Expand All @@ -176,10 +170,13 @@ def x(self):

def gauss_hermitian(self):
"A random array which has Gaussian magnitudes and Hermitian symmetry"

key,rng = jax.random.split(self.seed)
mag = jax.random.normal(key, shape=(self.n,) * self.dim)
pha = 2 * np.pi * jax.random.uniform(rng, shape=(self.n,) * self.dim)
shape = (self.n,) * self.dim
mag = numpyro.sample('gauss_hermitian_mag', dist.Independent(dist.Normal(np.zeros(shape),
np.ones(shape)) ,
self.dim))
pha = numpyro.sample('gauss_hermitian_pha', dist.Independent(dist.Uniform(np.zeros(shape),
2 * np.pi * np.ones(shape)),
self.dim))

dk = _make_hermitian(mag, pha)

Expand Down Expand Up @@ -235,7 +232,7 @@ def delta_x(self):

return dk

def create_discrete_sample(self, key, nbar, randomise_in_cell=True, min_at_zero=False,
def create_discrete_sample(self, nbar, randomise_in_cell=True, min_at_zero=False,
store_pos=False):
r"""
Assuming that the real-space signal represents an over-density with respect to some mean, create a sample
Expand All @@ -261,7 +258,7 @@ def create_discrete_sample(self, key, nbar, randomise_in_cell=True, min_at_zero=
dx = self.delta_x()
dx = (dx + 1) * self.dx ** self.dim * nbar
n = dx
self.n_per_cell = jax.random.poisson(key, n.flatten(), shape=n.flatten().shape)
self.n_per_cell = numpyro.sample('n_per_cell', dist.Poisson(n))

# Get all source positions
args = [self.x] * self.dim
Expand All @@ -271,8 +268,8 @@ def create_discrete_sample(self, key, nbar, randomise_in_cell=True, min_at_zero=
tracer_positions = tracer_positions.repeat(self.n_per_cell.flatten(), axis=0)

if randomise_in_cell:
key,rng = jax.random.split(key)
tracer_positions += jax.random.uniform(key, shape=(np.sum(self.n_per_cell), self.dim)) * self.dx
ntot = np.sum(self.n_per_cell)
tracer_positions += numpyro.sample('tracer_shifts', dist.Uniform(np.zeros(ntot, self.dim), np.ones(ntot, self.dim)* self.dx ))

if min_at_zero:
tracer_positions += self.boxlength / 2.0
Expand Down Expand Up @@ -314,7 +311,7 @@ class LogNormalPowerBox(PowerBox):
"""

def __init__(self, *args, **kwargs):
super(LogNormalPowerBox, self).__init__(*args, **kwargs)
super(self.__class__, self).__init__(*args, **kwargs)

def correlation_array(self):
"The correlation function from the input power, on the grid"
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"tqdm>=4.48.2",
"numpy>=1.16.0",
"scipy>=1.4.1",
"numpyro",
"corner",
"matplotlib"],
)