Skip to content

Commit

Permalink
api: refactor ScoredItem usage from Sample
Browse files Browse the repository at this point in the history
  • Loading branch information
benlipkin committed Nov 18, 2024
1 parent 019f314 commit 70f987a
Show file tree
Hide file tree
Showing 13 changed files with 196 additions and 196 deletions.
20 changes: 10 additions & 10 deletions TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ from nltk.sem.logic import LogicalExpressionException, LogicParser
from decoding.generators import TreeSearch
from decoding.estimators import SelfConsistency
from decoding.models import LanguageModel
from decoding.pmf import CategoricalLogPMF, Sample
from decoding.pmf import CategoricalLogPMF, ScoredItem
from decoding.scorers import Scorer

# here's our prompt for the problem we'd like solved
Expand Down Expand Up @@ -139,26 +139,26 @@ def stop_pass(s: str) -> bool:
# let's specify how to score particles at each step
# note that compared to the previous example,
# here instead of simply returning a float,
# we're returning a `Sample`: a str with an associated utility
# we're returning a `ScoredItem`: a str with an associated score
# this will allow us to modify the state of the string
def step_score_fn(s: str) -> Sample[str]:
def step_score_fn(s: str) -> ScoredItem[str]:
if stop_pass(s):
return Sample(item=s, utility=float("inf"))
return ScoredItem(item=s, score=float("inf"))
lines = s.strip().split("\n")
last_line = lines[-1]
if last_line.startswith(("P:", "C:")):
stmt = last_line[2:]
try:
parser.parse(stmt)
return Sample(item=s, utility=len(lines))
return ScoredItem(item=s, score=len(lines))
except LogicalExpressionException:
pass
backtrack = "\n".join(lines[:-1]) + "\n"
return Sample(item=backtrack, utility=len(lines) - 1)
return ScoredItem(item=backtrack, score=len(lines) - 1)
# the logic above is as follows:
# - if a string passes the stop condition, set utility high to keep it
# - if a string passes the stop condition, set the score high to keep it
# - for the strings that are not done, try to parse the last line
# - if is parses, keep it and update the utility to the number of passing lines
# - if is parses, keep it and update the score to the number of passing lines
# - if it fails, backtrack the string to the last passing line
# using a very simple (~10 line) step function,
# we've implemented a backtracking tree search algorithm
Expand All @@ -171,7 +171,7 @@ step_scorer = Scorer.from_f_str_to_sample(step_score_fn, parallelize=True)

# now let's specify our final score function
# to resolve the beam of passing particles
def final_score_fn(gens: CategoricalLogPMF[str]) -> list[Sample[str]]:
def final_score_fn(gens: CategoricalLogPMF[str]) -> list[ScoredItem[str]]:
def postproc(gen: str) -> str:
try:
new = gen[len(prompt) - 2 :]
Expand Down Expand Up @@ -212,7 +212,7 @@ final_scorer = Scorer.from_f_catlogpmf_to_batch_sample(final_score_fn)
# we can access it

# finally, let's wrap this all up in a `TreeSearch` generator
def run(prompt: str) -> list[Sample[str]]:
def run(prompt: str) -> list[ScoredItem[str]]:
return TreeSearch(
prompt=prompt,
llm=llm,
Expand Down
6 changes: 3 additions & 3 deletions decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
that can be used to sample controlled text from LMs.
Supporting modules:
- `decoding.pmf`: Data structures for working with probability mass functions
and methods for calculating information-theoretic quantities.
- `decoding.pmf`: Data structures for probability mass functions and other collections
of measures as well as algorithms for calculating information-theoretic quantities.
- `decoding.samplers`: Methods for sampling from distributions.
- `decoding.estimators`: Decision rules for deriving point estimates from distributions.
Supports a flexible Minimum Bayes Risk (MBR) interface that accepts arbitrary
user-defined utility functions.
- `decoding.metrics`: Metrics that may be useful for constructing utility functions.
- `decoding.metrics`: Metrics that may be useful for constructing scoring functions.
- `decoding.utils`: Miscellaneous helper functions for the library.
"""

Expand Down
40 changes: 20 additions & 20 deletions decoding/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
Methods for calculating point estimates from distributions.
Estimators in this module operate over an instance of `decoding.pmf.CategoricalLogPMF`
and return a list of `decoding.pmf.Sample` instances sorted by their utility. Each
`decoding.pmf.Sample` instance contains an `item` and `utility` field. More about these
data structures can be found in the `decoding.pmf` module.
and return a list of `decoding.pmf.ScoredItem` instances sorted by their expected
utility. Each `decoding.pmf.ScoredItem` instance contains an `item` and `score` field.
More about these data structures can be found in the `decoding.pmf` module.
The estimators in this module reflect variants of the Minimum Bayes Risk (MBR). The MBR
is a decision-theoretic approach to point estimation that minimizes the expected loss
of a decision rule. This module provides efficient implementations of MBR that account
for the properties of arbitrary user-provided utility functions.
The module also provides a `MAP` estimator, which is a special case of `MBR` where the
utility function is a constant function, and a `SelfConsistency` estimator, which
utility function is the identity function, and a `SelfConsistency` estimator, which
applies a post-processing and filtering step before aggregating the resulting samples
via a majority voting procedure.
"""
Expand All @@ -24,7 +24,7 @@

import jax.numpy as jnp

from decoding.pmf import CategoricalLogPMF, Sample, make_samples, sort_samples
from decoding.pmf import CategoricalLogPMF, ScoredItem, make_samples, sort_samples
from decoding.types import FS, NUM, T_, T


Expand All @@ -33,7 +33,7 @@ def MBR(
*,
utility: Callable[[T, T], NUM],
parallelize: bool = False,
) -> list[Sample[T]]:
) -> list[ScoredItem[T]]:
"""
Calculate the Minimum Bayes Risk (MBR) estimator for a given distribution
and arbitrary user-provided utility function.
Expand All @@ -44,7 +44,7 @@ def MBR(
parallelize: Whether to parallelize the utility calculation.
Returns:
A sorted list of `decoding.pmf.Sample` instances by their utility.
A sorted list of `decoding.pmf.ScoredItem` instances by their expected utility.
Example:
```python
Expand All @@ -69,7 +69,7 @@ def commutativeMBR(
*,
utility: Callable[[T, T], NUM],
parallelize: bool = False,
) -> list[Sample[T]]:
) -> list[ScoredItem[T]]:
"""
Variant of `MBR` for commutative utility functions.
By exploiting the commutative property of the utility function, this
Expand All @@ -81,7 +81,7 @@ def commutativeMBR(
parallelize: Whether to parallelize the utility calculation.
Returns:
A sorted list of `decoding.pmf.Sample` instances by their utility.
A sorted list of `decoding.pmf.ScoredItem` instances by their expected utility.
Example:
```python
Expand Down Expand Up @@ -112,7 +112,7 @@ def linearMBR(
*,
utility: Callable[[T], NUM],
parallelize: bool = False,
) -> list[Sample[T]]:
) -> list[ScoredItem[T]]:
"""
Variant of `MBR` for cases that can be executed in linear time.
By exploiting utility functions that operate only on individual elements,
Expand All @@ -124,7 +124,7 @@ def linearMBR(
parallelize: Whether to parallelize the utility calculation.
Returns:
A sorted list of `decoding.pmf.Sample` instances by their utility.
A sorted list of `decoding.pmf.ScoredItem` instances by their expected utility.
Example:
```python
Expand All @@ -144,7 +144,7 @@ def _risk(c1: T) -> FS:
return _MBR(d, _risk, parallelize=parallelize)


def MAP(d: CategoricalLogPMF[T], *, parallelize: bool = False) -> list[Sample[T]]:
def MAP(d: CategoricalLogPMF[T], *, parallelize: bool = False) -> list[ScoredItem[T]]:
"""
Calculate the Maximum A Posteriori (MAP) estimator for a given distribution.
Expand All @@ -153,7 +153,7 @@ def MAP(d: CategoricalLogPMF[T], *, parallelize: bool = False) -> list[Sample[T]
parallelize: Whether to parallelize the utility calculation.
Returns:
A sorted list of `decoding.pmf.Sample` instances by their utility.
A sorted list of `decoding.pmf.ScoredItem` instances by their expected utility.
Example:
```python
Expand All @@ -179,7 +179,7 @@ def SelfConsistency(
postproc: Callable[[T], T_],
filt: Callable[[T_], bool],
parallelize: bool = False,
) -> list[Sample[T_]]:
) -> list[ScoredItem[T_]]:
"""
Calculate the Self-Consistency estimator for a given distribution, after applying
a post-processing and filtering step.
Expand All @@ -191,7 +191,7 @@ def SelfConsistency(
parallelize: Whether to parallelize the utility calculation.
Returns:
A sorted list of `decoding.pmf.Sample` instances by their utility.
A sorted list of `decoding.pmf.ScoredItem` instances by their expected utility.
Example:
```python
Expand All @@ -207,16 +207,16 @@ def SelfConsistency(
_postproc = cache(postproc)

def _aggregate(
samples: list[Sample[T]],
samples: list[ScoredItem[T]],
_postproc: Callable[[T], T_],
filt: Callable[[T_], bool],
) -> list[Sample[T_]]:
) -> list[ScoredItem[T_]]:
ht = defaultdict(lambda: 0.0)
for sample in samples:
c = _postproc(sample.item)
if filt(c):
ht[c] += float(sample.utility)
return sort_samples([Sample(item=c, utility=u) for c, u in ht.items()])
ht[c] += float(sample.score)
return sort_samples([ScoredItem(item=c, score=u) for c, u in ht.items()])

def _utility(c1: T, c2: T) -> int:
return int(_postproc(c1) == _postproc(c2))
Expand All @@ -227,7 +227,7 @@ def _utility(c1: T, c2: T) -> int:

def _MBR(
d: CategoricalLogPMF[T], risk: Callable[[T], FS], *, parallelize: bool = False
) -> list[Sample[T]]:
) -> list[ScoredItem[T]]:
def _calc_utility(logp: FS, c1: T) -> float:
return -float(risk(c1) * jnp.exp(logp))

Expand Down
18 changes: 9 additions & 9 deletions decoding/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
_TreeSearch, # type: ignore[reportPrivateUsage]
)
from decoding.models import LanguageModel
from decoding.pmf import CategoricalLogPMF, Sample, make_samples, sort_samples
from decoding.pmf import CategoricalLogPMF, ScoredItem, make_samples, sort_samples
from decoding.scorers import Scorer


Expand All @@ -40,7 +40,7 @@ def RolloutTreeSearch( # noqa: PLR0913
temperature: float = 1.0,
logits_processors: list[LogitsProcessor] | None = None,
seed: int | None = None,
) -> list[Sample[str]]:
) -> list[ScoredItem[str]]:
if final_scorer is None:
final_scorer = step_scorer
search_params = _SearchParams(
Expand Down Expand Up @@ -76,8 +76,8 @@ def _RolloutTreeSearch(
scorer: Scorer,
search_params: _SearchParams,
sampling_params: SamplingParams,
) -> list[Sample[str]]:
def f(d: CategoricalLogPMF[str]) -> list[Sample[str]]:
) -> list[ScoredItem[str]]:
def f(d: CategoricalLogPMF[str]) -> list[ScoredItem[str]]:
_search_params = _SearchParams(
n=1,
width=search_params.width,
Expand All @@ -86,16 +86,16 @@ def f(d: CategoricalLogPMF[str]) -> list[Sample[str]]:
stop_fail=search_params.stop_fail,
)
prompts = list(d.cats)
utilities = []
scores = []
for prompt in prompts:
try:
samples = _TreeSearch(
[prompt], llm, scorer, _search_params, sampling_params
)
except ValueError:
samples = [Sample(item=prompt, utility=-float("inf"))]
utilities.append(samples[0].utility)
return make_samples(prompts, utilities)
samples = [ScoredItem(item=prompt, score=-float("inf"))]
scores.append(samples[0].score)
return make_samples(prompts, scores)

_scorer = Scorer.from_f_catlogpmf_to_batch_sample(f)
_scorer = Scorer.from_f_logpmf_to_batch_item(f)
return _TreeSearch(prompts, llm, _scorer, search_params, sampling_params)
Loading

0 comments on commit 70f987a

Please sign in to comment.