Skip to content

Conversation

@sahilsuneja1
Copy link

@sahilsuneja1 sahilsuneja1 commented Apr 2, 2025

What does this PR do?

This PR adds support to use MLPSpeculator models for assisted generation, similar to it's support in TGI and vLLM

Model code originally authored by Davis Wertheimer @daviswer

List of already existing speculators here and here

Training recipes new speculators here and here

Usage example:

import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, MLPSpeculatorPreTrainedModel

def compare_assisted_generation(prompts, checkpoint, assistant_checkpoint):
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    inputs = tokenizer(prompts, return_tensors="pt").to(device=device)

    model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device=device, dtype=torch.bfloat16)
    assistant_model = MLPSpeculatorPreTrainedModel.from_pretrained(assistant_checkpoint).to(device=device, dtype=torch.bfloat16)
    model.eval()
    assistant_model.eval()

    if model.generation_config.pad_token_id is None:
        model.generation_config.pad_token_id = model.generation_config.eos_token_id

    generate_kwargs = {
        "do_sample":False,
        "temperature":None,
        "max_new_tokens":50,
        "output_hidden_states":True,
    }

    # warmup
    for _ in range(0,2):
        model.generate(**inputs, **generate_kwargs)
        model.generate(**inputs,  assistant_model=assistant_model, **generate_kwargs)

    start_time = time.time()
    outputs = model.generate(**inputs, **generate_kwargs)
    end_time = time.time()
    print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
    print(f"Generation without assistant; Time taken: {end_time-start_time} seconds")

    start_time = time.time()
    outputs = model.generate(**inputs,  assistant_model=assistant_model, **generate_kwargs)
    end_time = time.time()
    print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
    print(f"Generation with assistant; Time taken: {end_time-start_time} seconds")


torch.set_grad_enabled(False)
prompt = "Alice and Bob"
checkpoint = "meta-llama/Meta-Llama-3-8B-Instruct"
speculator_checkpoint = "ibm-ai-platform/llama3-8b-accelerator"
compare_assisted_generation(prompt, checkpoint, speculator_checkpoint)

Output from the above example on A100:

['Alice and Bob are two friends who are trying to solve a puzzle. They are given a set of numbers, and they need to find the sum of the numbers that are multiples of 3 or 5.\n\nHere is the set of numbers: 1, ']
Generation without assistant; Time taken: 1.150806188583374 seconds
['Alice and Bob are two friends who are trying to solve a puzzle. They are given a set of numbers, and they need to find the sum of the numbers that are multiples of 3 or 5.\n\nHere is the set of numbers: 1, ']
Generation with assistant; Time taken: 0.6626832485198975 seconds

Who can review?

@gante

Signed-off-by: Sahil Suneja <sahilsuneja@gmail.com>
MLPSPeculator originally authored by Davis Wertheimer at:
https://github.com/foundation-model-stack/fms-extras/blob/main/fms_extras/models/speculator.py
@github-actions github-actions bot marked this pull request as draft April 2, 2025 23:09
@github-actions
Copy link
Contributor

github-actions bot commented Apr 2, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@sahilsuneja1 sahilsuneja1 marked this pull request as ready for review April 2, 2025 23:12
@gante gante requested review from gante and removed request for ArthurZucker and Rocketknight1 April 3, 2025 09:37
@gante
Copy link
Member

gante commented Apr 3, 2025

Hey @sahilsuneja1 👋

We're currently pausing the addition of all non-critical decoding methods, including assisted generation variations. This is because we're designing a new way of adding transformers-compatible decoding methods (see this draft PR)

TL;DR, if the plan goes forward, new decoding methods will live on the hub, and transformers will only hold the core decoding strategies 🤗

Signed-off-by: Sahil Suneja <sahilsuneja@gmail.com>
@sahilsuneja1
Copy link
Author

Thanks @gante, will track progress on it and revisit when the change is made!

Signed-off-by: Sahil Suneja <sahilsuneja@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants