Skip to content

Cache updating when use_cache = False #32843

@ciaran-regan-ie

Description

@ciaran-regan-ie

System Info

  • transformers version: 4.44.0
  • Platform: macOS-14.5-arm64-arm-64bit
  • Python version: 3.10.0
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.4
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm experimenting with shuffling layers in a pre-trained model. The layer_idx inside the Attention object makes this difficult as described in this issue. To work around this, I'm setting use_cache = False, however, even with use_cache = False, an error is occurring as past_key_value.update is being called in the Attention forward pass. A simple solution would be to use use_cache in the forward pass by adding the following and logic:

if past_key_value is not None and use_cache:
    cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Here is my code to reproduce. The first run through will run because the layers have not switched, but the second run will fail as the cache attempts to update.

from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import *
import random
import torch

def run_with_custom_order(model, tokenizer, device, prompts, order):
    original_layers = model.model.layers
    layer_dict = {i: layer for i, layer in enumerate(original_layers)}
    shuffled_layers = torch.nn.ModuleList([layer_dict[i] for i in order])
    model.model.layers = shuffled_layers

    for prompt in prompts:
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        input_length = inputs["input_ids"].shape[1]
        outputs = model.generate(**inputs, max_new_tokens=20, pad_token_id=tokenizer.eos_token_id, use_cache=False)
        generated_text = tokenizer.decode(outputs[0, input_length:], skip_special_tokens=True)
        print(f"{generated_text}")

    model.model.layers = original_layers
    pass

def main():
    device = (
        "cuda" if torch.cuda.is_available() else 
        "mps" if torch.backends.mps.is_available() else 
        "cpu"
    )    
    llm_name = "microsoft/Phi-3-mini-4k-instruct"
    tokenizer = AutoTokenizer.from_pretrained(llm_name)
    model = AutoModelForCausalLM.from_pretrained(llm_name, torch_dtype=torch.bfloat16, trust_remote_code=True)
    model.to(device)
    model.config.pad_token_id = tokenizer.eos_token_id
    num_layers = len(model.model.layers)

    # Load questions and answers
    dataset_type = "mmlu"  # Change this to "default" for capitals dataset
    num_questions = 1
    questions, _ = load_qa(dataset_type, num_questions)
    prompts = [f"<|user|>\n{question}\nChoose A, B, C, or D:<|end|>\n<|assistant|>" for question in questions]
    
    order = list(range(num_layers))
    run_with_custom_order(model, tokenizer, device, prompts, order)

    random.shuffle(order)
    run_with_custom_order(model, tokenizer, device, prompts, order)

    pass

if __name__ == "__main__":
    main()

Expected behavior

When use_cache = False, the cache should not be updating, right?

Happy to help with PRs if you feel its necessary!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions