Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ jobs:
run: python3 -m pip install -r benchmark_v2/requirements.txt kernels

- name: Reinstall transformers in edit mode (remove the one installed during docker image build)
working-directory: /transformers
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e ".[torch]"

- name: Run benchmark
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/generation/continuous_batching/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,14 @@ def update(
# Return the new KV values
return key_states_with_cache, value_states_with_cache

@traced
def close(self):
self.key_cache.clear()
self.value_cache.clear()

torch._dynamo.reset()
torch._dynamo.reset_code_caches()


# TODO: rework computation with the groups and their sizes
class PagedAttentionMemoryHandler:
Expand Down
33 changes: 33 additions & 0 deletions src/transformers/generation/continuous_batching/continuous_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,35 @@ def _sample(self, probs: torch.Tensor, do_sample: bool) -> None:
tokens = next_tokens.size(1) # Get seq_len dimension
self.output_ids[:, :tokens].copy_(next_tokens)

def close(self):
self.cache.close()
self.requests_in_batch.clear()

if self._graphs is not None:
self._graphs.clear()

del self.input_ids
del self.position_ids
del self.cumulative_seqlens_q
del self.logits_indices
del self.output_ids

self.cumulative_seqlens_k.clear()

if self.attention_mask is not None:
self.attention_mask.clear()
self.attention_mask = None

self.write_index_storage.clear()
self.read_index_storage.clear()

if torch.cuda.is_available():
torch.cuda.synchronize()
import gc

gc.collect()
torch.cuda.empty_cache()


# Manager Class (User Interface)
@attach_tracer()
Expand Down Expand Up @@ -826,6 +855,10 @@ def stop(self, block: bool = True, timeout: Optional[float] = None) -> None:
if block:
self.join(stop_trigger_time, timeout)

if self.batch_processor is not None:
self.batch_processor.close()
self.batch_processor = None # NOTE: this is enough to clear memory after stop, still calling `close()` because it calls torch cache intrinsics

def join(self, stop_trigger_time: float, timeout: Optional[float] = None) -> None:
"""Wait for the background thread to finish.

Expand Down
Loading