Skip to content

Commit

Permalink
Merge pull request #421 from hyanwong/ruff
Browse files Browse the repository at this point in the history
Switch to ruff linting
  • Loading branch information
hyanwong authored Jul 25, 2024
2 parents 0e25baf + feb78ff commit 17b15cb
Show file tree
Hide file tree
Showing 36 changed files with 338 additions and 467 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- run:
name: Lint Python
command: |
flake8 --max-line-length 89 tsdate setup.py tests
ruff check --line-length 90 tsdate setup.py tests
- save_cache:
key: tsdate-{{ checksum "data/prior_1000df.bak" }}
paths:
Expand Down
25 changes: 6 additions & 19 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,10 @@ repos:
- id: mixed-line-ending
- id: check-case-conflict
- id: check-yaml
- repo: https://github.com/asottile/reorder_python_imports
rev: v3.10.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.5
hooks:
- id: reorder-python-imports
- repo: https://github.com/asottile/pyupgrade
rev: v3.10.1
hooks:
- id: pyupgrade
args: [--py3-plus, --py38-plus]
- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
hooks:
- id: flake8
args: [--config=.flake8]
additional_dependencies: ["flake8-bugbear==23.7.10", "flake8-builtins==2.1.0"]
- id: ruff
args: [ "--fix", "--config", "ruff.toml" ]
- id: ruff-format
args: [ "--config", "ruff.toml" ]
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
tskit>=0.5.0
tsinfer>=0.3.0
flake8
ruff
numpy
tqdm
daiquiri
Expand Down
12 changes: 12 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
line-length = 90

[lint]
select = ["E", "F", "B", "W", "I", "N", "UP", "A", "RUF", "PT", "NPY"]
# N803,806,802 Allow capital varnames
# E741 Allow "l" as var name
# PT011 allow pytest raises without match
ignore = ["N803", "N806", "N802", "E741", "PT011", "PT009"]

[lint.isort]
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
known-first-party = ["tsdate"]
30 changes: 16 additions & 14 deletions tests/distribution_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
Utility functions to construct distributions used in variational inference,
for testing purposes
"""

import mpmath
import numpy as np
import scipy.integrate
import scipy.special

from tsdate import approx
from tsdate import hypergeo
from tsdate import approx, hypergeo


def kl_divergence(p, logq):
Expand Down Expand Up @@ -81,9 +81,7 @@ def pr_a(a, n, k):
if n == k:
return pr_t_bar_a(t, 1)
else:
return np.sum(
[pr_a(a, n, k) * pr_t_bar_a(t, a, n) for a in range(2, n - k + 2)]
)
return np.sum([pr_a(a, n, k) * pr_t_bar_a(t, a, n) for a in range(2, n - k + 2)])


class TiltedGammaDiff:
Expand Down Expand Up @@ -114,8 +112,12 @@ def _U(a, b, z):
return float(val)

def __init__(self, shape1, shape2, shape3, rate1, rate2, rate3, reorder=True):
assert shape1 > 0 and shape2 > 0 and shape3 > 0
assert rate1 >= 0 and rate2 > 0 and rate3 >= 0
assert shape1 > 0
assert shape2 > 0
assert shape3 > 0
assert rate1 >= 0
assert rate2 > 0
assert rate3 >= 0
# for convergence of 2F1, we need rate2 > rate3. Invariant
# transformations of 2F1 allow us to switch arguments, with
# appropriate rescaling
Expand Down Expand Up @@ -369,8 +371,12 @@ def _M(a, b, x):
return float(val)

def __init__(self, shape1, shape2, shape3, rate1, rate2, rate3):
assert shape1 > 0 and shape2 > 0 and shape3 > 0
assert rate1 >= 0 and rate2 > 0 and rate3 >= 0
assert shape1 > 0
assert shape2 > 0
assert shape3 > 0
assert rate1 >= 0
assert rate2 > 0
assert rate3 >= 0
# for numeric stability of hypergeometric we need rate2 > rate1
# as this is a convolution, the order of (1) and (2) don't matter
self.reparametrize = rate1 > rate2
Expand Down Expand Up @@ -481,11 +487,7 @@ def sufficient_statistics(self):
+ scipy.special.betaln(self.shape1, self.shape2)
)
x = dF_dz * T / S**2 + B / S
xsq = (
d2F_dz2 * T**2 / S**4
+ B * (B + 1) / S**2
+ 2 * dF_dz * (1 + B) * T / S**3
)
xsq = d2F_dz2 * T**2 / S**4 + B * (B + 1) / S**2 + 2 * dF_dz * (1 + B) * T / S**3
logx = dF_db + scipy.special.digamma(B) - np.log(S)
return logconst, x, xsq, logx

Expand Down
15 changes: 5 additions & 10 deletions tests/exact_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Moments for EP updates using exact hypergeometric evaluations rather than
a Laplace approximation; intended for testing and accuracy benchmarking.
"""

from math import exp
from math import log

Expand Down Expand Up @@ -181,9 +182,7 @@ def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
d3 = d2 * (a + 1) / (c + 1)
mn_m = s1 * exp(f111 - f000) / t / 2 * (1 + z) + b / t / 2
sq_m = (
d1 * exp(f020 - f000) / 3
+ d2 * exp(f121 - f000) / 3
+ d3 * exp(f222 - f000) / 3
d1 * exp(f020 - f000) / 3 + d2 * exp(f121 - f000) / 3 + d3 * exp(f222 - f000) / 3
)
va_m = sq_m - mn_m**2
return mn_m, va_m
Expand Down Expand Up @@ -499,9 +498,7 @@ def test_rootward_moments(self, pars):
)[0]
assert np.isclose(logconst, np.log(ck_normconst))
ck_t_i = scipy.integrate.quad(
lambda t_i: t_i
* self.pdf_rootward(t_i, t_j, *pars_redux)
/ ck_normconst,
lambda t_i: t_i * self.pdf_rootward(t_i, t_j, *pars_redux) / ck_normconst,
t_j,
np.inf,
epsabs=0,
Expand Down Expand Up @@ -752,8 +749,7 @@ def f(t_i, t_j): # conditional moments
)[0]
ck_mn = (
scipy.integrate.quad(
lambda t_i: f(t_i, t_j)[0]
* self.pdf_rootward(t_i, t_j, *pars_redux),
lambda t_i: f(t_i, t_j)[0] * self.pdf_rootward(t_i, t_j, *pars_redux),
t_j,
np.inf,
)[0]
Expand All @@ -762,8 +758,7 @@ def f(t_i, t_j): # conditional moments
assert np.isclose(mn, ck_mn)
ck_va = (
scipy.integrate.quad(
lambda t_i: f(t_i, t_j)[1]
* self.pdf_rootward(t_i, t_j, *pars_redux),
lambda t_i: f(t_i, t_j)[1] * self.pdf_rootward(t_i, t_j, *pars_redux),
t_j,
np.inf,
)[0]
Expand Down
17 changes: 11 additions & 6 deletions tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""
Test cases for tsdate accuracy.
"""

import json
import os

Expand All @@ -40,7 +41,7 @@ class TestAccuracy:
Test for some of the basic functions used in tsdate
"""

@pytest.mark.makefiles
@pytest.mark.makefiles()
def test_make_static_files(self, request):
"""
The function used to create the tree sequences for accuracy testing.
Expand Down Expand Up @@ -75,7 +76,13 @@ def test_make_static_files(self, request):
ts.dump(os.path.join(request.fspath.dirname, "data", f"{name}.trees"))

@pytest.mark.parametrize(
"ts_name,min_r2_ts,min_r2_unconstrained,min_spear_ts,min_spear_unconstrained",
(
"ts_name",
"min_r2_ts",
"min_r2_unconstrained",
"min_spear_ts",
"min_spear_unconstrained",
),
[
("one_tree", 0.98601, 0.98601, 0.97719, 0.97719),
("few_trees", 0.98220, 0.98220, 0.97744, 0.97744),
Expand All @@ -91,9 +98,7 @@ def test_basic(
min_spear_unconstrained,
request,
):
ts = tskit.load(
os.path.join(request.fspath.dirname, "data", ts_name + ".trees")
)
ts = tskit.load(os.path.join(request.fspath.dirname, "data", ts_name + ".trees"))

sim_ancestry_parameters = json.loads(ts.provenance(0).record)["parameters"]
assert sim_ancestry_parameters["command"] == "sim_ancestry"
Expand Down Expand Up @@ -144,7 +149,7 @@ def test_scaling(self, Ne):
assert 0.9 < dts.node(dts.first().root).time / (2 * Ne) < 1.1

@pytest.mark.parametrize(
"bkwd_rate, trio_tmrca",
("bkwd_rate", "trio_tmrca"),
[ # calculated from simulations
(-1.0, 0.76),
(-0.9, 0.79),
Expand Down
46 changes: 24 additions & 22 deletions tests/test_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,32 @@
"""
Test cases for the gamma-variational approximations in tsdate
"""

from math import sqrt

import numpy as np
import pytest
import scipy.integrate
import scipy.special
import scipy.stats
from exact_moments import leafward_moments
from exact_moments import moments
from exact_moments import mutation_block_moments
from exact_moments import mutation_edge_moments
from exact_moments import mutation_leafward_moments
from exact_moments import mutation_moments
from exact_moments import mutation_rootward_moments
from exact_moments import mutation_sideways_moments
from exact_moments import mutation_twin_moments
from exact_moments import mutation_unphased_moments
from exact_moments import rootward_moments
from exact_moments import sideways_moments
from exact_moments import twin_moments
from exact_moments import unphased_moments
from exact_moments import (
leafward_moments,
moments,
mutation_block_moments,
mutation_edge_moments,
mutation_leafward_moments,
mutation_moments,
mutation_rootward_moments,
mutation_sideways_moments,
mutation_twin_moments,
mutation_unphased_moments,
rootward_moments,
sideways_moments,
twin_moments,
unphased_moments,
)

from tsdate import approx
from tsdate import hypergeo
from tsdate import approx, hypergeo

# TODO: better test set?
_gamma_trio_test_cases = [ # [shape1, rate1, shape2, rate2, muts, rate]
Expand Down Expand Up @@ -294,9 +296,7 @@ def test_average_gammas(self):
E_x = np.mean(shape + 1)
E_logx = np.mean(scipy.special.digamma(shape + 1))
assert np.isclose(E_x, (avg_shape + 1) / avg_rate)
assert np.isclose(
E_logx, scipy.special.digamma(avg_shape + 1) - np.log(avg_rate)
)
assert np.isclose(E_logx, scipy.special.digamma(avg_shape + 1) - np.log(avg_rate))


class TestKLMinimizationFailed:
Expand All @@ -305,7 +305,7 @@ class TestKLMinimizationFailed:
"""

def test_violates_jensen(self):
with pytest.raises(approx.KLMinimizationFailed, match="violates Jensen's"):
with pytest.raises(approx.KLMinimizationFailedError, match="violates Jensen's"):
approx.approximate_gamma_kl(1, 0)

def test_asymptotic_bound(self):
Expand All @@ -314,10 +314,12 @@ def test_asymptotic_bound(self):
alpha, _ = approx.approximate_gamma_kl(1, logx)
alpha += 1
alpha_bound = -0.5 / logx
assert alpha == alpha_bound and alpha > 1e4
assert alpha == alpha_bound
assert alpha > 1e4
# check that bound matches optimization result just under threshold
logx = -0.000051
alpha, _ = approx.approximate_gamma_kl(1, logx)
alpha += 1
alpha_bound = -0.5 / logx
assert np.abs(alpha - alpha_bound) < 1 and alpha < 1e4
assert np.abs(alpha - alpha_bound) < 1
assert alpha < 1e4
1 change: 1 addition & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tests for the cache management code.
"""

import os
import pathlib
import unittest
Expand Down
21 changes: 10 additions & 11 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""
Test cases for the command line interface for tsdate.
"""

import json
import logging
from unittest import mock
Expand Down Expand Up @@ -75,11 +76,11 @@ def test_recombination_rate(self):
parser = cli.tsdate_cli_parser()
params = ["-m", "1e10"]
args = parser.parse_args(
["date", self.infile, self.output] + params + ["-r", "1e-100"]
["date", self.infile, self.output, *params, "-r", "1e-100"]
)
assert args.recombination_rate == 1e-100
args = parser.parse_args(
["date", self.infile, self.output] + params + ["--recombination-rate", "73"]
["date", self.infile, self.output, *params, "--recombination-rate", "73"]
)
assert args.recombination_rate == 73

Expand All @@ -97,24 +98,22 @@ def test_epsilon(self):
def test_num_threads(self):
parser = cli.tsdate_cli_parser()
params = ["--method", "maximization", "--num-threads"]
args = parser.parse_args(["date", self.infile, self.output] + params + ["1"])
args = parser.parse_args(["date", self.infile, self.output, *params, "1"])
assert args.num_threads == 1
args = parser.parse_args(["date", self.infile, self.output] + params + ["2"])
args = parser.parse_args(["date", self.infile, self.output, *params, "2"])
assert args.num_threads == 2

def test_probability_space(self):
parser = cli.tsdate_cli_parser()
params = ["--method", "inside_outside", "--probability-space"]
args = parser.parse_args(
["date", self.infile, self.output] + params + ["linear"]
)
args = parser.parse_args(["date", self.infile, self.output, *params, "linear"])
assert args.probability_space == "linear"
args = parser.parse_args(
["date", self.infile, self.output] + params + ["logarithmic"]
["date", self.infile, self.output, *params, "logarithmic"]
)
assert args.probability_space == "logarithmic"

@pytest.mark.parametrize("flag, log_status", logging_flags.items())
@pytest.mark.parametrize(("flag", "log_status"), logging_flags.items())
def test_verbosity(self, flag, log_status):
parser = cli.tsdate_cli_parser()
args = parser.parse_args(["preprocess", self.infile, self.output, flag])
Expand All @@ -130,7 +129,7 @@ def test_method(self, method):
params = ["-m", "1e-8", "--method", method]
if method != "variational_gamma":
params += ["-n", "10"]
args = parser.parse_args(["date", self.infile, self.output] + params)
args = parser.parse_args(["date", self.infile, self.output, *params])
assert args.method == method

def test_progress(self):
Expand Down Expand Up @@ -231,7 +230,7 @@ def test_no_output_variational_gamma(self, tmp_path, capfd):
assert out == ""
assert err == ""

@pytest.mark.parametrize("flag, log_status", logging_flags.items())
@pytest.mark.parametrize(("flag", "log_status"), logging_flags.items())
def test_verbosity(self, tmp_path, caplog, flag, log_status):
popsize = 10000
ts = msprime.simulate(
Expand Down
Loading

0 comments on commit 17b15cb

Please sign in to comment.