File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -161,7 +161,13 @@ def determine_available_memory(self) -> int:
161
161
# intermediate activations.
162
162
m = xm .get_memory_info (self .device )
163
163
total_memory_size = m ["bytes_limit" ]
164
- profiled = m ["peak_bytes_used" ] # Weights + intermediate activations.
164
+ current_mem = m ["bytes_used" ]
165
+ # Ideally we would use profiled = m["peak_bytes_used"] to
166
+ # get weights + activations. But there is memory used during
167
+ # compilation / weight loading that impacts the peak and
168
+ # there is no way to reset peak memory in XLA, So we
169
+ # use the heuristic of 2% of weights.
170
+ profiled = current_mem * 1.02
165
171
166
172
# Calculate the TPU KV cache size based on profiling.
167
173
usable_memory_size = int (total_memory_size *
You can’t perform that action at this time.
0 commit comments