Skip to content

[Misc]: hidden states using vllm #3594

@ra-MANUJ-an

Description

@ra-MANUJ-an

Anything you want to discuss about vllm.

Following is a little piece of code to extract embeddings from a certain layer of LLM:

def process_row(prompt: str, model, tokenizer, layers_to_use: list, remove_period: bool):
    """
    Processes a row of data and returns the embeddings.
    """
    if remove_period:
        prompt = prompt.rstrip(". ")
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(inputs.input_ids, output_hidden_states=True, return_dict_in_generate=True, max_new_tokens=1, min_new_tokens=1)
    embeddings = {}
    for layer in layers_to_use:
        last_hidden_state = outputs.hidden_states[0][layer][0][-1]
        embeddings[layer] = [last_hidden_state.numpy().tolist()]
    return embeddings

It's pretty standard way, but it's pretty slow. Is there any way to use vllm to make it faster without needing to call generate function everytime? I've tried batching, but it's slow too. Any help is appreciated!

One way to get last hidden state values using vllm is as follows:

from vllm import LLM, SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceData, 
                           SequenceGroupMetadata, SequenceStatus)
from transformers import LlamaModel, LlamaTokenizer
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata


llm = LLM(model=path_to_llama2)


# Enable top-k sampling to reflect the accurate memory usage.
vocab_size = llm.llm_engine.workers[0].model.config.vocab_size
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = llm.llm_engine.workers[0].scheduler_config.max_num_batched_tokens
max_num_seqs = llm.llm_engine.workers[0].scheduler_config.max_num_seqs
prompt = train[0]
prompt_token_ids = llm.llm_engine.tokenizer.encode(prompt) #[2, 100, 524, 10]
seqs = []
    
group_id = 1
seq_data = SequenceData(prompt_token_ids)
seq = SequenceGroupMetadata(
    request_id=str(group_id),
    is_prompt=True,
    seq_data={group_id: seq_data},
    sampling_params=sampling_params,
    block_tables=None,
)
seqs.append(seq)
input_tokens, input_positions, input_metadata = llm.llm_engine.workers[0]._prepare_inputs(
    seqs)
prompt_len = len(seq_data.prompt_token_ids)
input_tokens = input_tokens[:prompt_len]
input_positions = input_positions[:prompt_len]
# Execute the model.
num_layers = llm.llm_engine.workers[0].model_config.get_num_layers(llm.llm_engine.workers[0].parallel_config)
tempOut = llm.llm_engine.workers[0].model.model(
    input_ids=input_tokens,
    positions=input_positions,
    kv_caches=[(None, None)] * num_layers,
    input_metadata=input_metadata,
    cache_events=None,
)
print(tempOut.size())

but this doesn't get me with all the hidden state embeddings (of all layers). Is there any other way to get such values in a faster manner?

Metadata

Metadata

Assignees

No one assigned

    Labels

    miscstaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions