Skip to content

Commit

Permalink
Fix nan in CPU envs. (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
tk2lab authored Jun 11, 2024
1 parent e0c75bc commit 0ecc1a0
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 103 deletions.
106 changes: 35 additions & 71 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,41 +1,63 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "logbesselk"
version = "3.3.1"
version = "3.4.0"
description = "Provide function to calculate the modified Bessel function of the second kind"
license = "Apache-2.0"
authors = [
{ name = "TAKEKAWA Takashi", email = "takekawa@tk2lab.org" }
]
readme = "README.md"
repository = "https://github.com/tk2lab/logbesselk"

classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]

requires-python = ">=3.10"
dependencies = [
]

[project.optional-dependences]
[project.optional-dependencies]
jax = [
"jax>=0.4",
"jaxlib>=0.4",
"jax",
"jaxlib",
]
jaxcuda = [
"jax[cuda12]",
"jaxlib",
]
tensorflow = [
"tensorflow>=2.8",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.envs.default]
installer = "uv"
python = "3.12"
dependencies = [
"pandas",
"pytest",
]

[tool.rye]
manaed = true
[tool.hatch.envs.jax]
features = ["jax"]
[tool.hatch.envs.jax.env-vars]
JAX_PLATFORMS = "cpu"

[[tool.rye.sources]]
name = "jax"
url = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
type = "find-links"
[tool.hatch.envs.jaxcuda]
template = "jax"
features = ["jaxcuda"]
[tool.hatch.envs.jaxcuda.env-vars]
JAX_PLATFORMS = "cuda"
CUDA_LAUNCH_BLOCKING = "1"

[tool.pytest.ini_options]
markers = [
Expand All @@ -58,61 +80,3 @@ select = ["E4", "E7", "E9", "F", "W", "N", "I"]
combine-as-imports = true
force-wrap-aliases = true
split-on-trailing-comma = true

[tool.tox]
legacy_tox_ini = """
[tox]
min_version = 4.0
rye_discovery = true
env_list =
py{310,311,312}-jax{latest,4}
py{310,311}-tf{latest,28,29,210,211,212,213,214,215,216}
py312-tf{latest,216}
eval_tf
eval_jax
[gh-actions]
python =
3.10: py310-jax4, py310-tf28
3.12: py312-jaxlatest, py312-tflatest
[testenv:py{310,311,312}-jax{4,latest}]
deps =
jaxlatest: jax[cuda12]
jax4: jax[cuda12] (>=0.4,<0.5)
pytest
pandas
commands =
{envpython} -m pytest tests/test_jax.py {posargs}
[testenv:py{310,311,312}-tf{latest,28,29,210,211,212,213,214,215,216}]
deps =
tflatest: tensorflow
tf216: tensorflow (>=2.16,<2.17)
tf215: tensorflow (>=2.15,<2.16)
tf214: tensorflow (>=2.14,<2.15)
tf213: tensorflow (>=2.13,<2.14)
tf212: tensorflow (>=2.12,<2.13)
tf211: tensorflow (>=2.11,<2.12)
tf210: tensorflow (>=2.10,<2.11)
tf29: tensorflow (>=2.9,<2.10)
tf28: tensorflow (>=2.8,<2.9)
pytest
pandas
commands =
{envpython} -m pytest tests/test_tensorflow.py {posargs}
[testenv:eval_jax]
deps =
jax[cuda12]
pandas
commands =
{envpython} eval/eval_jax.py
[testenv:eval_tf]
deps =
tensorflow
pandas
commands =
{envpython} eval/eval_tensorflow.py
"""
10 changes: 0 additions & 10 deletions requirements-dev.lock

This file was deleted.

10 changes: 0 additions & 10 deletions requirements.lock

This file was deleted.

19 changes: 7 additions & 12 deletions src/logbesselk/jax/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ def log_bessel_k(v, x):
Combination of Integrate, Continued fraction and Asymptotic expansion.
"""

def small_x_case():
return log_k_small_x(v, x)

def large_x_case():
def small_v_case():
n = fround(v)
Expand All @@ -47,17 +44,15 @@ def large_v_case():
v_ = lax.cond(large_v, lambda: v.astype(dtype), lambda: dtype(0))
return log_k_large_v(v_, x)

dtype = result_type(v, x)
large_v_ = v >= 25
small_v = finite & ~large_v_
large_v = finite & large_v_
return lax.cond(small_v, small_v_case, large_v_case)

dtype = result_type(v, x)
finite = is_finite(v) & is_finite(x) & (x > 0)
large_v_ = v >= 25
large_x_ = x >= 100

small_x = finite & ~large_x_
small_v = finite & large_x_ & ~large_v_
large_v = finite & large_x_ & large_v_
return lax.cond(small_x, small_x_case, large_x_case)
out = log_k_small_x(v, x)
finite = is_finite(out)
return lax.cond(finite, lambda: out, large_x_case)


def bessel_kratio(v, x, d=1):
Expand Down

0 comments on commit 0ecc1a0

Please sign in to comment.