-
Notifications
You must be signed in to change notification settings - Fork 14
prefill decode microbenchmark for QWen3 #699
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# MICROBENCHAMRKING IS EXPERIMENTAL AND NOT SUPPORTED FOR ALL MODELS AND FLEXIBLE WORKLOADS | ||
|
||
The Goal of microbenchmarking is to strip the model call from VLLM Dependencies (Scheduler and KV Cache Manager) for efficient debugging and performance optimization of just model call. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A rewording suggestion: |
||
|
||
The current version is ** working on pinned main ** | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "The current implementation runs on the following pinned version of the main branch:" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, is not pinned anymore, I will just mention it so we can backtrack, but as long as model call API remains unchanged, it should work |
||
|
||
``` | ||
Commit ID 5797c31acb0010cf8c54ba9218bacf96d8a1260e | ||
``` | ||
|
||
> ⚠️ The microbenchmarking code **does not support all models and features and is currently used for debugging and optimizing static workloads | ||
|
||
**Only tested model for microbenchmarking is QWEN3-32B** | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "The only model validated for microbenchmarking is Qwen3-32B." Don't we support DeepSeek as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted to keep this, and remove once we verify the runs together |
||
|
||
## Params needed by microbenchmarking code | ||
|
||
### `max_seq_len` - | ||
max model len this is length of the model including number of prefill and decode tokens | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "max model len this is the maximum supported length of each request. Typically this equals the maximum number of prefill + decode tokens across all requests." |
||
|
||
### `phase` - | ||
|
||
phase of the model, supported modes are prefill and decode | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Inference phase - supported phases are 'prefill' and 'decode'." |
||
|
||
### `decode_offset_from_prefill` - | ||
used in decode primarily, if the value is 1, it means 1st token after prefill | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "This offset indicates the decode step index to profile. E.g. setting a value of 10 corresponds to profiling the 10th decode step." |
||
|
||
### `model_hf_config` - | ||
path to json file where HFConfig is saved. We need this because we dont want to download from huggingface. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "We need this to avoid having to download the model from huggingface everytime." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, if we are just downloading the config, would it be a problem to download from HF? |
||
|
||
### `num_block_override` - | ||
number of blocks in KV Cache. This is kept as an override because we need the KV Cache part to be representative, a good value is obtained from | ||
`offline_inference.py` runs. | ||
|
||
### `max_prefill_len` - | ||
max length of prefill sequence | ||
|
||
### `max_num_sequence` - | ||
is the maximum number of sequence supported by model. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sequence -> sequences |
||
|
||
### `Caveats are` : | ||
|
||
i) In Prefill phase - `max_num_sequence` = max_seq_len // max_prefill_len | ||
|
||
ii) In Decode phase - `max_num_sequence` < `max_seq_len` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does the maximum sequence length influence the maximum number of allowed sequences (and vice versa)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it help if we added --max-num-batched-tokens from vLLM? This variable corresponds to the maximum total tokens that we can process in a single batch. |
||
|
||
### `model_call_steps` - | ||
number of times the model is to be called | ||
|
||
### `block_size` - | ||
or same as `page_size` for KV Cache | ||
|
||
### `additional_config` - | ||
example of additional config | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "This is used to propagate tpu_commons-specific arguments (e.g. sharding and quantization settings)." |
||
|
||
``` | ||
'{"sharding": {"sharding_strategy": {"tensor_parallelism": 8, "data_parallelism": 1}}, "quantization": { "qwix": { "rules": [{ "module_path": ".*", "weight_qtype": "float8_e4m3fn", "act_qtype": "float8_e4m3fn"}]}}}' --model_config='{"model":"Qwen/Qwen3-32B"}' | ||
``` | ||
|
||
### `model_config` - | ||
--model_config='{"model":"Qwen/Qwen3-32B"}' | ||
|
||
### `new_model_design` - | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we using this? it seems like you are setting this via env variables in your code. |
||
True if microbenchmarking is done for new models like L4 and DeepSeek v3 | ||
|
||
### `trace_dir` - | ||
|
||
local location where traces are stored. Default value is `/tmp/tpu_commons_traces` | ||
|
||
## Example command to run Microbenchmark | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for adding these! |
||
|
||
### | ||
## Decode | ||
|
||
``` | ||
python examples/microbenchamarking/microbenchmark_app.py --additional_config='{"sharding": {"sharding_strategy": {"tensor_parallelism": 8, "data_parallelism": 1}}}' --model_config='{"model":"Qwen/Qwen3-32B"}' --phase='decode' --max_seq_len=4096 --max_num_seq=2048 --model_hf_config="examples/microbenchamarking/hf_configs/qwen3_32b_hf_config.json" | ||
|
||
``` | ||
|
||
## Prefill | ||
|
||
``` | ||
python examples/microbenchamarking/microbenchmark_app.py --additional_config='{"sharding": {"sharding_strategy": {"tensor_parallelism": 8, "data_parallelism": 1}}}' --model_config='{"model":"Qwen/Qwen3-32B"}' --phase='prefill' --max_seq_len=1024 --max_num_seq=2 --max_prefill_len=512 --model_hf_config="examples/microbenchamarking/hf_configs/qwen3_32b_hf_config.json" | ||
|
||
``` | ||
|
||
Notice that max_num_seq = 2 as maximum of 2 sequences can fit with 512 as max_prefill_len |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
{ | ||
"architectures": ["DeepseekV3ForCausalLM"], | ||
"attention_bias": "False", | ||
"attention_dropout": 0.0, | ||
"auto_map": { | ||
"AutoConfig": "configuration_deepseek.DeepseekV3Config", | ||
"AutoModel": "modeling_deepseek.DeepseekV3Model", | ||
"AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM" | ||
}, | ||
"bos_token_id": 0, | ||
"eos_token_id": 1, | ||
"ep_size": 1, | ||
"first_k_dense_replace": 3, | ||
"hidden_act": "silu", | ||
"hidden_size": 7168, | ||
"initializer_range": 0.02, | ||
"intermediate_size": 18432, | ||
"kv_lora_rank": 512, | ||
"max_position_embeddings": 163840, | ||
"model_type": "deepseek_v3", | ||
"moe_intermediate_size": 2048, | ||
"moe_layer_freq": 1, | ||
"n_group": 8, | ||
"n_routed_experts": 256, | ||
"n_shared_experts": 1, | ||
"norm_topk_prob": "True", | ||
"num_attention_heads": 128, | ||
"num_experts_per_tok": 8, | ||
"num_hidden_layers": 61, | ||
"num_key_value_heads": 128, | ||
"num_nextn_predict_layers": 1, | ||
"q_lora_rank": 1536, | ||
"qk_nope_head_dim": 128, | ||
"qk_rope_head_dim": 64, | ||
"quantization_config": { | ||
"activation_scheme": "dynamic", | ||
"fmt": "e4m3", | ||
"quant_method": "fp8", | ||
"weight_block_size": [128, 128] | ||
}, | ||
"rms_norm_eps": 1e-06, | ||
"rope_scaling": { | ||
"beta_fast": 32, | ||
"beta_slow": 1, | ||
"factor": 40, | ||
"mscale": 1.0, | ||
"mscale_all_dim": 1.0, | ||
"original_max_position_embeddings": 4096, | ||
"type": "yarn" | ||
}, | ||
"rope_theta": 10000, | ||
"routed_scaling_factor": 2.5, | ||
"scoring_func": "sigmoid", | ||
"tie_word_embeddings": "False", | ||
"topk_group": 4, | ||
"topk_method": "noaux_tc", | ||
"torch_dtype": "bfloat16", | ||
"transformers_version": "4.33.1", | ||
"use_cache": "True", | ||
"v_head_dim": 128, | ||
"vocab_size": 129280 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
{ | ||
"architectures": ["Qwen3ForCausalLM"], | ||
"attention_bias": "False", | ||
"attention_dropout": 0.0, | ||
"bos_token_id": 151643, | ||
"eos_token_id": 151645, | ||
"head_dim": 128, | ||
"hidden_act": "silu", | ||
"hidden_size": 5120, | ||
"initializer_range": 0.02, | ||
"intermediate_size": 25600, | ||
"max_position_embeddings": 40960, | ||
"max_window_layers": 64, | ||
"model_type": "qwen3", | ||
"num_attention_heads": 64, | ||
"num_hidden_layers": 64, | ||
"num_key_value_heads": 8, | ||
"rms_norm_eps": 1e-06, | ||
"rope_scaling": "None", | ||
"rope_theta": 1000000, | ||
"sliding_window": "None", | ||
"tie_word_embeddings": "False", | ||
"torch_dtype": "bfloat16", | ||
"transformers_version": "4.51.0", | ||
"use_cache": "True", | ||
"use_sliding_window": "False", | ||
"vocab_size": 151936 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
# pytype: disable=import-error | ||
# pytype: disable=module-attr | ||
# pytype: disable=attribute-error | ||
# pytype: disable=wrong-arg-types | ||
import json | ||
import os | ||
import time | ||
from typing import Any, Sequence | ||
|
||
import jax | ||
from absl import app, flags | ||
from flax import nnx | ||
from microbenchmark_input_utils import InputArgs, InputCreator | ||
from microbenchmark_utils import Sampler, init_mesh | ||
|
||
from tpu_commons.logger import init_logger | ||
from tpu_commons.mock.vllm_config_utils import (CacheConfig, ModelConfig, | ||
VllmConfig) | ||
from tpu_commons.models.jax.model_loader import get_model | ||
|
||
logger = init_logger(__name__) | ||
|
||
_MAX_SEQ_LEN = flags.DEFINE_integer("max_seq_len", 4096, | ||
"Maximum sequence length.") | ||
|
||
_PHASE = flags.DEFINE_string("phase", "decode", "Phase to benchmark.") | ||
|
||
_DECODE_OFFSET_FROM_PREFILL = flags.DEFINE_integer( | ||
"decode_offset_from_prefill", | ||
0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to the readme, does offset of 0 correspond to the last prefill token? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At least set it to 1. |
||
"Offset from the prefill length to start decoding.", | ||
) | ||
|
||
_MODEL_NAME = flags.DEFINE_string( | ||
"model_name", "qwen3-32b", | ||
"Model name to benchmark. Supported models: qwen3-32b, deepseek_v3") | ||
|
||
_MODEL_HF_CONFIG = flags.DEFINE_string( | ||
"model_hf_config", | ||
"", | ||
"Model HF config in json format.", | ||
) | ||
|
||
# this has to be overriden as the calculation is not very correct yet on microbenchmark side. | ||
#TODO: @(vijaya) Fix the calculation and remove this flag as an override. | ||
mailvijayasingh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_KV_NUM_BLOCK_OVERRIDE = flags.DEFINE_integer( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you share some guidelines in the readme for how you are calculating this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So far running on v7x-2 (depending on how much TPU memory you have right now) to determine the number of total blocks. |
||
"num_block_override", | ||
2048, | ||
"Override number of blocks.", | ||
) | ||
|
||
_MAX_PREFILL_LEN = flags.DEFINE_integer("max_prefill_len", 512, | ||
"Maximum prefill length.") | ||
|
||
# for prefill, max_num_seq = max_seq_len // max_prefill_len | ||
# for decode max_num_seq = max_seq_len // 1 | ||
_MAX_NUM_SEQ = flags.DEFINE_integer( | ||
"max_num_seq", | ||
2048, | ||
"maximum number of sequences to be benchmarked.", | ||
) | ||
|
||
_MODEL_CALL_STEPS = flags.DEFINE_integer("model_call_steps", 5, | ||
"Number of model call steps.") | ||
|
||
_BLOCK_SIZE = flags.DEFINE_integer("block_size", 128, "Block size.") | ||
_SAMPLER_TYPE = flags.DEFINE_string("sampler_type", "fixed", "Sampler type.") | ||
_SAMPLER_STD = flags.DEFINE_float("sampler_std", 1.0, | ||
"Sampler standard deviation.") | ||
_ADDITIONAL_CONFIG = flags.DEFINE_string( | ||
"additional_config", | ||
"", | ||
"Additional configuration for the model.", | ||
) | ||
_MODEL_CONFIG = flags.DEFINE_string( | ||
"model_config", | ||
mailvijayasingh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"", | ||
"Model configuration for the model.", | ||
) | ||
|
||
NEW_MODEL_DESIGN = flags.DEFINE_string( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we using this? |
||
"NEW_MODEL_DESIGN", | ||
"True", | ||
"Model design to use. If True, uses the new model design which is used for newer models like DeepseekV3 and Llama4", | ||
) | ||
|
||
_TRACE_DIR = flags.DEFINE_string( | ||
"trace_dir", | ||
"/tmp/tpu_commons_traces", | ||
"Directory to save the trace files.", | ||
) | ||
|
||
|
||
def get_hf_config_attribute_map(model_hf_config: str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe creating this function is overkill? It's just one extra line of code to convert the string to a json =] |
||
with open(model_hf_config, 'r') as file: | ||
# Load the JSON data from the file | ||
data = json.load(file) | ||
return data | ||
|
||
|
||
def validate_command_line_args(): | ||
if _PHASE.value not in ["prefill", "decode"]: | ||
raise ValueError( | ||
f"Phase {_PHASE.value} not supported. Choose either 'prefill' or 'decode'." | ||
) | ||
if _MAX_SEQ_LEN.value % _BLOCK_SIZE.value != 0: | ||
raise ValueError( | ||
f"Max sequence length {_MAX_SEQ_LEN.value} must be divisible by block size {_BLOCK_SIZE.value}." | ||
) | ||
|
||
if _PHASE.value == "prefill": | ||
if _MAX_SEQ_LEN.value % _MAX_PREFILL_LEN.value != 0: | ||
raise ValueError( | ||
f"Max sequence length {_MAX_SEQ_LEN.value} must be divisible by max prefill length {_MAX_PREFILL_LEN.value}." | ||
) | ||
if _MAX_SEQ_LEN.value // _MAX_PREFILL_LEN.value != _MAX_NUM_SEQ.value: | ||
raise ValueError( | ||
f"Max number of sequences {_MAX_NUM_SEQ.value} must be equal to max sequence length {_MAX_SEQ_LEN.value} divided by max prefill length {_MAX_PREFILL_LEN.value}." | ||
) | ||
|
||
|
||
class Benchmarker: | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a quick docstring on what this class does? |
||
def __init__(self, vllm_config: VllmConfig, model: Any, mesh: Any, | ||
sampler: Sampler, rng: nnx.Rngs, model_hf_config, state, | ||
trace_directory): | ||
""" | ||
Class takes in VllmConfig, model function, mesh, sampler, rng, model_hf_config, state and trace_directory. | ||
and benchmarks the model for the given phase after creating the input using InputCreator class. | ||
""" | ||
self.vllm_config = vllm_config | ||
self.model = model | ||
self.mesh = mesh | ||
self.sampler = sampler | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not used? |
||
self.rng = rng | ||
self.model_hf_config = model_hf_config | ||
self.state = state | ||
self.trace_directory = trace_directory | ||
|
||
def benchmark(self, phase: str): | ||
input_args = InputArgs( | ||
block_size=_BLOCK_SIZE.value, | ||
max_num_seq=_MAX_NUM_SEQ.value, | ||
min_prefill_len=1, | ||
max_prefill_len=_MAX_PREFILL_LEN.value, | ||
max_model_len=_MAX_SEQ_LEN.value, | ||
decode_offset_from_prefill=_DECODE_OFFSET_FROM_PREFILL.value, | ||
sampler=self.sampler, | ||
model_hf_config=self.model_hf_config, | ||
phase=phase, | ||
num_blocks_override=_KV_NUM_BLOCK_OVERRIDE.value) | ||
|
||
input_creator = InputCreator(input_args=input_args, | ||
sharding=None, | ||
mesh=self.mesh, | ||
rng=self.rng) | ||
model_input = input_creator.create_input(phase=phase) | ||
|
||
jax.profiler.start_trace(self.trace_directory) | ||
start_time = time.time() | ||
kv_caches, act = self.model( | ||
self.state, | ||
model_input.kv_caches, | ||
model_input.input_ids, | ||
model_input.attention_metadata, | ||
) | ||
|
||
act.block_until_ready() | ||
end_time = time.time() | ||
jax.profiler.stop_trace() | ||
logger.info( | ||
f"Time taken for model call in phase {phase}: {end_time - start_time} seconds. and profile trace is saved in {self.trace_directory}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "seconds. and" -> "seconds\nProfile" |
||
) | ||
|
||
|
||
def main(argv: Sequence[str]): | ||
sampler = Sampler(type=_SAMPLER_TYPE.value, std=_SAMPLER_STD.value) | ||
rng = nnx.Rngs(params=0) | ||
vllm_config = VllmConfig( | ||
additional_config=json.loads(_ADDITIONAL_CONFIG.value), | ||
model_config=ModelConfig(**json.loads(_MODEL_CONFIG.value)), | ||
cache_config=CacheConfig(block_size=_BLOCK_SIZE.value), | ||
) | ||
|
||
validate_command_line_args() | ||
|
||
vllm_config.model_config.hf_config.attribute_map = get_hf_config_attribute_map( | ||
_MODEL_HF_CONFIG.value) | ||
|
||
mesh = init_mesh(vllm_config, jax.devices()) | ||
model_fn, compute_logits_fn, get_multimodal_embeddings_fn, get_input_embeddings_fn, state = get_model( | ||
vllm_config, | ||
rng.params(), | ||
mesh, | ||
) | ||
|
||
benchmarker = Benchmarker(vllm_config, model_fn, mesh, sampler, rng, | ||
vllm_config.model_config.hf_config, state, | ||
_TRACE_DIR.value) | ||
for _ in range(_MODEL_CALL_STEPS.value): | ||
benchmarker.benchmark(_PHASE.value) | ||
|
||
|
||
if __name__ == "__main__": | ||
# uncomment below line to enable new model design | ||
# os.environ['NEW_MODEL_DESIGN'] = 'True' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added a comment to uncomment only when new model design |
||
os.environ['JAX_RANDOM_WEIGHTS'] = 'True' | ||
mailvijayasingh marked this conversation as resolved.
Show resolved
Hide resolved
mailvijayasingh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
os.environ['TPU_BACKEND_TYPE'] = 'JAX' | ||
app.run(main) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: VLLM -> vLLM