@@ -42,6 +42,9 @@ def __init__(
4242 zch_size : int ,
4343 frequency : int ,
4444 start_bucket : int ,
45+ num_buckets_per_rank : int ,
46+ num_reserved_slots_per_bucket : int ,
47+ device : torch .device ,
4548 disable_fallback : bool ,
4649 log_file_path : str = "" ,
4750 ) -> None :
@@ -57,8 +60,32 @@ def __init__(
5760 self ._zch_size : int = zch_size
5861 self ._frequency : int = frequency
5962 self ._start_bucket : int = start_bucket
63+ self ._num_buckets_per_rank : int = num_buckets_per_rank
64+ self ._num_reserved_slots_per_bucket : int = num_reserved_slots_per_bucket
65+ self ._device : torch .device = device
6066 self ._disable_fallback : bool = disable_fallback
6167
68+ assert (
69+ self ._zch_size % self ._num_buckets_per_rank == 0
70+ ), f"{ self ._zch_size } must be divisible by { self ._num_buckets_per_rank } "
71+ indice_per_bucket = torch .tensor (
72+ [
73+ (self ._zch_size // self ._num_buckets_per_rank ) * bucket
74+ for bucket in range (1 , self ._num_buckets_per_rank + 1 )
75+ ],
76+ dtype = torch .int64 ,
77+ device = self ._device ,
78+ )
79+
80+ self ._opt_in_ranges : torch .Tensor = torch .sub (
81+ indice_per_bucket ,
82+ (
83+ self ._num_reserved_slots_per_bucket
84+ if self ._num_reserved_slots_per_bucket > 0
85+ else 0
86+ ),
87+ )
88+
6289 self ._dtype_checked : bool = False
6390 self ._total_cnt : int = 0
6491 self ._hit_cnt : int = 0
@@ -77,6 +104,13 @@ def __init__(
77104 ) # initialize file handler
78105 self .logger .addHandler (file_handler ) # add file handler to logger
79106
107+ self .logger .info (
108+ f"ScalarLogger: { self ._name = } , { self ._device = } , "
109+ f"{ self ._zch_size = } , { self ._frequency = } , { self ._start_bucket = } , "
110+ f"{ self ._num_buckets_per_rank = } , { self ._num_reserved_slots_per_bucket = } , "
111+ f"{ self ._opt_in_ranges = } , { self ._disable_fallback = } "
112+ )
113+
80114 def should_report (self ) -> bool :
81115 # We only need to report metrics from rank0 (start_bucket = 0)
82116
@@ -95,9 +129,9 @@ def update(
95129 identities_1 : torch .Tensor ,
96130 values : torch .Tensor ,
97131 remapped_ids : torch .Tensor ,
132+ hit_indices : torch .Tensor ,
98133 evicted_emb_indices : Optional [torch .Tensor ],
99134 metadata : Optional [torch .Tensor ],
100- num_reserved_slots : int ,
101135 eviction_config : Optional [HashZchEvictionConfig ] = None ,
102136 ) -> None :
103137 if not self ._dtype_checked :
@@ -125,9 +159,17 @@ def update(
125159 self ._hit_cnt += hit_cnt
126160 self ._collision_cnt += values .numel () - hit_cnt - insert_cnt
127161
128- opt_in_range = self ._zch_size - num_reserved_slots
129- opt_in_ids = torch .lt (remapped_ids , opt_in_range )
130- self ._opt_in_cnt += int (torch .sum (opt_in_ids ).item ())
162+ if self ._disable_fallback :
163+ hit_values = values [hit_indices ]
164+ train_buckets = hit_values % self ._num_buckets_per_rank
165+ else :
166+ train_buckets = values % self ._num_buckets_per_rank
167+
168+ opt_in_ranges = self ._opt_in_ranges .index_select (dim = 0 , index = train_buckets )
169+ opt_in_ids = torch .lt (remapped_ids , opt_in_ranges )
170+ opt_in_ids_cnt = int (torch .sum (opt_in_ids ).item ())
171+ # opt_in_cnt: # of ids assigned to opt-in block
172+ self ._opt_in_cnt += opt_in_ids_cnt
131173
132174 if evicted_emb_indices is not None and evicted_emb_indices .numel () > 0 :
133175 deduped_evicted_indices = torch .unique (evicted_emb_indices )
0 commit comments