Skip to content

Commit b9c1e52

Browse files
Alireza Tehranimeta-codesync[bot]
authored andcommitted
Implement KV-ZCH Benchmark (#3540)
Summary: Pull Request resolved: #3540 Implements KV-ZCH with the benchmarking platform Several things were added to make it work with KV-ZCH: - Added eviction policies - Added KeyValueParam to add parameters to TBE `fused_params` which is then fed into `SSDTableBatchedEmbeddingBags`. See `_populate_ssd_tbe_params` in batched_embedding_kernel and `add_params_from_parameter_sharding` in distributed/utils.py. - Added CacheParams creation to set `prefetch_pipeline=True` due to warning below. NOTE: The `prefetch_pipeline` attribute of `CacheParams` is set to True, due to the following complaint without it: {F1983388476,width=300,height=200} Update on November 11, 2025: - The line `pipeline.progress(iter(bench_inputs)) ` is commented out on `benchmark_train_pipeline.py` due to conflict with `pipeline.reset()`. This gives an error on the forward pass when using `pipeline="prefetch"` with KV-ZCH. Reviewed By: TroyGarden Differential Revision: D86677315 fbshipit-source-id: e5d9ca737c59a589fde5d0e33b27fc9874d18b80
1 parent 2e5701c commit b9c1e52

File tree

4 files changed

+167
-5
lines changed

4 files changed

+167
-5
lines changed

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def _func_to_benchmark(
196196
opt=optimizer,
197197
device=ctx.device,
198198
)
199-
pipeline.progress(iter(bench_inputs)) # warmup
199+
# Commented out due to potential conflict with pipeline.reset()
200+
# pipeline.progress(iter(bench_inputs)) # warmup
200201

201202
run_option.name = (
202203
type(pipeline).__name__ if run_option.name == "" else run_option.name
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# This is a very basic KV-ZCH (ZCH v.Next) benchmark configuration
2+
# For guidelines, see document `ZCH v.Next Onboarding Guidelines`
3+
# KV-ZCH parameters have comments next to them below.
4+
# Runs on 2 ranks, showing traces with reasonable workloads
5+
RunOptions:
6+
world_size: 2
7+
num_batches: 10
8+
num_benchmarks: 1
9+
num_profiles: 1
10+
sharding_type: table_wise
11+
profile_dir: "."
12+
name: "sparsenn_prefetch_kvzch_dram"
13+
PipelineConfig:
14+
pipeline: "prefetch"
15+
ModelInputConfig:
16+
feature_pooling_avg: 30
17+
EmbeddingTablesConfig:
18+
num_unweighted_features: 10
19+
num_weighted_features: 10
20+
embedding_feature_dim: 256
21+
additional_tables:
22+
- - name: FP16_table
23+
embedding_dim: 512
24+
num_embeddings: 100_000 # Both feature hashsize and virtual table size
25+
feature_names: ["additional_0_0"]
26+
data_type: FP16
27+
total_num_buckets: 100 # num_embedding should be divisible by total_num_buckets
28+
location: "DRAM_VIRTUAL_TABLE" # See sparsenn.configs::LocationType,
29+
# either SSD_VIRTUAL_TABLE, DRAM_VIRTUAL_TABLE
30+
# weight_init_max: 10 # Controls initial Embedding table values
31+
# weight_init_min: -10 # Controls initial Embedding table values
32+
# virtual_table_eviction_policy: # If want eviction policy
33+
# CountBasedEvictionPolicy:
34+
# training_id_eviction_trigger_count: 10000
35+
# eviction_threshold: 15
36+
# decay_rate: 0.99
37+
38+
- name: large_table
39+
embedding_dim: 2048
40+
num_embeddings: 1_000_000
41+
feature_names: ["additional_0_1"]
42+
- []
43+
- - name: skipped_table
44+
embedding_dim: 128
45+
num_embeddings: 100_000
46+
feature_names: ["additional_2_1"]
47+
PlannerConfig:
48+
additional_constraints:
49+
large_table:
50+
sharding_types: [column_wise]
51+
FP16_table:
52+
sharding_types: [row_wise] # KV-ZCH virtual tables currently only support row_wise sharding
53+
compute_kernels: [dram_virtual_table] # Either ['ssd_virtual_table', 'dram_virtual_table'], must match above
54+
cache_params:
55+
prefetch_pipeline: True # Required for SSD/DRAM virtual tables
56+
key_value_params:
57+
max_l1_cache_size: 1250 # in MB, check warnings in log to see if it is actually used.
58+
l2_cache_size: 64 # in GB
59+
gather_ssd_cache_stats: False
60+
ssd_rocksdb_shards: 32
61+
# Only use if `virtual_table_eviction_policy` is set above.
62+
# kvzch_tbe_config: # See fbgemm_gpu/split_table_batched_embeddings_ops_common.py::KVZCHEvictionTBEConfig
63+
# kvzch_eviction_trigger_mode: 2 # 0:disabled, 1:iteration, 2:mem_util, 3:manual, 4:id_count, 5:free_mem
64+
# eviction_free_mem_threshold_gb: 200 # Minimum free memory in GB before eviction
65+
# eviction_free_mem_check_interval_batch: 1000 # Batches between free memory checks
66+
# threshold_calculation_bucket_stride: 0.2 # Feature score bucket width
67+
# threshold_calculation_bucket_num: 1000000 # Total number of feature score buckets

torchrec/distributed/test_utils/sharding_config.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
from dataclasses import dataclass, field
1010
from typing import Any, Dict, List, Optional, Tuple, Union
1111

12+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import KVZCHTBEConfig
13+
1214
from torchrec.distributed.comm import get_local_size
1315

1416
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
1517
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
1618
from torchrec.distributed.planner.constants import POOLING_FACTOR
1719
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
18-
from torchrec.distributed.planner.types import ParameterConstraints
19-
from torchrec.distributed.types import ShardingType
20+
from torchrec.distributed.planner.types import CacheParams, ParameterConstraints
21+
from torchrec.distributed.types import KeyValueParams, ShardingType
2022
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
2123

2224

@@ -64,6 +66,25 @@ def table_to_constraint(
6466
else:
6567
kwargs = default_kwargs | kwargs
6668

69+
# (KVZCH) Convert key_value_params dict to KeyValueParams object if present
70+
if "key_value_params" in kwargs:
71+
key_value_params = kwargs["key_value_params"]
72+
# If eviction policy is set then construct object
73+
if (
74+
isinstance(key_value_params, dict)
75+
and "kvzch_tbe_config" in key_value_params
76+
):
77+
key_value_params["kvzch_tbe_config"] = KVZCHTBEConfig(
78+
**key_value_params["kvzch_tbe_config"]
79+
)
80+
# pyre-ignore[6,32]
81+
kwargs["key_value_params"] = KeyValueParams(**key_value_params)
82+
83+
# Convert cache_params dict to CacheParams object if present
84+
if "cache_params" in kwargs:
85+
# pyre-ignore[6,32]
86+
kwargs["cache_params"] = CacheParams(**kwargs["cache_params"])
87+
6788
constraint = ParameterConstraints(**kwargs) # pyre-ignore [6]
6889
return table.name, constraint
6990

torchrec/distributed/test_utils/table_config.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,77 @@
88
# pyre-strict
99

1010
from dataclasses import dataclass, field
11-
from typing import Any, Dict, List
11+
from typing import Any, Dict, List, Optional, Type
12+
13+
from torchrec.modules.embedding_configs import (
14+
CountBasedEvictionPolicy,
15+
CountTimestampMixedEvictionPolicy,
16+
EmbeddingBagConfig,
17+
FeatureScoreBasedEvictionPolicy,
18+
NoEvictionPolicy,
19+
TimestampBasedEvictionPolicy,
20+
VirtualTableEvictionPolicy,
21+
)
1222

13-
from torchrec.modules.embedding_configs import EmbeddingBagConfig
1423
from torchrec.types import DataType
1524

1625

26+
def _return_correct_eviction_policy(
27+
eviction_str: str,
28+
) -> Type[VirtualTableEvictionPolicy]:
29+
if eviction_str == "CountBasedEvictionPolicy":
30+
return CountBasedEvictionPolicy
31+
if eviction_str == "TimestampBasedEvictionPolicy":
32+
return TimestampBasedEvictionPolicy
33+
if eviction_str == "CountTimestampMixedEvictionPolicy":
34+
return CountTimestampMixedEvictionPolicy
35+
if eviction_str == "FeatureScoreBasedEvictionPolicy":
36+
return FeatureScoreBasedEvictionPolicy
37+
raise ValueError(f"Could not recognize eviction_str in yaml file: {eviction_str}")
38+
39+
40+
def _process_virtual_table_config(config_dict: Dict[str, Any]) -> None:
41+
"""Converts YAML virtual table fields (location, eviction-policy) to EBC format."""
42+
if "location" in config_dict:
43+
# config_dict["location"] should match LocationType
44+
config_dict["use_virtual_table"] = config_dict["location"] in [
45+
"DRAM_VIRTUAL_TABLE",
46+
"SSD_VIRTUAL_TABLE",
47+
]
48+
del config_dict["location"] # location not an attribute of EBC
49+
50+
if config_dict["use_virtual_table"]:
51+
assert (
52+
config_dict["total_num_buckets"] > 0
53+
), "Should be larger 0 when using SSD_VIRTUAL_TABLE or DRAM_VIRTUAL_TABLE"
54+
55+
assert (
56+
config_dict["num_embeddings"] % config_dict["total_num_buckets"] == 0
57+
), (
58+
f"num_embeddings ({config_dict['num_embeddings']}) must be divisible by "
59+
f"total_num_buckets ({config_dict['total_num_buckets']})"
60+
)
61+
62+
if "virtual_table_eviction_policy" in config_dict:
63+
# Obtain what eviction strategy was chosen
64+
eviction = config_dict["virtual_table_eviction_policy"]
65+
policy_class_name = next(iter(eviction.keys()))
66+
policy_params = eviction[policy_class_name]
67+
eviction = _return_correct_eviction_policy(policy_class_name)(
68+
**policy_params
69+
)
70+
else:
71+
# Choose standard no eviction policy
72+
eviction = NoEvictionPolicy()
73+
74+
# Initialize the eviction policy
75+
data_type = config_dict["data_type"]
76+
embedding_dim = config_dict["embedding_dim"]
77+
eviction.init_metaheader_config(data_type, embedding_dim)
78+
79+
config_dict["virtual_table_eviction_policy"] = eviction
80+
81+
1782
@dataclass
1883
class EmbeddingTablesConfig:
1984
"""
@@ -38,13 +103,18 @@ class EmbeddingTablesConfig:
38103
embedding_feature_dim: int = 128
39104
base_row_size: int = 100_000
40105
table_data_type: DataType = DataType.FP32
106+
total_num_buckets: Optional[int] = None
41107
additional_tables: List[List[Dict[str, Any]]] = field(default_factory=list)
42108

43109
def convert_to_ebconf(self, kwargs: Dict[str, Any]) -> EmbeddingBagConfig:
44110
if "data_type" in kwargs:
45111
kwargs["data_type"] = DataType[kwargs["data_type"]]
46112
else:
47113
kwargs["data_type"] = self.table_data_type
114+
115+
# Process configs for KV-ZCH/ZCH v.Next
116+
_process_virtual_table_config(kwargs)
117+
48118
return EmbeddingBagConfig(**kwargs)
49119

50120
def generate_tables(
@@ -70,6 +140,7 @@ def generate_tables(
70140
two lists - the first for unweighted embedding tables and the second for
71141
weighted embedding tables.
72142
"""
143+
73144
unweighted_tables = [
74145
EmbeddingBagConfig(
75146
num_embeddings=max(i + 1, 100) * self.base_row_size // 100,
@@ -90,6 +161,7 @@ def generate_tables(
90161
)
91162
for i in range(self.num_weighted_features)
92163
]
164+
93165
tables_list = []
94166
for idx, adts in enumerate(self.additional_tables):
95167
if idx == 0:
@@ -100,6 +172,7 @@ def generate_tables(
100172
tables = []
101173
for adt in adts:
102174
tables.append(self.convert_to_ebconf(adt))
175+
tables_list.append(tables)
103176

104177
if len(tables_list) == 0:
105178
tables_list.append(unweighted_tables)

0 commit comments

Comments
 (0)