Skip to content

No speed-up of model.generate() with StaticCache + torch.compile in 4.39.3 #30055

@learning-chip

Description

@learning-chip

System Info

torch==2.2.2
transformers==4.39.3

Platform: RTX 4090 or RTX A6000 rent on vast.ai

Who can help?

@ArthurZucker @gan

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Below is the complete benchmark script. The compile setting follows #29791. The generate setting and benchmark metric follows gpt-fast/generate.py

import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache

import torch
import torch._dynamo.config
import torch._inductor.config

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True

# set same as gpt-fast default config
compile = True
device = "cuda"
precision = torch.bfloat16
num_samples = 5
max_new_tokens = 200
start = -1 if compile else 0
torch.manual_seed(1234)

def device_sync(device):
    if "cuda" in device:
        torch.cuda.synchronize(device)
    elif ("cpu" in device) or ("mps" in device):
        pass
    else:
        print(f"device={device} is not yet suppported")

model_name = "openlm-research/open_llama_7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)

t0 = time.time()
model = AutoModelForCausalLM.from_pretrained(
    model_name, device_map={"": device}, torch_dtype=precision, attn_implementation='sdpa'
)
device_sync(device=device)
t = time.time() - t0
print(f"Time to load model: {t:.02f} seconds")

prompt = "Hello, my name is"
inputs = tokenizer(prompt, return_tensors='pt').to(device)
prompt_length = inputs["input_ids"].size(1)
max_seq_length = max_new_tokens + prompt_length

model._setup_cache(StaticCache, max_batch_size=1, max_cache_len=max_seq_length)

if compile:
    model = torch.compile(model, mode="reduce-overhead", fullgraph=True)

aggregate_metrics = {
    'tokens_per_sec': []
}

for i in range(start, num_samples):
    device_sync(device)
    t0 = time.perf_counter()
    output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=True, top_k=200, temperature=0.8)
    device_sync(device)
    t = time.perf_counter() - t0

    print(tokenizer.decode(output[0]))

    tokens_generated = output.size(1) - prompt_length
    tokens_sec = tokens_generated / t
    if i == -1:
        print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
        continue
    
    aggregate_metrics['tokens_per_sec'].append(tokens_sec)
    print(f"tokens generated: {tokens_generated}")
    print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")

print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

With either compile=True or compile=False, the throughput is always ~33 token/s.
Complete logs on A6000:

For comparison, same benchmark setting using https://github.com/pytorch-labs/gpt-fast:

export MODEL_REPO=openlm-research/open_llama_7b
python ./generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is" # eager
python ./generate.py --compile  checkpoints/$MODEL_REPO/model.pth ./checkpoints/open_llama_7b/model.pth  --prompt "Hello, my name is" # compile

Here the eager version gives ~30 token/s while compiled version gives 50 token/s.

Expected behavior

StaticCache + torch.compile should also achieve ~2x speed-up as gpt-fast

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions