Skip to content

Commit

Permalink
Added batch_size parameter to compute_banzhaf_semivalues, `comput…
Browse files Browse the repository at this point in the history
…e_beta_shapley_semivalues`, `compute_shapley_semivalues` and `compute_generic_semivalues`.
  • Loading branch information
Markus Semmler committed Sep 5, 2023
1 parent e5c117a commit 0914b66
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 13 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ randomness.
`pydvl.value.semivalues`. Introduced new type `Seed` and conversion function
`ensure_seed_sequence`.
[PR #396](https://github.com/aai-institute/pyDVL/pull/396)
- Added `batch_size` parameter to `compute_banzhaf_semivalues`,
`compute_beta_shapley_semivalues`, `compute_shapley_semivalues` and
`compute_generic_semivalues`.
[PR #428](https://github.com/aai-institute/pyDVL/pull/428)

### Changed

Expand Down Expand Up @@ -240,3 +244,4 @@ It contains:
- Parallelization of computations with Ray
- Documentation
- Notebooks containing examples of different use cases

47 changes: 34 additions & 13 deletions src/pydvl/value/semivalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@
import logging
import math
from enum import Enum
from typing import Optional, Protocol, Tuple, Type, TypeVar, cast
from itertools import islice
from typing import Collection, List, Optional, Protocol, Tuple, Type, TypeVar, cast

import numpy as np
import scipy as sp
Expand Down Expand Up @@ -123,23 +124,28 @@ def __call__(self, n: int, k: int) -> float:
MarginalT = Tuple[IndexT, float]


def _marginal(u: Utility, coefficient: SVCoefficient, sample: SampleT) -> MarginalT:
def _marginal(
u: Utility, coefficient: SVCoefficient, samples: Collection[SampleT]
) -> Tuple[MarginalT, ...]:
"""Computation of marginal utility. This is a helper function for
[compute_generic_semivalues][pydvl.value.semivalues.compute_generic_semivalues].
Args:
u: Utility object with model, data, and scoring function.
coefficient: The semivalue coefficient and sampler weight
sample: A tuple of index and subset of indices to compute a marginal
utility.
samples: A collection of samples. Each sample is a tuple of index and subset of
indices to compute a marginal utility.
Returns:
Tuple with index and its marginal utility.
A collection of marginals. Each marginal is a tuple with index and its marginal
utility.
"""
n = len(u.data)
idx, s = sample
marginal = (u({idx}.union(s)) - u(s)) * coefficient(n, len(s))
return idx, marginal
marginals: List[MarginalT] = []
for idx, s in samples:
marginal = (u({idx}.union(s)) - u(s)) * coefficient(n, len(s))
marginals.append((idx, marginal))
return tuple(marginals)


# @deprecated(
Expand All @@ -153,6 +159,7 @@ def compute_generic_semivalues(
coefficient: SVCoefficient,
done: StoppingCriterion,
*,
batch_size: int = 1,
n_jobs: int = 1,
config: ParallelConfig = ParallelConfig(),
progress: bool = False,
Expand All @@ -164,6 +171,7 @@ def compute_generic_semivalues(
u: Utility object with model, data, and scoring function.
coefficient: The semi-value coefficient
done: Stopping criterion.
batch_size: Number of marginal evaluations per (parallelized) task.
n_jobs: Number of parallel jobs to use.
config: Object configuring parallel computation, with cluster
address, number of cpus, etc.
Expand Down Expand Up @@ -210,20 +218,24 @@ def compute_generic_semivalues(

completed, pending = wait(pending, timeout=1, return_when=FIRST_COMPLETED)
for future in completed:
idx, marginal = future.result()
result.update(idx, marginal)
if done(result):
return result
for idx, marginal in future.result():
result.update(idx, marginal)
if done(result):
return result

# Ensure that we always have n_submitted_jobs running
try:
for _ in range(n_submitted_jobs - len(pending)):
samples = tuple(islice(sampler_it, batch_size))
if len(samples) == 0:
raise StopIteration

pending.add(
executor.submit(
_marginal,
u=u,
coefficient=correction,
sample=next(sampler_it),
samples=samples,
)
)
except StopIteration:
Expand Down Expand Up @@ -266,6 +278,7 @@ def compute_shapley_semivalues(
*,
done: StoppingCriterion = MaxUpdates(100),
sampler_t: Type[StochasticSampler] = PermutationSampler,
batch_size: int = 1,
n_jobs: int = 1,
config: ParallelConfig = ParallelConfig(),
progress: bool = False,
Expand All @@ -284,6 +297,7 @@ def compute_shapley_semivalues(
done: Stopping criterion.
sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler`
for a list.
batch_size: Number of marginal evaluations per (parallelized) task.
n_jobs: Number of parallel jobs to use.
config: Object configuring parallel computation, with cluster
address, number of cpus, etc.
Expand All @@ -298,6 +312,7 @@ def compute_shapley_semivalues(
u,
shapley_coefficient,
done,
batch_size=batch_size,
n_jobs=n_jobs,
config=config,
progress=progress,
Expand All @@ -309,6 +324,7 @@ def compute_banzhaf_semivalues(
*,
done: StoppingCriterion = MaxUpdates(100),
sampler_t: Type[StochasticSampler] = PermutationSampler,
batch_size: int = 1,
n_jobs: int = 1,
config: ParallelConfig = ParallelConfig(),
progress: bool = False,
Expand All @@ -325,6 +341,7 @@ def compute_banzhaf_semivalues(
done: Stopping criterion.
sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a
list.
batch_size: Number of marginal evaluations per (parallelized) task.
n_jobs: Number of parallel jobs to use.
seed: Either an instance of a numpy random number generator or a seed for it.
config: Object configuring parallel computation, with cluster address,
Expand All @@ -339,6 +356,7 @@ def compute_banzhaf_semivalues(
u,
banzhaf_coefficient,
done,
batch_size=batch_size,
n_jobs=n_jobs,
config=config,
progress=progress,
Expand All @@ -352,6 +370,7 @@ def compute_beta_shapley_semivalues(
beta: float = 1,
done: StoppingCriterion = MaxUpdates(100),
sampler_t: Type[StochasticSampler] = PermutationSampler,
batch_size: int = 1,
n_jobs: int = 1,
config: ParallelConfig = ParallelConfig(),
progress: bool = False,
Expand All @@ -369,6 +388,7 @@ def compute_beta_shapley_semivalues(
beta: Beta parameter of the Beta distribution.
done: Stopping criterion.
sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a list.
batch_size: Number of marginal evaluations per (parallelized) task.
n_jobs: Number of parallel jobs to use.
seed: Either an instance of a numpy random number generator or a seed for it.
config: Object configuring parallel computation, with cluster address, number of
Expand All @@ -383,6 +403,7 @@ def compute_beta_shapley_semivalues(
u,
beta_coefficient(alpha, beta),
done,
batch_size=batch_size,
n_jobs=n_jobs,
config=config,
progress=progress,
Expand Down
27 changes: 27 additions & 0 deletions tests/value/test_semivalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,33 @@ def test_shapley(
check_values(values, exact_values, rtol=0.2)


@pytest.mark.parametrize(
"num_samples,sampler,coefficient,batch_size",
[(5, PermutationSampler, beta_coefficient(1, 1), 2)],
)
def test_shapley_batch_size(
num_samples: int,
analytic_shapley,
sampler: Type[PowersetSampler],
coefficient: SVCoefficient,
batch_size: int,
n_jobs: int,
parallel_config: ParallelConfig,
):
u, exact_values = analytic_shapley
criterion = AbsoluteStandardError(0.02, 1.0) | MaxUpdates(2 ** (num_samples * 2))
values = compute_generic_semivalues(
sampler(u.data.indices),
u,
coefficient,
criterion,
n_jobs=n_jobs,
batch_size=batch_size,
config=parallel_config,
)
check_values(values, exact_values, rtol=0.2)


@pytest.mark.parametrize("num_samples", [5])
@pytest.mark.parametrize(
"sampler",
Expand Down

0 comments on commit 0914b66

Please sign in to comment.