-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
Closed as not planned
Labels
Description
Anything you want to discuss about vllm.
Following is a little piece of code to extract embeddings from a certain layer of LLM:
def process_row(prompt: str, model, tokenizer, layers_to_use: list, remove_period: bool):
"""
Processes a row of data and returns the embeddings.
"""
if remove_period:
prompt = prompt.rstrip(". ")
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(inputs.input_ids, output_hidden_states=True, return_dict_in_generate=True, max_new_tokens=1, min_new_tokens=1)
embeddings = {}
for layer in layers_to_use:
last_hidden_state = outputs.hidden_states[0][layer][0][-1]
embeddings[layer] = [last_hidden_state.numpy().tolist()]
return embeddingsIt's pretty standard way, but it's pretty slow. Is there any way to use vllm to make it faster without needing to call generate function everytime? I've tried batching, but it's slow too. Any help is appreciated!
One way to get last hidden state values using vllm is as follows:
from vllm import LLM, SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceData,
SequenceGroupMetadata, SequenceStatus)
from transformers import LlamaModel, LlamaTokenizer
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
llm = LLM(model=path_to_llama2)
# Enable top-k sampling to reflect the accurate memory usage.
vocab_size = llm.llm_engine.workers[0].model.config.vocab_size
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = llm.llm_engine.workers[0].scheduler_config.max_num_batched_tokens
max_num_seqs = llm.llm_engine.workers[0].scheduler_config.max_num_seqsprompt = train[0]
prompt_token_ids = llm.llm_engine.tokenizer.encode(prompt) #[2, 100, 524, 10]
seqs = []
group_id = 1
seq_data = SequenceData(prompt_token_ids)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
)
seqs.append(seq)
input_tokens, input_positions, input_metadata = llm.llm_engine.workers[0]._prepare_inputs(
seqs)
prompt_len = len(seq_data.prompt_token_ids)
input_tokens = input_tokens[:prompt_len]
input_positions = input_positions[:prompt_len]
# Execute the model.
num_layers = llm.llm_engine.workers[0].model_config.get_num_layers(llm.llm_engine.workers[0].parallel_config)
tempOut = llm.llm_engine.workers[0].model.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=[(None, None)] * num_layers,
input_metadata=input_metadata,
cache_events=None,
)
print(tempOut.size())but this doesn't get me with all the hidden state embeddings (of all layers). Is there any other way to get such values in a faster manner?
J0hnArren, BiEchi, spidercatfly, archit-spec, teej and 7 more