Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from sampling import autoregressive_sampling, speculative_sampling, speculative_sampling_v2
from globals import Decoder
import time



Expand Down Expand Up @@ -95,32 +96,44 @@ def generate(input_text, approx_model_name, target_model_name, num_tokens=20, ga
top_p = 0.9

torch.manual_seed(123)
start = time.time()
output = autoregressive_sampling(input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
end = time.time()
color_print(f"large (target) model autoregressive_sampling: {generated_text}")
color_print(f"Elapsed time for Large Autoregressive_sampling: {end-start}")

if use_benchmark:
benchmark(autoregressive_sampling, "AS_large", use_profiling,
input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p)

torch.manual_seed(123)
start = time.time()
output = autoregressive_sampling(input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
end = time.time()
color_print(f"small (approx) model autoregressive_sampling: {generated_text}")
color_print(f"Elapsed time for Small Autoregressive_sampling: {end-start}")

if use_benchmark:
benchmark(autoregressive_sampling, "AS_small", use_profiling,
input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p)

torch.manual_seed(123)
start = time.time()
output = speculative_sampling_v2(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
color_print(f"deepmind's speculative_sampling: {generated_text}")
end = time.time()
color_print(f"deepmind's speculative_sampling: {generated_text}")
color_print(f"Elapsed time for deepmind's speculative_sampling: {end-start}")

torch.manual_seed(123)
start = time.time()
output = speculative_sampling(input_ids, small_model, large_model, num_tokens, gamma = gamma, top_k = top_k, top_p=top_p, random_seed = random_seed, verbose = verbose)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
end = time.time()
color_print(f"google's speculative_sampling: {generated_text}")
color_print(f"Elapsed time for google's speculative_sampling: {end-start}")

if use_benchmark:
benchmark(speculative_sampling, "SP", use_profiling,
Expand Down
59 changes: 36 additions & 23 deletions sampling/kvcache_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from typing import Optional
from transformers.cache_utils import DynamicCache

from sampling.utils import norm_logits, sample
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
Expand Down Expand Up @@ -33,6 +34,18 @@ def _forward_with_kvcache(self, input_ids : torch.Tensor, use_debug = True) -> t
self._past_key_values = outputs.past_key_values
last_q = self._prob_history[:, -1, :]
else:
if isinstance(self._past_key_values, DynamicCache):
cached_len = self._past_key_values.get_seq_length()
else:
cached_len = 0
for kv in self._past_key_values:
k, v = kv
cached_len = k.shape[2] # For Bloom
if k.dim() == 3: # Handle standard (batch, heads, seq_len, dim) format
cached_len = k.shape[2]
else:
cached_len = k.shape[-2]
break
# return the last token's logits
cached_len = 0
for kv in self._past_key_values:
Expand Down Expand Up @@ -90,28 +103,28 @@ def generate(self, input : torch.Tensor, gamma : int) -> torch.Tensor:
return output

@torch.no_grad()
def rollback(self, end_pos : int):
past_key_values_trimmed = []
assert self._past_key_values
for kv in self._past_key_values:
k, v = kv
# NOTE() the indexing is specific for bloom. This won't work for other models
# For example llama k, v should be (batch, num_head, seq_len, hidden_dim)

# Bloom is special one
if isinstance(self._model, BloomForCausalLM):
# k (batch * head, hidden_dim, seq); v (batch * head, seq, hidden_dim)
k = k[:, :, :end_pos]
v = v[:, :end_pos, :]
kv_trimmed = (k, v)
past_key_values_trimmed.append(kv_trimmed)
else:
# k, v (batch, head, seq, hidden_dim)
k = k[:, :, :end_pos, :]
v = v[:, :, :end_pos, :]
kv_trimmed = (k, v)
past_key_values_trimmed.append(kv_trimmed)
def rollback(self, end_pos: int):
if isinstance(self._past_key_values, DynamicCache):
# Truncate DynamicCache
new_cache = DynamicCache()
for layer_idx in range(len(self._past_key_values.key_cache)):
k = self._past_key_values.key_cache[layer_idx][..., :end_pos, :]
v = self._past_key_values.value_cache[layer_idx][..., :end_pos, :]
new_cache.key_cache.append(k)
new_cache.value_cache.append(v)
self._past_key_values = new_cache
else:
# Original tuple-based handling
past_key_values_trimmed = []
for kv in self._past_key_values:
k, v = kv
if isinstance(self._model, BloomForCausalLM):
k = k[:, :, :end_pos]
v = v[:, :end_pos, :]
else:
k = k[..., :end_pos, :]
v = v[..., :end_pos, :]
past_key_values_trimmed.append((k, v))
self._past_key_values = past_key_values_trimmed

self._past_key_values = past_key_values_trimmed
self._prob_history = self._prob_history[:, :end_pos, :]