Skip to content

Commit

Permalink
Add percentage based threshold function and unit tests (pytorch#1679)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1679

Add percentage based threshold function to mc_modules.

Also:
1. Add unit tests for all threshold functions.
2. Add one liner documentation for all threshold functions.
3. Add loggings for threshold.

Note that regardless of eviction policy, threshold functions only look at eviction counts.

Reviewed By: dstaay-fb

Differential Revision: D53030312

fbshipit-source-id: 52bde031b45a38429fccb7e280705995d41a1ab1
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Feb 5, 2024
1 parent f8f6f61 commit 02771aa
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 0 deletions.
28 changes: 28 additions & 0 deletions torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def dynamic_threshold_filter(
id_counts: torch.Tensor,
threshold_skew_multiplier: float = 10.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Threshold is total_count / num_ids * threshold_skew_multiplier. An id is
added if its count is strictly greater than the threshold.
"""

num_ids = id_counts.numel()
total_count = id_counts.sum()
Expand All @@ -69,6 +73,10 @@ def dynamic_threshold_filter(
def average_threshold_filter(
id_counts: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Threshold is average of id_counts. An id is added if its count is strictly
greater than the mean.
"""
if id_counts.dtype != torch.float:
id_counts = id_counts.float()
threshold = id_counts.mean()
Expand All @@ -77,6 +85,26 @@ def average_threshold_filter(
return threshold_mask, threshold


@torch.no_grad()
def probabilistic_threshold_filter(
id_counts: torch.Tensor,
per_id_probability: float = 0.01,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Each id has probability per_id_probability of being added. For example,
if per_id_probability is 0.01 and an id appears 100 times, then it has a 60%
of being added. More precisely, the id score is 1 - (1 - per_id_probability) ^ id_count,
and for a randomly generated threshold, the id score is the chance of it being added.
"""
probability = torch.full_like(id_counts, 1 - per_id_probability, dtype=torch.float)
id_scores = 1 - torch.pow(probability, id_counts)

threshold: torch.Tensor = torch.rand(id_counts.size(), device=id_counts.device)
threshold_mask = id_scores > threshold

return threshold_mask, threshold


class ManagedCollisionModule(nn.Module):
"""
Abstract base class for ManagedCollisionModule.
Expand Down
128 changes: 128 additions & 0 deletions torchrec/modules/tests/test_mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@

import torch
from torchrec.modules.mc_modules import (
average_threshold_filter,
DistanceLFU_EvictionPolicy,
dynamic_threshold_filter,
LFU_EvictionPolicy,
LRU_EvictionPolicy,
MCHManagedCollisionModule,
probabilistic_threshold_filter,
)
from torchrec.sparse.jagged_tensor import JaggedTensor

Expand Down Expand Up @@ -215,3 +218,128 @@ def test_distance_lfu_eviction_fast_decay(self) -> None:
self.assertEqual(list(_mch_counts), [1, 1, 1, 1, torch.iinfo(torch.int64).max])
_mch_last_access_iter = mc_module._mch_last_access_iter
self.assertEqual(list(_mch_last_access_iter), [2, 2, 3, 3, 3])

def test_dynamic_threshold_filter(self) -> None:
mc_module = MCHManagedCollisionModule(
zch_size=5,
device=torch.device("cpu"),
eviction_policy=LFU_EvictionPolicy(
threshold_filtering_func=lambda tensor: dynamic_threshold_filter(
tensor, threshold_skew_multiplier=0.75
)
),
eviction_interval=1,
input_hash_size=100,
)

# check initial state
_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5)
_mch_counts = mc_module._mch_counts
self.assertEqual(list(_mch_counts), [0] * 5)

ids = [5, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3, 2, 2, 1]
# threshold is len(ids) / unique_count(ids) * threshold_skew_multiplier
# = 15 / 5 * 0.5 = 2.25
features: Dict[str, JaggedTensor] = {
"f1": JaggedTensor(
values=torch.tensor(ids, dtype=torch.int64),
lengths=torch.tensor([1] * len(ids), dtype=torch.int64),
)
}
mc_module.profile(features)

_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
self.assertEqual(
list(_mch_sorted_raw_ids),
[3, 4, 5, torch.iinfo(torch.int64).max, torch.iinfo(torch.int64).max],
)
_mch_counts = mc_module._mch_counts
self.assertEqual(list(_mch_counts), [3, 4, 5, 0, torch.iinfo(torch.int64).max])

def test_average_threshold_filter(self) -> None:
mc_module = MCHManagedCollisionModule(
zch_size=5,
device=torch.device("cpu"),
eviction_policy=LFU_EvictionPolicy(
threshold_filtering_func=average_threshold_filter
),
eviction_interval=1,
input_hash_size=100,
)

# check initial state
_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5)
_mch_counts = mc_module._mch_counts
self.assertEqual(list(_mch_counts), [0] * 5)

# insert some values to zch
# we have 10 counts of 4 and 1 count of 5
mc_module._mch_sorted_raw_ids[0:2] = torch.tensor([4, 5])
mc_module._mch_counts[0:2] = torch.tensor([10, 1])

ids = [3, 4, 5, 6, 6, 6, 7, 8, 8, 9, 10]
# threshold is 1.375
features: Dict[str, JaggedTensor] = {
"f1": JaggedTensor(
values=torch.tensor(ids, dtype=torch.int64),
lengths=torch.tensor([1] * len(ids), dtype=torch.int64),
)
}
mc_module.profile(features)

# empty, empty will be evicted
# 6, 8 will be added
# 7 is not added because it's below the average threshold
_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
self.assertEqual(
list(_mch_sorted_raw_ids), [4, 5, 6, 8, torch.iinfo(torch.int64).max]
)
# count for 4 is not updated since it's below the average threshold
_mch_counts = mc_module._mch_counts
self.assertEqual(list(_mch_counts), [10, 1, 3, 2, torch.iinfo(torch.int64).max])

def test_probabilistic_threshold_filter(self) -> None:
mc_module = MCHManagedCollisionModule(
zch_size=5,
device=torch.device("cpu"),
eviction_policy=LFU_EvictionPolicy(
threshold_filtering_func=lambda tensor: probabilistic_threshold_filter(
tensor,
per_id_probability=0.01,
)
),
eviction_interval=1,
input_hash_size=100,
)

# check initial state
_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5)
_mch_counts = mc_module._mch_counts
self.assertEqual(list(_mch_counts), [0] * 5)

unique_ids = [5, 4, 3, 2, 1]
id_counts = [100, 80, 60, 40, 10]
ids = [id for id, count in zip(unique_ids, id_counts) for _ in range(count)]
# chance of being added is [0.63, 0.55, 0.45, 0.33]
features: Dict[str, JaggedTensor] = {
"f1": JaggedTensor(
values=torch.tensor(ids, dtype=torch.int64),
lengths=torch.tensor([1] * len(ids), dtype=torch.int64),
)
}

torch.manual_seed(42)
for _ in range(10):
mc_module.profile(features)

_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
print(f"henry {mc_module._mch_counts}")
self.assertEqual(
sorted(_mch_sorted_raw_ids.tolist()),
[2, 3, 4, 5, torch.iinfo(torch.int64).max],
)
# _mch_counts is like
# [80, 180, 160, 800, 9223372036854775807]

0 comments on commit 02771aa

Please sign in to comment.