Skip to content

Commit

Permalink
add example and code
Browse files Browse the repository at this point in the history
  • Loading branch information
cchapmanbird committed May 10, 2023
1 parent d49786a commit 3d802c7
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 2 deletions.
1 change: 1 addition & 0 deletions CDF_snake/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from CDF_snake.interpolate import *
39 changes: 39 additions & 0 deletions CDF_snake/interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np
try:
import cupy as cp
CDF_SNAKE_GPU_AVAILABLE = True
except ImportError or ModuleNotFoundError:
CDF_SNAKE_GPU_AVAILABLE = False

class CDFSnake:
def __init__(self, grid_values, pdfs, normalise_cdfs=False, use_gpu=False, cache_cdfs=False) -> None:
if use_gpu and CDF_SNAKE_GPU_AVAILABLE:
xp = cp
else:
xp = np

self.xp = xp
self.grid_values = self.xp.asarray(grid_values)
self.dx = self.xp.append(0, self.xp.diff(self.grid_values))
self.pdfs = self.xp.asarray(pdfs)
self.num_pdfs = self.pdfs.shape[1]
self.arange = self.xp.arange(self.num_pdfs)
self.max_grid_value = self.grid_values.max()

self.grid_snake = (self.grid_values[:,None] + self.max_grid_value*self.arange[None,:]).T.flatten()
self.construct_snake(normalise_cdfs, cache_cdfs)

def construct_snake(self, normalise_cdfs, cache_cdfs):
cdfs = self.xp.nan_to_num(self.xp.cumsum(self.pdfs, axis=0)*self.dx[:,None])
if normalise_cdfs:
cdfs /= cdfs.max(axis=0)
if cache_cdfs:
self.cdfs = cdfs
self.cdf_snake = (cdfs + self.arange[None,:]).T.flatten()

def sample_snake(self):
self.uniform_samples = self.xp.random.uniform(0,1, size=self.num_pdfs)
self.random_sample_snake = self.uniform_samples + self.arange

inverse_samples = self.xp.interp(self.random_sample_snake, self.cdf_snake, self.grid_snake) - self.max_grid_value*self.arange
return inverse_samples
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# snake_interpolate
A handy tool for performing efficient one-dimensional interpolation over many x-grids that share a common y-grid.
# CDF_snake
A handy tool for performing efficient inverse-transform sampling of one-dimensional conditional probability distributions.
215 changes: 215 additions & 0 deletions examples/demo.ipynb

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -- build

[build-system]
requires = [
"setuptools",
"setuptools_scm[toml]",
"wheel",
]
build-backend = "setuptools.build_meta"

[project]
name = "CDF_snake"
description = "A handy tool for performing efficient inverse-transform sampling of one-dimensional conditional probability distributions."
readme = "README.md"
authors = [
{ name = "Christian Chapman-Bird", email = "c.chapman-bird.1@research.gla.ac.uk" },
]
license = { text = "GPL-3.0-or-later" }
classifiers = [
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering :: Astronomy",
"Topic :: Scientific/Engineering :: Physics",
]

# requirements
requires-python = ">=3.8"
dependencies = [
"numpy >=1.17",
]

# dynamic properties set by tools
dynamic = [
"version",
]

[project.urls]
"Bug Tracker" = "https://github.com/CChapmanbird/CDF_snake/issues"
"Source Code" = "https://github.com/CChapmanbird/CDF_snake"

[tool.setuptools]
license-files = [ "LICENSE" ]

[tool.setuptools_scm]

0 comments on commit 3d802c7

Please sign in to comment.