diff --git a/TUTORIAL.md b/TUTORIAL.md index d1de2bc..9a54a7a 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -99,7 +99,7 @@ from nltk.sem.logic import LogicalExpressionException, LogicParser from decoding.generators import TreeSearch from decoding.estimators import SelfConsistency from decoding.models import LanguageModel -from decoding.pmf import CategoricalLogPMF, Sample +from decoding.pmf import CategoricalLogPMF, ScoredItem from decoding.scorers import Scorer # here's our prompt for the problem we'd like solved @@ -139,26 +139,26 @@ def stop_pass(s: str) -> bool: # let's specify how to score particles at each step # note that compared to the previous example, # here instead of simply returning a float, -# we're returning a `Sample`: a str with an associated utility +# we're returning a `ScoredItem`: a str with an associated score # this will allow us to modify the state of the string -def step_score_fn(s: str) -> Sample[str]: +def step_score_fn(s: str) -> ScoredItem[str]: if stop_pass(s): - return Sample(item=s, utility=float("inf")) + return ScoredItem(item=s, score=float("inf")) lines = s.strip().split("\n") last_line = lines[-1] if last_line.startswith(("P:", "C:")): stmt = last_line[2:] try: parser.parse(stmt) - return Sample(item=s, utility=len(lines)) + return ScoredItem(item=s, score=len(lines)) except LogicalExpressionException: pass backtrack = "\n".join(lines[:-1]) + "\n" - return Sample(item=backtrack, utility=len(lines) - 1) + return ScoredItem(item=backtrack, score=len(lines) - 1) # the logic above is as follows: -# - if a string passes the stop condition, set utility high to keep it +# - if a string passes the stop condition, set the score high to keep it # - for the strings that are not done, try to parse the last line -# - if is parses, keep it and update the utility to the number of passing lines +# - if is parses, keep it and update the score to the number of passing lines # - if it fails, backtrack the string to the last passing line # using a very simple (~10 line) step function, # we've implemented a backtracking tree search algorithm @@ -171,7 +171,7 @@ step_scorer = Scorer.from_f_str_to_sample(step_score_fn, parallelize=True) # now let's specify our final score function # to resolve the beam of passing particles -def final_score_fn(gens: CategoricalLogPMF[str]) -> list[Sample[str]]: +def final_score_fn(gens: CategoricalLogPMF[str]) -> list[ScoredItem[str]]: def postproc(gen: str) -> str: try: new = gen[len(prompt) - 2 :] @@ -212,7 +212,7 @@ final_scorer = Scorer.from_f_catlogpmf_to_batch_sample(final_score_fn) # we can access it # finally, let's wrap this all up in a `TreeSearch` generator -def run(prompt: str) -> list[Sample[str]]: +def run(prompt: str) -> list[ScoredItem[str]]: return TreeSearch( prompt=prompt, llm=llm, diff --git a/decoding/__init__.py b/decoding/__init__.py index 6b600d2..c007f8d 100644 --- a/decoding/__init__.py +++ b/decoding/__init__.py @@ -14,13 +14,13 @@ that can be used to sample controlled text from LMs. Supporting modules: -- `decoding.pmf`: Data structures for working with probability mass functions - and methods for calculating information-theoretic quantities. +- `decoding.pmf`: Data structures for probability mass functions and other collections + of measures as well as algorithms for calculating information-theoretic quantities. - `decoding.samplers`: Methods for sampling from distributions. - `decoding.estimators`: Decision rules for deriving point estimates from distributions. Supports a flexible Minimum Bayes Risk (MBR) interface that accepts arbitrary user-defined utility functions. -- `decoding.metrics`: Metrics that may be useful for constructing utility functions. +- `decoding.metrics`: Metrics that may be useful for constructing scoring functions. - `decoding.utils`: Miscellaneous helper functions for the library. """ diff --git a/decoding/estimators.py b/decoding/estimators.py index 9ae447a..dfe2d42 100644 --- a/decoding/estimators.py +++ b/decoding/estimators.py @@ -2,9 +2,9 @@ Methods for calculating point estimates from distributions. Estimators in this module operate over an instance of `decoding.pmf.CategoricalLogPMF` -and return a list of `decoding.pmf.Sample` instances sorted by their utility. Each -`decoding.pmf.Sample` instance contains an `item` and `utility` field. More about these -data structures can be found in the `decoding.pmf` module. +and return a list of `decoding.pmf.ScoredItem` instances sorted by their expected +utility. Each `decoding.pmf.ScoredItem` instance contains an `item` and `score` field. +More about these data structures can be found in the `decoding.pmf` module. The estimators in this module reflect variants of the Minimum Bayes Risk (MBR). The MBR is a decision-theoretic approach to point estimation that minimizes the expected loss @@ -12,7 +12,7 @@ for the properties of arbitrary user-provided utility functions. The module also provides a `MAP` estimator, which is a special case of `MBR` where the -utility function is a constant function, and a `SelfConsistency` estimator, which +utility function is the identity function, and a `SelfConsistency` estimator, which applies a post-processing and filtering step before aggregating the resulting samples via a majority voting procedure. """ @@ -24,7 +24,7 @@ import jax.numpy as jnp -from decoding.pmf import CategoricalLogPMF, Sample, make_samples, sort_samples +from decoding.pmf import CategoricalLogPMF, ScoredItem, make_samples, sort_samples from decoding.types import FS, NUM, T_, T @@ -33,7 +33,7 @@ def MBR( *, utility: Callable[[T, T], NUM], parallelize: bool = False, -) -> list[Sample[T]]: +) -> list[ScoredItem[T]]: """ Calculate the Minimum Bayes Risk (MBR) estimator for a given distribution and arbitrary user-provided utility function. @@ -44,7 +44,7 @@ def MBR( parallelize: Whether to parallelize the utility calculation. Returns: - A sorted list of `decoding.pmf.Sample` instances by their utility. + A sorted list of `decoding.pmf.ScoredItem` instances by their expected utility. Example: ```python @@ -69,7 +69,7 @@ def commutativeMBR( *, utility: Callable[[T, T], NUM], parallelize: bool = False, -) -> list[Sample[T]]: +) -> list[ScoredItem[T]]: """ Variant of `MBR` for commutative utility functions. By exploiting the commutative property of the utility function, this @@ -81,7 +81,7 @@ def commutativeMBR( parallelize: Whether to parallelize the utility calculation. Returns: - A sorted list of `decoding.pmf.Sample` instances by their utility. + A sorted list of `decoding.pmf.ScoredItem` instances by their expected utility. Example: ```python @@ -112,7 +112,7 @@ def linearMBR( *, utility: Callable[[T], NUM], parallelize: bool = False, -) -> list[Sample[T]]: +) -> list[ScoredItem[T]]: """ Variant of `MBR` for cases that can be executed in linear time. By exploiting utility functions that operate only on individual elements, @@ -124,7 +124,7 @@ def linearMBR( parallelize: Whether to parallelize the utility calculation. Returns: - A sorted list of `decoding.pmf.Sample` instances by their utility. + A sorted list of `decoding.pmf.ScoredItem` instances by their expected utility. Example: ```python @@ -144,7 +144,7 @@ def _risk(c1: T) -> FS: return _MBR(d, _risk, parallelize=parallelize) -def MAP(d: CategoricalLogPMF[T], *, parallelize: bool = False) -> list[Sample[T]]: +def MAP(d: CategoricalLogPMF[T], *, parallelize: bool = False) -> list[ScoredItem[T]]: """ Calculate the Maximum A Posteriori (MAP) estimator for a given distribution. @@ -153,7 +153,7 @@ def MAP(d: CategoricalLogPMF[T], *, parallelize: bool = False) -> list[Sample[T] parallelize: Whether to parallelize the utility calculation. Returns: - A sorted list of `decoding.pmf.Sample` instances by their utility. + A sorted list of `decoding.pmf.ScoredItem` instances by their expected utility. Example: ```python @@ -179,7 +179,7 @@ def SelfConsistency( postproc: Callable[[T], T_], filt: Callable[[T_], bool], parallelize: bool = False, -) -> list[Sample[T_]]: +) -> list[ScoredItem[T_]]: """ Calculate the Self-Consistency estimator for a given distribution, after applying a post-processing and filtering step. @@ -191,7 +191,7 @@ def SelfConsistency( parallelize: Whether to parallelize the utility calculation. Returns: - A sorted list of `decoding.pmf.Sample` instances by their utility. + A sorted list of `decoding.pmf.ScoredItem` instances by their expected utility. Example: ```python @@ -207,16 +207,16 @@ def SelfConsistency( _postproc = cache(postproc) def _aggregate( - samples: list[Sample[T]], + samples: list[ScoredItem[T]], _postproc: Callable[[T], T_], filt: Callable[[T_], bool], - ) -> list[Sample[T_]]: + ) -> list[ScoredItem[T_]]: ht = defaultdict(lambda: 0.0) for sample in samples: c = _postproc(sample.item) if filt(c): - ht[c] += float(sample.utility) - return sort_samples([Sample(item=c, utility=u) for c, u in ht.items()]) + ht[c] += float(sample.score) + return sort_samples([ScoredItem(item=c, score=u) for c, u in ht.items()]) def _utility(c1: T, c2: T) -> int: return int(_postproc(c1) == _postproc(c2)) @@ -227,7 +227,7 @@ def _utility(c1: T, c2: T) -> int: def _MBR( d: CategoricalLogPMF[T], risk: Callable[[T], FS], *, parallelize: bool = False -) -> list[Sample[T]]: +) -> list[ScoredItem[T]]: def _calc_utility(logp: FS, c1: T) -> float: return -float(risk(c1) * jnp.exp(logp)) diff --git a/decoding/experimental.py b/decoding/experimental.py index 0890156..380745b 100644 --- a/decoding/experimental.py +++ b/decoding/experimental.py @@ -15,7 +15,7 @@ _TreeSearch, # type: ignore[reportPrivateUsage] ) from decoding.models import LanguageModel -from decoding.pmf import CategoricalLogPMF, Sample, make_samples, sort_samples +from decoding.pmf import CategoricalLogPMF, ScoredItem, make_samples, sort_samples from decoding.scorers import Scorer @@ -40,7 +40,7 @@ def RolloutTreeSearch( # noqa: PLR0913 temperature: float = 1.0, logits_processors: list[LogitsProcessor] | None = None, seed: int | None = None, -) -> list[Sample[str]]: +) -> list[ScoredItem[str]]: if final_scorer is None: final_scorer = step_scorer search_params = _SearchParams( @@ -76,8 +76,8 @@ def _RolloutTreeSearch( scorer: Scorer, search_params: _SearchParams, sampling_params: SamplingParams, -) -> list[Sample[str]]: - def f(d: CategoricalLogPMF[str]) -> list[Sample[str]]: +) -> list[ScoredItem[str]]: + def f(d: CategoricalLogPMF[str]) -> list[ScoredItem[str]]: _search_params = _SearchParams( n=1, width=search_params.width, @@ -86,16 +86,16 @@ def f(d: CategoricalLogPMF[str]) -> list[Sample[str]]: stop_fail=search_params.stop_fail, ) prompts = list(d.cats) - utilities = [] + scores = [] for prompt in prompts: try: samples = _TreeSearch( [prompt], llm, scorer, _search_params, sampling_params ) except ValueError: - samples = [Sample(item=prompt, utility=-float("inf"))] - utilities.append(samples[0].utility) - return make_samples(prompts, utilities) + samples = [ScoredItem(item=prompt, score=-float("inf"))] + scores.append(samples[0].score) + return make_samples(prompts, scores) - _scorer = Scorer.from_f_catlogpmf_to_batch_sample(f) + _scorer = Scorer.from_f_logpmf_to_batch_item(f) return _TreeSearch(prompts, llm, _scorer, search_params, sampling_params) diff --git a/decoding/generators.py b/decoding/generators.py index 34c8bd6..8571bf0 100644 --- a/decoding/generators.py +++ b/decoding/generators.py @@ -22,7 +22,7 @@ from vllm.transformers_utils.tokenizers import MistralTokenizer from decoding.models import LanguageModel -from decoding.pmf import CategoricalLogPMF, Sample, sort_samples +from decoding.pmf import CategoricalLogPMF, ScoredItem, sort_samples from decoding.scorers import Scorer @@ -50,7 +50,7 @@ def BestOfN( # noqa: PLR0913 temperature: float = 1.0, logits_processors: list[LogitsProcessor] | None = None, seed: int | None = None, -) -> list[Sample[str]]: +) -> list[ScoredItem[str]]: """ Generate `n` samples from the language model `llm` using the `scorer` to rank them. See the [`vLLM.SamplingParams`](https://docs.vllm.ai/en/latest/dev/sampling_params.html) @@ -77,7 +77,7 @@ def BestOfN( # noqa: PLR0913 seed: The random seed. Returns: - A list of `decoding.pmf.Sample` objects sorted by the `scorer`. + A list of `decoding.pmf.ScoredItem` objects sorted by the `scorer`. Raises: ValueError: If any of the argument configurations are invalid. @@ -100,8 +100,8 @@ def BestOfN( # noqa: PLR0913 ) assert len(samples) == 20 assert all(s.item.endswith(".") for s in samples) - assert all(s.utility == -len(s.item) for s in samples) - assert samples[0].utility >= samples[-1].utility + assert all(s.score == -len(s.item) for s in samples) + assert samples[0].score >= samples[-1].score ``` """ @@ -144,7 +144,7 @@ def TreeSearch( # noqa: PLR0913 temperature: float = 1.0, logits_processors: list[LogitsProcessor] | None = None, seed: int | None = None, -) -> list[Sample[str]]: +) -> list[ScoredItem[str]]: """ Generate `n` samples from the language model `llm` using the `step_scorer` to rank them at each sync step and the `final_scorer` to rank the final beam. @@ -183,7 +183,7 @@ def TreeSearch( # noqa: PLR0913 seed: The random seed. Returns: - A list of `decoding.pmf.Sample` objects sorted by the `final_scorer`. + A list of `decoding.pmf.ScoredItem` objects sorted by the `final_scorer`. Raises: ValueError: If any of the argument configurations are invalid @@ -194,13 +194,13 @@ def TreeSearch( # noqa: PLR0913 ```python from decoding.generators import TreeSearch from decoding.models import LanguageModel - from decoding.pmf import Sample + from decoding.pmf import ScoredItem from decoding.scorers import Scorer def f(x): if "." in x: x = x.split(".")[0] + "." - return Sample(item=x, utility=-len(x)) + return ScoredItem(item=x, score=-len(x)) llm = LanguageModel.from_id("gpt2") scorer = Scorer.from_f_str_to_sample(f) @@ -218,8 +218,8 @@ def f(x): ) assert len(samples) == 3 assert all(s.item.endswith(".") for s in samples) - assert all(s.utility == -len(s.item) for s in samples) - assert samples[0].utility >= samples[-1].utility + assert all(s.score == -len(s.item) for s in samples) + assert samples[0].score >= samples[-1].score ``` """ @@ -256,7 +256,7 @@ def _BestOfN( llm: LanguageModel, scorer: Scorer, sampling_params: SamplingParams, -) -> list[Sample[str]]: +) -> list[ScoredItem[str]]: return scorer(llm(prompts=prompts, params=sampling_params)) @@ -266,8 +266,8 @@ def _TreeSearch( scorer: Scorer, search_params: _SearchParams, sampling_params: SamplingParams, -) -> list[Sample[str]]: - beam = [Sample(item=p, utility=-float("inf")) for p in prompts] +) -> list[ScoredItem[str]]: + beam = [ScoredItem(item=p, score=-float("inf")) for p in prompts] passing = [] for _ in range(search_params.max_steps): stop_pass = [search_params.stop_pass(s.item) for s in beam] @@ -352,7 +352,7 @@ def _guard_positive_int(n: int) -> int: return n -def _handle_failed_beam(passing: list[Sample[str]]) -> list[Sample[str]]: +def _handle_failed_beam(passing: list[ScoredItem[str]]) -> list[ScoredItem[str]]: if len(passing) == 0: msg = "All live samples failed before any passed stop conditions." msg += " Check compatibility of stop conditions or expand search." @@ -366,7 +366,7 @@ def _handle_failed_beam(passing: list[Sample[str]]) -> list[Sample[str]]: return sort_samples(passing) -def _handle_maxsteps(passing: list[Sample[str]]) -> list[Sample[str]]: +def _handle_maxsteps(passing: list[ScoredItem[str]]) -> list[ScoredItem[str]]: if len(passing) == 0: msg = "Max steps reached, and no samples passed stop conditions." raise RuntimeError(msg) diff --git a/decoding/metrics.py b/decoding/metrics.py index 6cfd7cc..cdb06c8 100644 --- a/decoding/metrics.py +++ b/decoding/metrics.py @@ -1,5 +1,5 @@ """ -Miscellaneous metrics that may be useful building blocks for utility functions. +Miscellaneous metrics that may be useful building blocks for scoring functions. """ from collections.abc import Sequence diff --git a/decoding/pmf.py b/decoding/pmf.py index 5292ac6..be881f1 100644 --- a/decoding/pmf.py +++ b/decoding/pmf.py @@ -6,9 +6,9 @@ calculating various information-theoretic quantities, such as `surprise`, `entropy`, `kl_divergence`, `cross_entropy`, etc. -The module also provides a `Sample` dataclass, instances of which are used to -store an `item` and its `utility` (e.g., a score, probability, or other measure). -There are also functions for creating and sorting lists of `Sample` instances. +The module also provides a `ScoredItem` dataclass, instances of which are used to +store an `item` and its `score` (e.g., a utility, probability, or other measure). +There are also functions for creating and sorting lists of `ScoredItem` instances. """ from collections import Counter @@ -25,79 +25,79 @@ @dataclass(frozen=True, kw_only=True) -class Sample(Generic[T]): +class ScoredItem(Generic[T]): """ - Dataclass for storing an item and its utility. + Dataclass for storing an item and its score. Attributes: item: The item to be stored. - utility: The utility of the item + score: The score of the item. Example: ```python - from decoding.pmf import Sample + from decoding.pmf import ScoredItem - s = Sample(item="a", utility=0.5) + s = ScoredItem(item="a", score=0.5) assert s.item == "a" - assert s.utility == 0.5 + assert s.score == 0.5 ``` """ item: T - utility: NUM + score: NUM -def sort_samples(samples: Iterable[Sample[T]]) -> list[Sample[T]]: +def sort_samples(samples: Iterable[ScoredItem[T]]) -> list[ScoredItem[T]]: """ - Sort a list of `Sample` instances by utility in descending order. + Sort a list of `ScoredItem` instances by score in descending order. Args: - samples: An iterable of `Sample` instances. + samples: An iterable of `ScoredItem` instances. Returns: - A list of `Sample` instances sorted by utility in descending order. + A list of `ScoredItem` instances sorted by score in descending order. Example: ```python - from decoding.pmf import Sample, sort_samples + from decoding.pmf import ScoredItem, sort_samples samples = [ - Sample(item="a", utility=0.5), - Sample(item="b", utility=0.3), - Sample(item="c", utility=0.7), + ScoredItem(item="a", score=0.5), + ScoredItem(item="b", score=0.3), + ScoredItem(item="c", score=0.7), ] sorted_samples = sort_samples(samples) - assert sorted_samples[0] == Sample(item="c", utility=0.7) + assert sorted_samples[0] == ScoredItem(item="c", score=0.7) ``` """ - return sorted(samples, key=lambda x: float(x.utility), reverse=True) + return sorted(samples, key=lambda x: float(x.score), reverse=True) -def make_samples(items: Sequence[T], utilities: Sequence[NUM]) -> list[Sample[T]]: +def make_samples(items: Sequence[T], scores: Sequence[NUM]) -> list[ScoredItem[T]]: """ - Create a list of `Sample` instances from a list of items and utilities. + Create a list of `ScoredItem` instances from a list of items and scores. Args: items: A sequence of items to be stored. - utilities: A sequence of utilities for the items. + scores: A sequence of scores for the items. Returns: - A list of `Sample` instances. + A list of `ScoredItem` instances. Example: ```python from decoding.pmf import make_samples items = ["a", "b", "c"] - utilities = [0.5, 0.3, 0.7] - samples = make_samples(items, utilities) - assert samples[0] == Sample(item="a", utility=0.5) + scores = [0.5, 0.3, 0.7] + samples = make_samples(items, scores) + assert samples[0] == ScoredItem(item="a", score=0.5) ``` """ - return [Sample(item=i, utility=u) for i, u in zip(items, utilities, strict=True)] + return [ScoredItem(item=i, score=u) for i, u in zip(items, scores, strict=True)] @dataclass(frozen=True, kw_only=True) @@ -207,21 +207,21 @@ def from_logits( @classmethod def from_samples( - cls, samples: Sequence[T] | Sequence[Sample[T]] + cls, samples: Sequence[T] | Sequence[ScoredItem[T]] ) -> "CategoricalLogPMF[T]": """ Create a `CategoricalLogPMF` instance from a list of items - or a list of `Sample` instances. + or a list of `ScoredItem` instances. Args: - samples: A sequence of items or `Sample` instances. + samples: A sequence of items or `ScoredItem` instances. Returns: A `CategoricalLogPMF` instance. Example: ```python - from decoding.pmf import CategoricalLogPMF, Sample + from decoding.pmf import CategoricalLogPMF, ScoredItem samples = ["a", "b", "a", "c"] d = CategoricalLogPMF.from_samples(samples) @@ -432,25 +432,25 @@ def js_distance(d_p: CategoricalLogPMF[T], d_q: CategoricalLogPMF[T]) -> FS: return jnp.sqrt(js_divergence(d_p, d_q)) -def _prepare_items(samples: Sequence[T] | Sequence[Sample[T]]) -> Sequence[T]: +def _prepare_items(samples: Sequence[T] | Sequence[ScoredItem[T]]) -> Sequence[T]: if _guard_sample_seq(samples): return [s.item for s in samples] if _guard_item_seq(samples): return samples - msg = "Samples must be `Sequence[T]` or `Sequence[Sample[T]]` and nonempty" + msg = "Samples must be `Sequence[T]` or `Sequence[ScoredItem[T]]` and nonempty" raise ValueError(msg) def _guard_sample_seq( - samples: Sequence[T] | Sequence[Sample[T]], -) -> TypeGuard[Sequence[Sample[T]]]: + samples: Sequence[T] | Sequence[ScoredItem[T]], +) -> TypeGuard[Sequence[ScoredItem[T]]]: if len(samples) == 0: return False - return isinstance(samples[0], Sample) + return isinstance(samples[0], ScoredItem) def _guard_item_seq( - samples: Sequence[T] | Sequence[Sample[T]], + samples: Sequence[T] | Sequence[ScoredItem[T]], ) -> TypeGuard[Sequence[T]]: return len(samples) > 0 diff --git a/decoding/scorers.py b/decoding/scorers.py index 8782730..0bf859e 100644 --- a/decoding/scorers.py +++ b/decoding/scorers.py @@ -1,8 +1,8 @@ """ Scorers are objects that take an instance of `decoding.pmf.CategoricalLogPMF` -and return a list of `decoding.pmf.Sample` instances that are sorted by their -utility values. The utility values are computed by a function that is passed -to the constructor of the `Scorer` object. +and return a list of `decoding.pmf.ScoredItem` instances that are sorted by their +scores. The scores are computed by a function that is passed to the constructor +of the `Scorer` object. The `Scorer` class is a frozen dataclass that wraps this scoring function. The class supports constructors that enable the preparation of this scoring function @@ -20,7 +20,7 @@ class supports constructors that enable the preparation of this scoring function from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from decoding.pmf import CategoricalLogPMF, Sample, make_samples, sort_samples +from decoding.pmf import CategoricalLogPMF, ScoredItem, make_samples, sort_samples from decoding.types import NUM @@ -30,25 +30,25 @@ class Scorer: The `Scorer` class wraps and coordinates user-supplied scoring functions. """ - _f: Callable[[CategoricalLogPMF[str]], list[Sample[str]]] + _f: Callable[[CategoricalLogPMF[str]], list[ScoredItem[str]]] - def __call__(self, d: CategoricalLogPMF[str]) -> list[Sample[str]]: + def __call__(self, d: CategoricalLogPMF[str]) -> list[ScoredItem[str]]: """ `__call__` is an alias for `score`. """ return self.score(d) - def score(self, d: CategoricalLogPMF[str]) -> list[Sample[str]]: + def score(self, d: CategoricalLogPMF[str]) -> list[ScoredItem[str]]: """ Process a `decoding.pmf.CategoricalLogPMF` instance and returns a list - of `decoding.pmf.Sample` instances that are sorted by their utility values. + of `decoding.pmf.ScoredItem` instances that are sorted by their scores. Args: d: A `decoding.pmf.CategoricalLogPMF` instance. Returns: - A list of `decoding.pmf.Sample` instances that are sorted - by their utility values. + A list of `decoding.pmf.ScoredItem` instances that are sorted + by their scores. Example: ```python @@ -59,7 +59,7 @@ def score(self, d: CategoricalLogPMF[str]) -> list[Sample[str]]: d = CategoricalLogPMF.from_samples(["a", "bb", "ccc"]) samples = scorer(d) assert samples[0].item == "ccc" - assert samples[0].utility == 3 + assert samples[0].score == 3 ``` """ @@ -92,12 +92,12 @@ def from_f_str_to_num( d = CategoricalLogPMF.from_samples(["a", "bb", "ccc"]) samples = scorer(d) assert samples[-1].item == "a" - assert samples[-1].utility == 1 + assert samples[-1].score == 1 ``` """ - def _f(d: CategoricalLogPMF[str]) -> list[Sample[str]]: + def _f(d: CategoricalLogPMF[str]) -> list[ScoredItem[str]]: if parallelize: with ThreadPoolExecutor() as e: utilities = list(e.map(f, d.cats)) @@ -132,19 +132,19 @@ def from_f_batch_str_to_batch_num( d = CategoricalLogPMF.from_samples(["a", "bb", "ccc"]) samples = scorer(d) assert samples[0].item == "ccc" - assert samples[0].utility == 3 + assert samples[0].score == 3 ``` """ - def _f(d: CategoricalLogPMF[str]) -> list[Sample[str]]: + def _f(d: CategoricalLogPMF[str]) -> list[ScoredItem[str]]: utilities = f(d.cats) return make_samples(d.cats, utilities) return cls(_f=_f) @classmethod - def from_f_catlogpmf_to_batch_num( + def from_f_logpmf_to_batch_num( cls, f: Callable[[CategoricalLogPMF[str]], Sequence[NUM]] ) -> "Scorer": """ @@ -167,38 +167,38 @@ def from_f_catlogpmf_to_batch_num( from decoding.scorers import Scorer f = lambda d: [jnp.exp(logp) * len(cat) for logp, cat in d] - scorer = Scorer.from_f_catlogpmf_to_batch_num(f) + scorer = Scorer.from_f_logpmf_to_batch_num(f) d = CategoricalLogPMF.from_samples(["a", "bb", "bb", "ccc"]) samples = scorer(d) assert samples[0].item == "bb" - assert samples[0].utility == 1.0 + assert samples[0].score == 1.0 assert samples[1].item == "ccc" - assert samples[1].utility == 0.75 + assert samples[1].score == 0.75 assert samples[2].item == "a" - assert samples[2].utility == 0.25 + assert samples[2].score == 0.25 ``` """ - def _f(d: CategoricalLogPMF[str]) -> list[Sample[str]]: + def _f(d: CategoricalLogPMF[str]) -> list[ScoredItem[str]]: utilities = f(d) return make_samples(d.cats, utilities) return cls(_f=_f) @classmethod - def from_f_str_to_sample( - cls, f: Callable[[str], Sample[str]], *, parallelize: bool = False + def from_f_str_to_item( + cls, f: Callable[[str], ScoredItem[str]], *, parallelize: bool = False ) -> "Scorer": """ Construct a `Scorer` object from a function that maps a string to a - `decoding.pmf.Sample` instance. The `Scorer` object will then score a + `decoding.pmf.ScoredItem` instance. The `Scorer` object will then score a `decoding.pmf.CategoricalLogPMF` instance by applying this function to - each of its categories. This allows us to update not only the utility + each of its categories. This allows us to update not only the score values but also the items themselves. Args: - f: A function that maps a string to a `decoding.pmf.Sample` instance. + f: A function that maps a string to a `decoding.pmf.ScoredItem` instance. parallelize: A boolean indicating whether to parallelize the scoring process. @@ -207,26 +207,26 @@ def from_f_str_to_sample( Example: ```python - from decoding.pmf import CategoricalLogPMF, Sample + from decoding.pmf import CategoricalLogPMF, ScoredItem from decoding.scorers import Scorer def f(x): if x.endswith("."): - return Sample(item=x[:-1], utility=len(x)-1) - return Sample(item=x, utility=len(x)) + return ScoredItem(item=x[:-1], score=len(x)-1) + return ScoredItem(item=x, score=len(x)) - scorer = Scorer.from_f_str_to_sample(f, parallelize=True) + scorer = Scorer.from_f_str_to_item(f, parallelize=True) d = CategoricalLogPMF.from_samples(["a", "bb.", "ccc"]) samples = scorer(d) assert samples[0].item == "ccc" - assert samples[0].utility == 3 + assert samples[0].score == 3 assert samples[1].item == "bb" - assert samples[1].utility == 2 + assert samples[1].score == 2 ``` """ - def _f(d: CategoricalLogPMF[str]) -> list[Sample[str]]: + def _f(d: CategoricalLogPMF[str]) -> list[ScoredItem[str]]: if parallelize: with ThreadPoolExecutor() as e: return list(e.map(f, d.cats)) @@ -236,58 +236,58 @@ def _f(d: CategoricalLogPMF[str]) -> list[Sample[str]]: return cls(_f=_f) @classmethod - def from_f_batch_str_to_batch_sample( - cls, f: Callable[[Sequence[str]], Sequence[Sample[str]]] + def from_f_batch_str_to_batch_item( + cls, f: Callable[[Sequence[str]], Sequence[ScoredItem[str]]] ) -> "Scorer": """ Construct a `Scorer` object from a function that maps a sequence of strings - to a sequence of `decoding.pmf.Sample` instances. The `Scorer` object will + to a sequence of `decoding.pmf.ScoredItem` instances. The `Scorer` object will then score a `decoding.pmf.CategoricalLogPMF` instance by applying this function - to its categories. This allows us to update not only the utility values but + to its categories. This allows us to update not only the score values but also the items themselves. Args: f: A function that maps a sequence of strings to a sequence of - `decoding.pmf.Sample` instances. + `decoding.pmf.ScoredItem` instances. Returns: A `Scorer` object. Example: ```python - from decoding.pmf import CategoricalLogPMF, Sample + from decoding.pmf import CategoricalLogPMF, ScoredItem from decoding.scorers import Scorer - f = lambda xs: [Sample(item=x[1:], utility=len(x[1:])) for x in xs] - scorer = Scorer.from_f_batch_str_to_batch_sample(f) + f = lambda xs: [ScoredItem(item=x[1:], score=len(x[1:])) for x in xs] + scorer = Scorer.from_f_batch_str_to_batch_item(f) d = CategoricalLogPMF.from_samples(["_a", "_bb", "_ccc"]) samples = scorer(d) assert samples[0].item == "ccc" - assert samples[0].utility == 3 + assert samples[0].score == 3 ``` """ - def _f(d: CategoricalLogPMF[str]) -> list[Sample[str]]: + def _f(d: CategoricalLogPMF[str]) -> list[ScoredItem[str]]: return list(f(d.cats)) return cls(_f=_f) @classmethod - def from_f_catlogpmf_to_batch_sample( - cls, f: Callable[[CategoricalLogPMF[str]], Sequence[Sample[str]]] + def from_f_logpmf_to_batch_item( + cls, f: Callable[[CategoricalLogPMF[str]], Sequence[ScoredItem[str]]] ) -> "Scorer": """ Construct a `Scorer` object from a function that maps a `decoding.pmf.CategoricalLogPMF` instance to a sequence of - `decoding.pmf.Sample` instances. This type signature actually + `decoding.pmf.ScoredItem` instances. This type signature actually matches much of the `decoding.estimators` module, so this constructor is particularly useful for building `Scorer` instances based on `decoding.estimators.MBR`, etc. Args: f: A function that maps a `decoding.pmf.CategoricalLogPMF` - instance to a sequence of `decoding.pmf.Sample` instances. + instance to a sequence of `decoding.pmf.ScoredItem` instances. Returns: A `Scorer` object. @@ -300,18 +300,18 @@ def from_f_catlogpmf_to_batch_sample( from decoding.scorers import Scorer f = lambda d: MBR(d, utility=lambda x1, x2: x1 < x2) - scorer = Scorer.from_f_catlogpmf_to_batch_sample(f) + scorer = Scorer.from_f_logpmf_to_batch_item(f) d = CategoricalLogPMF.from_samples(["aa", "bb", "cc"]) samples = scorer(d) assert samples[0].item == "aa" - assert jnp.isclose(samples[0].utility, 2/3) + assert jnp.isclose(samples[0].score, 2/3) assert samples[1].item == "bb" - assert jnp.isclose(samples[1].utility, 1/3) + assert jnp.isclose(samples[1].score, 1/3) ``` """ - def _f(d: CategoricalLogPMF[str]) -> list[Sample[str]]: + def _f(d: CategoricalLogPMF[str]) -> list[ScoredItem[str]]: return list(f(d)) return cls(_f=_f) diff --git a/decoding/utils.py b/decoding/utils.py index aa94261..3f88ac1 100644 --- a/decoding/utils.py +++ b/decoding/utils.py @@ -1,5 +1,5 @@ """ -Miscellaneous utility functions. +Miscellaneous helper functions. """ import secrets diff --git a/examples/thm_proving_treesearch.py b/examples/thm_proving_treesearch.py index af74dd9..a2b382b 100644 --- a/examples/thm_proving_treesearch.py +++ b/examples/thm_proving_treesearch.py @@ -12,7 +12,7 @@ from decoding.estimators import SelfConsistency from decoding.generators import TreeSearch from decoding.models import LanguageModel -from decoding.pmf import CategoricalLogPMF, Sample +from decoding.pmf import CategoricalLogPMF, ScoredItem from decoding.scorers import Scorer llm = LanguageModel.from_id( @@ -24,23 +24,23 @@ prover = TableauProver() -def step_score_fn(s: str) -> Sample[str]: +def step_score_fn(s: str) -> ScoredItem[str]: if stop_pass(s): - return Sample(item=s, utility=float("inf")) + return ScoredItem(item=s, score=float("inf")) lines = s.strip().split("\n") last_line = lines[-1] if last_line.startswith(("P:", "C:")): stmt = last_line[2:] try: parser.parse(stmt) - return Sample(item=s, utility=len(lines)) + return ScoredItem(item=s, score=len(lines)) except LogicalExpressionException: pass backtrack = "\n".join(lines[:-1]) + "\n" - return Sample(item=backtrack, utility=len(lines) - 1) + return ScoredItem(item=backtrack, score=len(lines) - 1) -def final_score_fn(d: CategoricalLogPMF[str]) -> list[Sample[str]]: +def final_score_fn(d: CategoricalLogPMF[str]) -> list[ScoredItem[str]]: def postproc(gen: str) -> str: try: new = gen[len(prompt) - 2 :] @@ -67,8 +67,8 @@ def stop_pass(s: str) -> bool: return s.endswith("\n\n") -step_scorer = Scorer.from_f_str_to_sample(step_score_fn, parallelize=True) -final_scorer = Scorer.from_f_catlogpmf_to_batch_sample(final_score_fn) +step_scorer = Scorer.from_f_str_to_item(step_score_fn, parallelize=True) +final_scorer = Scorer.from_f_logpmf_to_batch_item(final_score_fn) def run(prompt: str) -> str: diff --git a/tests/test_generators.py b/tests/test_generators.py index b329d8d..c43dd6e 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -2,7 +2,7 @@ from decoding.generators import BestOfN, TreeSearch from decoding.models import LanguageModel -from decoding.pmf import Sample +from decoding.pmf import ScoredItem from decoding.scorers import Scorer llm = LanguageModel.from_id("EleutherAI/pythia-70m", gpu_memory_utilization=0.2) @@ -11,10 +11,10 @@ def test_bestofn() -> None: start = "The" - def utility(s: str) -> int: + def score(s: str) -> int: return -len(s) - scorer = Scorer.from_f_str_to_num(utility) + scorer = Scorer.from_f_str_to_num(score) sentences = {} for n in [1, 10, 100]: samples = BestOfN( @@ -23,7 +23,7 @@ def utility(s: str) -> int: sentences[n] = samples[0].item assert all(s.startswith("The") for s in sentences.values()) assert all(s.endswith(".") for s in sentences.values()) - assert utility(sentences[100]) > utility(sentences[10]) > utility(sentences[1]) + assert score(sentences[100]) > score(sentences[10]) > score(sentences[1]) msg = "Delimiter must be a single character" with pytest.raises(ValueError, match=msg): @@ -41,12 +41,12 @@ def test_treesearch_basic() -> None: def stop(s: str) -> bool: return end in s - def utility(s: str) -> int: + def score(s: str) -> int: if stop(s): return 1 return -len(s) - scorer = Scorer.from_f_str_to_num(utility) + scorer = Scorer.from_f_str_to_num(score) sentence = TreeSearch( llm=llm, step_scorer=scorer, @@ -93,7 +93,7 @@ def test_treesearch_step() -> None: def stop(s: str) -> bool: return end in s - def utility_step(s: str) -> int: + def score_step(s: str) -> int: if stop(s): return 1 ws = s.split(delim) @@ -102,11 +102,11 @@ def utility_step(s: str) -> int: return -len(s) return -(len(ws[-2]) + len(ws[-1])) - def utility_final(s: str) -> int: + def score_final(s: str) -> int: return -len(s) - step_scorer = Scorer.from_f_str_to_num(utility_step) - final_scorer = Scorer.from_f_str_to_num(utility_final) + step_scorer = Scorer.from_f_str_to_num(score_step) + final_scorer = Scorer.from_f_str_to_num(score_final) sentence = TreeSearch( llm=llm, step_scorer=step_scorer, @@ -132,7 +132,7 @@ def test_treesearch_fail() -> None: def stop(s: str) -> bool: return end in s - def utility(s: str) -> int: + def score(s: str) -> int: if stop(s): return 1 return -len(s) @@ -140,9 +140,9 @@ def utility(s: str) -> int: def fail(s: str) -> bool: return len(s) > max_len_constraint - scorer = Scorer.from_f_str_to_num(utility) + scorer = Scorer.from_f_str_to_num(score) - def beam_search(n: int, beam_width: int, beam_factor: int) -> list[Sample[str]]: + def beam_search(n: int, beam_width: int, beam_factor: int) -> list[ScoredItem[str]]: return TreeSearch( llm=llm, step_scorer=scorer, @@ -184,15 +184,15 @@ def test_treesearch_maxsteps() -> None: def stop(s: str) -> bool: return end in s - def utility(s: str) -> int: + def score(s: str) -> int: if stop(s): return 1 return -len(s) - scorer = Scorer.from_f_str_to_num(utility) + scorer = Scorer.from_f_str_to_num(score) n_requested = 3 - def beam_search(max_steps: int) -> list[Sample[str]]: + def beam_search(max_steps: int) -> list[ScoredItem[str]]: return TreeSearch( llm=llm, step_scorer=scorer, diff --git a/tests/test_pmf.py b/tests/test_pmf.py index 3b1cf6c..47e7c8a 100644 --- a/tests/test_pmf.py +++ b/tests/test_pmf.py @@ -57,7 +57,7 @@ def test_categoricallogpmf() -> None: with pytest.raises(FrozenInstanceError, match="cannot assign to field 'cats'"): d.cats = None # type: ignore[reportAttributeAccessIssue] msg = re.escape( - "Samples must be `Sequence[T]` or `Sequence[Sample[T]]` and nonempty" + "Samples must be `Sequence[T]` or `Sequence[ScoredItem[T]]` and nonempty" ) with pytest.raises(ValueError, match=msg): CategoricalLogPMF.from_samples(samples=[]) diff --git a/tests/test_scorers.py b/tests/test_scorers.py index 46bc101..8017bc2 100644 --- a/tests/test_scorers.py +++ b/tests/test_scorers.py @@ -3,7 +3,7 @@ import jax.numpy as jnp -from decoding.pmf import CategoricalLogPMF, Sample +from decoding.pmf import CategoricalLogPMF, ScoredItem from decoding.scorers import Scorer @@ -18,13 +18,13 @@ def f(s: str) -> int: t1 = time.time() samples = scorer(d) t1 = time.time() - t1 - assert [s.utility for s in samples] == [3, 2, 1] + assert [s.score for s in samples] == [3, 2, 1] scorer = Scorer.from_f_str_to_num(f, parallelize=True) t2 = time.time() samples = scorer(d) t2 = time.time() - t2 - assert [s.utility for s in samples] == [3, 2, 1] + assert [s.score for s in samples] == [3, 2, 1] max_time = 3e-2 assert t2 < max_time <= t1 @@ -38,66 +38,66 @@ def f(ss: Sequence[str]) -> list[int]: scorer = Scorer.from_f_batch_str_to_batch_num(f) samples = scorer(d) - assert [s.utility for s in samples] == [3, 2, 1] + assert [s.score for s in samples] == [3, 2, 1] -def test_scorer_from_f_catlogpmf_to_batch_num() -> None: +def test_scorer_from_f_logpmf_to_batch_num() -> None: def f(d: CategoricalLogPMF[str]) -> list[float]: return [float(jnp.exp(logp) * len(cat)) for logp, cat in d] d = CategoricalLogPMF.from_samples(["a", "bb", "bb", "bb", "ccc"]) - scorer = Scorer.from_f_catlogpmf_to_batch_num(f) + scorer = Scorer.from_f_logpmf_to_batch_num(f) samples = scorer(d) assert [s.item for s in samples] == ["bb", "ccc", "a"] -def test_scorer_from_f_str_to_sample() -> None: - def f(s: str) -> Sample[str]: +def test_scorer_from_f_str_to_item() -> None: + def f(s: str) -> ScoredItem[str]: time.sleep(1e-2) - return Sample(item=s + " ", utility=len(s) + 1) + return ScoredItem(item=s + " ", score=len(s) + 1) d = CategoricalLogPMF.from_samples(["a", "bb", "ccc"]) - scorer = Scorer.from_f_str_to_sample(f) + scorer = Scorer.from_f_str_to_item(f) t1 = time.time() samples = scorer(d) t1 = time.time() - t1 - assert [s.utility for s in samples] == [4, 3, 2] + assert [s.score for s in samples] == [4, 3, 2] assert [s.item for s in samples] == ["ccc ", "bb ", "a "] - scorer = Scorer.from_f_str_to_sample(f, parallelize=True) + scorer = Scorer.from_f_str_to_item(f, parallelize=True) t2 = time.time() samples = scorer(d) t2 = time.time() - t2 - assert [s.utility for s in samples] == [4, 3, 2] + assert [s.score for s in samples] == [4, 3, 2] assert [s.item for s in samples] == ["ccc ", "bb ", "a "] max_time = 3e-2 assert t2 < max_time <= t1 -def test_scorer_from_f_batch_str_to_batch_sample() -> None: - def f(ss: Sequence[str]) -> list[Sample[str]]: - return [Sample(item=s + " ", utility=len(s) + 1) for s in ss] +def test_scorer_from_f_batch_str_to_batch_item() -> None: + def f(ss: Sequence[str]) -> list[ScoredItem[str]]: + return [ScoredItem(item=s + " ", score=len(s) + 1) for s in ss] d = CategoricalLogPMF.from_samples(["a", "bb", "ccc"]) - scorer = Scorer.from_f_batch_str_to_batch_sample(f) + scorer = Scorer.from_f_batch_str_to_batch_item(f) samples = scorer(d) - assert [s.utility for s in samples] == [4, 3, 2] + assert [s.score for s in samples] == [4, 3, 2] assert [s.item for s in samples] == ["ccc ", "bb ", "a "] -def test_scorer_from_f_catlogpmf_to_batch_sample() -> None: - def f(d: CategoricalLogPMF[str]) -> list[Sample[str]]: +def test_scorer_from_f_logpmf_to_batch_item() -> None: + def f(d: CategoricalLogPMF[str]) -> list[ScoredItem[str]]: return [ - Sample(item=cat + " ", utility=float(jnp.exp(logp) * (len(cat) + 1))) + ScoredItem(item=cat + " ", score=float(jnp.exp(logp) * (len(cat) + 1))) for logp, cat in d ] d = CategoricalLogPMF.from_samples(["a", "bb", "bb", "bb", "ccc"]) - scorer = Scorer.from_f_catlogpmf_to_batch_sample(f) + scorer = Scorer.from_f_logpmf_to_batch_item(f) samples = scorer(d) assert [s.item for s in samples] == ["bb ", "ccc ", "a "]