Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions botorch/utils/probability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,33 @@ def compute_log_prob_feas_from_bounds(
dist_u = (con_upper - means[..., i]) / sigmas[..., i]
log_prob = log_prob + log_prob_normal_in(a=dist_l, b=dist_u).sum(dim=-1)
return log_prob


def percentile_of_score(data: Tensor, score: Tensor, dim: int = -1) -> Tensor:
"""Compute the percentile rank of `score` relative to `data`.
For example, if this function returns 70 then 70% of the
values in `data` are below `score`.

This implementation is based on `scipy.stats.percentileofscore`,
with `kind='rank'` and `nan_policy='propagate'`, which is the default.

Args:
data: A `... x n x output_shape`-dim Tensor of data.
score: A `... x 1 x output_shape`-dim Tensor of scores.

Returns:
A `... x output_shape`-dim Tensor of percentile ranks.
"""
# based on scipy.stats.percentileofscore
left = torch.count_nonzero(data < score, dim=dim)
right = torch.count_nonzero(data <= score, dim=dim)
plus1 = left < right
perct = (left + right + plus1) * (50.0 / data.shape[dim])
# perct shape: `... x output_shape`
# fill in nans due to current trial progression being nan
nan_mask = torch.broadcast_to(torch.isnan(score.squeeze(dim)), perct.shape)
perct[nan_mask] = torch.nan
# fill in nans due to previous trial progressions being nan
nan_mask = torch.broadcast_to(torch.any(torch.isnan(data), dim=dim), perct.shape)
perct[nan_mask] = torch.nan
return perct
70 changes: 70 additions & 0 deletions test/utils/probability/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

from __future__ import annotations

import itertools

import numpy as np

import torch
from botorch.utils.probability import ndtr, utils
from botorch.utils.probability.utils import (
Expand All @@ -14,11 +18,13 @@
log_ndtr,
log_phi,
log_prob_normal_in,
percentile_of_score,
phi,
standard_normal_log_hazard,
)
from botorch.utils.testing import BotorchTestCase
from numpy.polynomial.legendre import leggauss as numpy_leggauss
from scipy.stats import percentileofscore as percentile_of_score_scipy


class TestProbabilityUtils(BotorchTestCase):
Expand Down Expand Up @@ -321,3 +327,67 @@ def test_gaussian_probabilities(self) -> None:

with self.assertRaisesRegex(TypeError, expected_regex=float16_msg):
log_ndtr(torch.tensor(1.0, dtype=torch.float16, device=self.device))

def test_percentile_of_score(self) -> None:
# compare to scipy.stats.percentileofscore with default settings
# `kind='rank'` and `nan_policy='propagate'`
torch.manual_seed(12345)
n = 10
for (
dtype,
data_batch_shape,
score_batch_shape,
output_shape,
) in itertools.product(
(torch.float, torch.double),
((), (1,), (2,), (2, 3)),
((), (1,), (2,), (2, 3)),
((), (1,), (2,), (2, 3)),
):
# calculate shapes
data_shape = data_batch_shape + (n,) + output_shape
score_shape = score_batch_shape + (1,) + output_shape
dim = -1 - len(output_shape)
# generate data
data = torch.rand(*data_shape, dtype=dtype, device=self.device)
score = torch.rand(*score_shape, dtype=dtype, device=self.device)
# insert random nans to test nan policy
data[data < 0.01] = torch.nan
score[score < 0.01] = torch.nan
# calculate percentile ranks using torch
try:
perct_torch = percentile_of_score(data, score, dim=dim).cpu().numpy()
except RuntimeError:
# confirm RuntimeError is raised because shapes cannot be broadcasted
with self.assertRaises(ValueError):
np.broadcast_shapes(data_batch_shape, score_batch_shape)
continue
# check shape
broadcast_shape = np.broadcast_shapes(data_batch_shape, score_batch_shape)
expected_perct_shape = broadcast_shape + output_shape
self.assertEqual(perct_torch.shape, expected_perct_shape)
# calculate percentile ranks using scipy.stats.percentileofscore
# scipy.stats.percentileofscore does not support broadcasting
# loop over batch and output shapes instead
perct_scipy = np.zeros_like(perct_torch)
data_scipy = np.broadcast_to(
data.cpu().numpy(), broadcast_shape + (n,) + output_shape
)
score_scipy = np.broadcast_to(
score.cpu().numpy(), broadcast_shape + (1,) + output_shape
)
broadcast_idx_prod = list(
itertools.product(*[list(range(d)) for d in broadcast_shape])
)
output_idx_prod = list(
itertools.product(*[list(range(d)) for d in output_shape])
)
for broadcast_idx in broadcast_idx_prod:
for output_idx in output_idx_prod:
data_idx = broadcast_idx + (slice(None),) + output_idx
score_idx = broadcast_idx + (0,) + output_idx
perct_idx = broadcast_idx + output_idx
perct_scipy[perct_idx] = percentile_of_score_scipy(
data_scipy[data_idx], score_scipy[score_idx]
)
self.assertTrue(np.array_equal(perct_torch, perct_scipy, equal_nan=True))