Skip to content

Commit

Permalink
chore: misc style and docs format tidying
Browse files Browse the repository at this point in the history
  • Loading branch information
benlipkin committed Nov 23, 2024
1 parent c9a8a67 commit 3316da2
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 41 deletions.
22 changes: 4 additions & 18 deletions decoding/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,12 @@

import jax.numpy as jnp

from decoding.pmf import (
LogPMF,
ScoredItem,
make_scored_items,
sort_scored_items,
)
from decoding.pmf import LogPMF, ScoredItem, make_scored_items, sort_scored_items
from decoding.types import FS, NUM, T_, T


def MBR(
d: LogPMF[T],
*,
utility: Callable[[T, T], NUM],
parallelize: bool = False,
d: LogPMF[T], *, utility: Callable[[T, T], NUM], parallelize: bool = False
) -> list[ScoredItem[T]]:
"""
Calculate the Minimum Bayes Risk (MBR) estimator for a given distribution
Expand Down Expand Up @@ -70,10 +62,7 @@ def _risk(c1: T) -> FS:


def commutativeMBR(
d: LogPMF[T],
*,
utility: Callable[[T, T], NUM],
parallelize: bool = False,
d: LogPMF[T], *, utility: Callable[[T, T], NUM], parallelize: bool = False
) -> list[ScoredItem[T]]:
"""
Variant of `MBR` for commutative utility functions.
Expand Down Expand Up @@ -113,10 +102,7 @@ def _risk(c1: T) -> FS:


def linearMBR(
d: LogPMF[T],
*,
utility: Callable[[T], NUM],
parallelize: bool = False,
d: LogPMF[T], *, utility: Callable[[T], NUM], parallelize: bool = False
) -> list[ScoredItem[T]]:
"""
Variant of `MBR` for cases that can be executed in linear time.
Expand Down
7 changes: 1 addition & 6 deletions decoding/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
_TreeSearch, # type: ignore[reportPrivateUsage]
)
from decoding.models import LanguageModel
from decoding.pmf import (
LogPMF,
ScoredItem,
make_scored_items,
sort_scored_items,
)
from decoding.pmf import LogPMF, ScoredItem, make_scored_items, sort_scored_items
from decoding.scorers import Scorer


Expand Down
7 changes: 1 addition & 6 deletions decoding/scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@ class supports constructors that enable the preparation of this scoring function
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass

from decoding.pmf import (
LogPMF,
ScoredItem,
make_scored_items,
sort_scored_items,
)
from decoding.pmf import LogPMF, ScoredItem, make_scored_items, sort_scored_items
from decoding.types import NUM


Expand Down
22 changes: 11 additions & 11 deletions tests/test_pmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from decoding.utils import getkey, logsoftmax


def _make_random_catlogpmf(size: int = 10) -> LogPMF[int]:
def _make_random_logpmf(size: int = 10) -> LogPMF[int]:
logp = logsoftmax(jr.normal(getkey(), (size,)))
return LogPMF(logp=logp, items=list(range(size)))


def test_categoricallogpmf() -> None:
def test_categorical_logpmf() -> None:
with pytest.raises(ValueError, match="LogProbs must be 1D"):
LogPMF(logp=jnp.asarray([[0.0], [0.0]]), items=[0, 1])
with pytest.raises(ValueError, match="LogProbs and Categories must match length"):
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_entropy() -> None:
d = LogPMF(logp=jnp.log(p), items=c)
assert entropy(d) > 0.0

d = _make_random_catlogpmf()
d = _make_random_logpmf()
surps = jnp.asarray([surprise(d, i) for i in range(len(d.items))])
h = entropy(d)
assert jnp.isclose(h, jnp.sum(jnp.exp(d.logp) * surps))
Expand Down Expand Up @@ -156,8 +156,8 @@ def test_cross_entropy() -> None:
assert ce_pq == jnp.inf
assert ce_qp / jnp.log(2) == 1.0

d_p = _make_random_catlogpmf()
d_q = _make_random_catlogpmf()
d_p = _make_random_logpmf()
d_q = _make_random_logpmf()
h_p = entropy(d_p)
h_q = entropy(d_q)
kl_pq = kl_divergence(d_p, d_q)
Expand Down Expand Up @@ -185,8 +185,8 @@ def test_js_divergence() -> None:
jsd = js_divergence(d_p, d_q)
assert 0.0 < jsd < jnp.inf

d_p = _make_random_catlogpmf()
d_q = _make_random_catlogpmf()
d_p = _make_random_logpmf()
d_q = _make_random_logpmf()
jsd_pq = js_divergence(d_p, d_q)
jsd_qp = js_divergence(d_q, d_p)
assert jnp.isclose(jsd_pq, jsd_qp)
Expand All @@ -204,16 +204,16 @@ def test_js_distance() -> None:
jsm = js_distance(d_p, d_q)
assert 0.0 < jsm < jnp.inf

d_p = _make_random_catlogpmf()
d_q = _make_random_catlogpmf()
d_p = _make_random_logpmf()
d_q = _make_random_logpmf()
jsm_pq = js_distance(d_p, d_q)
jsm_qp = js_distance(d_q, d_p)
assert jnp.isclose(jsm_pq, jsm_qp)


def test_jax_compare() -> None:
d_p = _make_random_catlogpmf()
d_q = _make_random_catlogpmf()
d_p = _make_random_logpmf()
d_q = _make_random_logpmf()

assert jnp.isclose(
entropy(d_p),
Expand Down

0 comments on commit 3316da2

Please sign in to comment.