diff --git a/pyproject.toml b/pyproject.toml index d0faf8f..f659a3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,10 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + [project] name = "logbesselk" -version = "3.3.1" +version = "3.4.0" description = "Provide function to calculate the modified Bessel function of the second kind" license = "Apache-2.0" authors = [ @@ -8,6 +12,7 @@ authors = [ ] readme = "README.md" repository = "https://github.com/tk2lab/logbesselk" + classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -15,27 +20,44 @@ classifiers = [ ] requires-python = ">=3.10" +dependencies = [ +] -[project.optional-dependences] +[project.optional-dependencies] jax = [ - "jax>=0.4", - "jaxlib>=0.4", + "jax", + "jaxlib", +] +jaxcuda = [ + "jax[cuda12]", + "jaxlib", ] tensorflow = [ "tensorflow>=2.8", ] -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.envs.default] +installer = "uv" +python = "3.12" +dependencies = [ + "pandas", + "pytest", +] -[tool.rye] -manaed = true +[tool.hatch.envs.jax] +features = ["jax"] +[tool.hatch.envs.jax.env-vars] +JAX_PLATFORMS = "cpu" -[[tool.rye.sources]] -name = "jax" -url = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" -type = "find-links" +[tool.hatch.envs.jaxcuda] +template = "jax" +features = ["jaxcuda"] +[tool.hatch.envs.jaxcuda.env-vars] +JAX_PLATFORMS = "cuda" +CUDA_LAUNCH_BLOCKING = "1" [tool.pytest.ini_options] markers = [ @@ -58,61 +80,3 @@ select = ["E4", "E7", "E9", "F", "W", "N", "I"] combine-as-imports = true force-wrap-aliases = true split-on-trailing-comma = true - -[tool.tox] -legacy_tox_ini = """ -[tox] -min_version = 4.0 -rye_discovery = true -env_list = - py{310,311,312}-jax{latest,4} - py{310,311}-tf{latest,28,29,210,211,212,213,214,215,216} - py312-tf{latest,216} - eval_tf - eval_jax - -[gh-actions] -python = - 3.10: py310-jax4, py310-tf28 - 3.12: py312-jaxlatest, py312-tflatest - -[testenv:py{310,311,312}-jax{4,latest}] -deps = - jaxlatest: jax[cuda12] - jax4: jax[cuda12] (>=0.4,<0.5) - pytest - pandas -commands = - {envpython} -m pytest tests/test_jax.py {posargs} - -[testenv:py{310,311,312}-tf{latest,28,29,210,211,212,213,214,215,216}] -deps = - tflatest: tensorflow - tf216: tensorflow (>=2.16,<2.17) - tf215: tensorflow (>=2.15,<2.16) - tf214: tensorflow (>=2.14,<2.15) - tf213: tensorflow (>=2.13,<2.14) - tf212: tensorflow (>=2.12,<2.13) - tf211: tensorflow (>=2.11,<2.12) - tf210: tensorflow (>=2.10,<2.11) - tf29: tensorflow (>=2.9,<2.10) - tf28: tensorflow (>=2.8,<2.9) - pytest - pandas -commands = - {envpython} -m pytest tests/test_tensorflow.py {posargs} - -[testenv:eval_jax] -deps = - jax[cuda12] - pandas -commands = - {envpython} eval/eval_jax.py - -[testenv:eval_tf] -deps = - tensorflow - pandas -commands = - {envpython} eval/eval_tensorflow.py -""" diff --git a/requirements-dev.lock b/requirements-dev.lock deleted file mode 100644 index 8f23096..0000000 --- a/requirements-dev.lock +++ /dev/null @@ -1,10 +0,0 @@ -# generated by rye -# use `rye lock` or `rye sync` to update this lockfile -# -# last locked with the following flags: -# pre: false -# features: [] -# all-features: false -# with-sources: false - --e file:. diff --git a/requirements.lock b/requirements.lock deleted file mode 100644 index 8f23096..0000000 --- a/requirements.lock +++ /dev/null @@ -1,10 +0,0 @@ -# generated by rye -# use `rye lock` or `rye sync` to update this lockfile -# -# last locked with the following flags: -# pre: false -# features: [] -# all-features: false -# with-sources: false - --e file:. diff --git a/src/logbesselk/jax/ica.py b/src/logbesselk/jax/ica.py index 3fdfe06..a0f5652 100644 --- a/src/logbesselk/jax/ica.py +++ b/src/logbesselk/jax/ica.py @@ -32,9 +32,6 @@ def log_bessel_k(v, x): Combination of Integrate, Continued fraction and Asymptotic expansion. """ - def small_x_case(): - return log_k_small_x(v, x) - def large_x_case(): def small_v_case(): n = fround(v) @@ -47,17 +44,15 @@ def large_v_case(): v_ = lax.cond(large_v, lambda: v.astype(dtype), lambda: dtype(0)) return log_k_large_v(v_, x) + dtype = result_type(v, x) + large_v_ = v >= 25 + small_v = finite & ~large_v_ + large_v = finite & large_v_ return lax.cond(small_v, small_v_case, large_v_case) - dtype = result_type(v, x) - finite = is_finite(v) & is_finite(x) & (x > 0) - large_v_ = v >= 25 - large_x_ = x >= 100 - - small_x = finite & ~large_x_ - small_v = finite & large_x_ & ~large_v_ - large_v = finite & large_x_ & large_v_ - return lax.cond(small_x, small_x_case, large_x_case) + out = log_k_small_x(v, x) + finite = is_finite(out) + return lax.cond(finite, lambda: out, large_x_case) def bessel_kratio(v, x, d=1):