Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/inc_measure_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"names": []
},
"blacklist": {
"types": ["KVCache", "VLLMKVCache"],
"types": [],
"names": ["lm_head"]
},
"quantize_weight": false,
Expand Down
2 changes: 1 addition & 1 deletion scripts/inc_quant_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"names": []
},
"blacklist": {
"types": ["KVCache", "VLLMKVCache"],
"types": [],
"names": ["lm_head"]
},
"dump_stats_path": "./nc_workspace_tmp/inc_measure_output"
Expand Down
276 changes: 276 additions & 0 deletions scripts/n2_prepare_tp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
from vllm import LLM, SamplingParams

import argparse
import os

os.environ["VLLM_MOE_N_SLICE"] = "8"
os.environ["VLLM_EP_SIZE"] = "8"
os.environ["VLLM_TP_SIZE"] = "8"
os.environ["VLLM_SKIP_WARMUP"] = "true"

from typing import Any, List, Tuple
from transformers import (PreTrainedTokenizerBase, AutoTokenizer)
import random
import datasets
from vllm.utils import reset_seed
reset_seed()
# get file location
file_path = os.path.abspath(__file__)
dataset_path = os.path.join(os.path.dirname(file_path), "../benchmarks")

model_path = "/data/models/DeepSeek-R1/"
model_path = "/hf/hf_models/DeepSeek-R1"
# model_path = "deepseek-ai/DeepSeek-V2-Lite"
model_path = "/mnt/disk5/hf_models/DeepSeek-R1-BF16"
model_path = "/software/users/yiliu4/HF_HOME/hub/deepseekv3-bf16-4l-real"
# Parse the command-line arguments.
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default=model_path, help="The model path.")
parser.add_argument("--tokenizer", type=str, default=model_path, help="The model path.")
parser.add_argument("--tp_size", type=int, default=8, help="The number of threads.")
parser.add_argument("--ep_size", type=int, default=8, help="The number of threads.")
parser.add_argument("--dataset", type=str, default=None, help="The dataset.")
parser.add_argument("--isl", type=int, default=1024, help="input sequence length.")
parser.add_argument("--osl", type=int, default=128, help="output sequence length.")
parser.add_argument("--nprompts", type=int, default=4, help="The number of prompts.")
parser.add_argument("--random", action="store_true", help="Randomly sample prompts.")
parser.add_argument("--quant", action="store_true", help="quant")
args = parser.parse_args()

# os.environ["VLLM_SKIP_WARMUP"] = "true"
# os.environ["HABANA_VISIBLE_DEVICES"] = "ALL"
# os.environ['HABANA_VISIBLE_MODULES'] ='0,1,2,3,4,5,6,7'
# os.environ["PT_HPU_ENABLE_LAZY_COLLECTIVES"] = "true"
# os.environ["PT_HPU_WEIGHT_SHARING"] = "0"
# os.environ['PT_HPUGRAPH_DISABLE_TENSOR_CACHE']='1'
# os.environ['GLOO_SOCKET_IFNAME']='eth0'

# os.environ["VLLM_MOE_N_SLICE"] = "1" if args.ep_size > 1 else "4"
# os.environ["VLLM_EP_SIZE"] = f"{args.ep_size}"
# os.environ["VLLM_MLA_DISABLE_REQUANTIZATION"] = "1"

# os.environ["VLLM_RAY_DISABLE_LOG_TO_DRIVER"] = "0"
# os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "0"
# os.environ["RAY_DEDUP_LOGS"] = "1"
# os.environ["VLLM_LOGGING_LEVEL"] = "DEBUG"

# ==-------------------------------------------------------------------------==
# Calibration parameters
least_tokens = 1024
num_samples = 512
max_new_tokens = 32
seed = 42
# https://github.com/deepseek-ai/DeepSeek-R1/blob/main/README.md#deepseek-r1-evaluation
"""
... benchmarks requiring sampling, we use a temperature of 0.6, a top-p value of 0.95...
"""
temperature = 0.6
temperature = 0 # greedy sample
top_p = 0.95
# ==-------------------------------------------------------------------------==


def sample_sonnet_requests(
dataset_path: str,
num_requests: int,
input_len: int,
prefix_len: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, str, int, int, None]]:
assert (
input_len > prefix_len
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."

# Load the dataset.
with open(dataset_path, encoding='utf-8') as f:
poem_lines = f.readlines()

# Tokenize the poem lines.
poem_token_ids = tokenizer(poem_lines).input_ids
average_poem_len = sum(
len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids)

# Base prefix for all requests.
base_prompt = "Pick as many lines as you can from these poem lines:\n"
base_message = [{
"role": "user",
"content": base_prompt,
}]
base_prompt_formatted = tokenizer.apply_chat_template(
base_message, add_generation_prompt=True, tokenize=False)
base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)

assert (
input_len > base_prompt_offset
), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
num_input_lines = round(
(input_len - base_prompt_offset) / average_poem_len)

# First approximately `prefix_len` number of tokens in the
# prompt are fixed poem lines.
assert (
prefix_len > base_prompt_offset
), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."

num_prefix_lines = round(
(prefix_len - base_prompt_offset) / average_poem_len)
prefix_lines = poem_lines[:num_prefix_lines]

# Sample the rest of lines per request.
sampled_requests: List = []
for _ in range(num_requests):
num_lines_needed = num_input_lines - num_prefix_lines
sampled_lines = "".join(prefix_lines +
random.choices(poem_lines, k=num_lines_needed))


prompt = f"{base_prompt}{sampled_lines}"
message = [
{
"role": "user",
"content": prompt,
},
]
prompt_formatted = tokenizer.apply_chat_template(
message, add_generation_prompt=True, tokenize=False)
sampled_requests.append(prompt_formatted)

return sampled_requests, None

def sample_gsm8k_requests(
num_requests: int, tokenizer: PreTrainedTokenizerBase, do_random: bool = False
) -> List[Tuple[str, str]]:
# Load the dataset from huggingface.
dataset = datasets.load_dataset("openai/gsm8k", "main")
prompts = dataset["train"]["question"]
expected_responses = dataset["train"]["answer"]
few_shots = 5
base_prompt = [f"Question: {prompts[i]}\nAnswer: {expected_responses[i]}\n" for i in range(few_shots)]
base_prompt = "\n".join(base_prompt)
base_prompt = f"{base_prompt}\n"

# Sample the requests.
sampled_requests: List = []
sampled_response: List = []
for j in range(num_requests):
i = random.choice(range(len(prompts[few_shots:]))) if do_random else j + few_shots
prompt = f"{base_prompt}Question: {prompts[i]}\nAnswer: "
# message = [
# {
# "role": "user",
# "content": prompt,
# },
# ]
# prompt = tokenizer.apply_chat_template(
# message, add_generation_prompt=True, tokenize=False)
expected_response = expected_responses[i]
sampled_requests.append(prompt)
sampled_response.append(expected_response)

return sampled_requests, sampled_response

if __name__ == "__main__":

# Sample prompts.

if args.dataset == "sonnet":
# Sample sonnet requests.
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
prompts, gt = sample_sonnet_requests(
dataset_path=f"{dataset_path}/sonnet.txt",
num_requests=args.nprompts,
input_len=args.isl,
prefix_len=200,
tokenizer=tokenizer,
)
elif args.dataset == "gsm8k":
# Sample GSM8K requests.
args.osl=128
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
prompts, gt = sample_gsm8k_requests(
num_requests=args.nprompts,
tokenizer=tokenizer,
do_random=args.random,
)
else:
prompts = [
"Hello, my name is",
# "The president of the United States is",
# "The capital of France is",
"The future of AI is",
]

from utils import get_prompts, get_prompt_token_ids, get_pile_prompts

# prompts = get_prompts()
# Got the unseen prompts.
# smoke_num_samples = 10
# prompts = get_pile_prompts(args.model, num_samples * 2)
# smoke_prompts = [
# "Hello, my name is",
# "The president of the United States is",
# "The capital of France is",
# "The future of AI is",
# ]

# smoke_prompts = smoke_prompts + prompts[-smoke_num_samples:]
smoke_prompts = get_prompts()
prompt_token_ids = get_prompt_token_ids(
args.model, smoke_prompts, least_tokens
)
gt = None
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_new_tokens,
truncate_prompt_tokens=least_tokens,
)
model = args.model
quantization = "inc_q" if args.quant else "inc_p"
if args.quant:
llm = LLM(
model=model,
tokenizer=args.tokenizer,
tensor_parallel_size=args.tp_size,
distributed_executor_backend='mp',
trust_remote_code=True,
quantization=quantization,
kv_cache_dtype="fp8_inc",
max_model_len=16384,
dtype="bfloat16",
)
else:
llm = LLM(
model=model,
tokenizer=args.tokenizer,
tensor_parallel_size=args.tp_size,
distributed_executor_backend='mp',
trust_remote_code=True,
quantization=quantization,
max_model_len=16384,
dtype="bfloat16",
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(
# prompts=smoke_prompts,
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids
)
# Print the outputs.
for output_i in range(len(outputs)):
output = outputs[output_i]
gt_i = None if gt is None else gt[output_i]
prompt_token_ids = output.prompt_token_ids
generated_text = output.outputs[0].text
print("====================================")
prompt = output.prompt
print(f"prompt: {prompt!r}")
print(f"prompt_token_ids[:10]: {prompt_token_ids[:10]!r}")
print(f"prompt_token_ids[-10:]: {prompt_token_ids[-10:]!r}")
print(f"Generated text: {generated_text!r}")
print(f"Ground truth: {gt_i!r}")
print("====================================")

llm.llm_engine.model_executor.shutdown()
33 changes: 23 additions & 10 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from vllm.logger import init_logger

logger = init_logger(__name__)


from vllm.logger import rank_debug
from habana_frameworks.torch import core as htcore
class HPUAttentionBackend(AttentionBackend):

@staticmethod
Expand Down Expand Up @@ -143,7 +143,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
cross_attn_bias: Optional[torch.Tensor] = None


class HPUMLAImpl(MLACommonImpl[HPUAttentionMetadata]):
class HPUMLAImpl(MLACommonImpl[HPUAttentionMetadata], torch.nn.Module):

def __init__(
self,
Expand All @@ -159,11 +159,11 @@ def __init__(
attn_type: str,
# MLA Specific Arguments
**kwargs) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
torch.nn.Module.__init__(self)
MLACommonImpl.__init__(self, num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**kwargs)

self.matmul_qk = Matmul()
self.softmax = Softmax()
self.matmul_av = Matmul()
Expand Down Expand Up @@ -215,13 +215,16 @@ def forward(
assert hasattr(attn_metadata, "input_positions"), f"attn meta: {attn_metadata}"

if not is_prefill:
rank_debug(f"decoding hidden_states_or_q_c: {hidden_states_or_q_c.shape}")
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
input_positions = attn_metadata.input_positions.view(-1)
q_pe, k_pe = \
self.rotary_emb(input_positions, q_pe, k_pe)
else:
rank_debug(f"prefill hidden_states_or_q_c: {hidden_states_or_q_c.shape}")
rank_debug(f"self.q_proj : {self.q_proj}")
q = self.q_proj(hidden_states_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)

Expand All @@ -231,7 +234,7 @@ def forward(
# TODO(lucas): there must be a nicer way to write this line
q[..., self.qk_nope_head_dim:], k_pe = \
self.rotary_emb(input_positions, q_pe, k_pe)

htcore.mark_step()
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets

Expand All @@ -250,14 +253,20 @@ def forward(
if kv_cache is not None and len(kv_cache) == 2:
# print(f"k cache shape: {kv_cache[0].shape}")
# print(f"v cache shape: {kv_cache[1].shape}")
# print(f"latent vec k shape: {latent_vec_k.shape}")
# print(f"latent vec v shape: {latent_vec_v.shape}")
latent_vec_v = latent_vec_k[..., :self.kv_lora_rank]
latent_vec_k = latent_vec_k[..., self.kv_lora_rank:]
rank_debug(f"latent vec k shape: {latent_vec_k.shape}, dtype {latent_vec_k.dtype}")
rank_debug(f"latent vec v shape: {latent_vec_v.shape}, dtype {latent_vec_v.dtype}")
rank_debug(f"v_cache : {kv_cache[0].shape}, dtype: {kv_cache[0].dtype}")
rank_debug(f"k_cache : {kv_cache[1].shape}, dtype: {kv_cache[1].dtype}")
htcore.mark_step()
k_cache = self.latent_cache_k(latent_vec_k, kv_cache[0], block_indices,
block_offsets)
v_cache = self.latent_cache_v(latent_vec_v, kv_cache[1], block_indices,
block_offsets)
htcore.mark_step()
rank_debug(f"v_cache : {v_cache.shape}")
rank_debug(f"k_cache : {k_cache.shape}")
kv_cache = (k_cache, v_cache)

# if torch.distributed.get_rank() == 0:
Expand Down Expand Up @@ -312,8 +321,12 @@ def _forward_prefill(
attn_output = out\
.view(batch_size, -1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
.reshape(batch_size, -1, self.num_heads * v.shape[-1])

return self.o_proj(attn_output)[0]
rank_debug(f"attn_output shape: {attn_output.shape}")
htcore.mark_step()
out = self.o_proj(attn_output)[0]
htcore.mark_step()
rank_debug(f"out shape: {out.shape}")
return out

def _forward_decode(
self,
Expand Down
Loading