Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions examples/microbenchamarking/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# MICROBENCHAMRKING IS EXPERIMENTAL AND NOT SUPPORTED FOR ALL MODELS AND FLEXIBLE WORKLOADS

The Goal of microbenchmarking is to strip the model call from VLLM Dependencies (Scheduler and KV Cache Manager) for efficient debugging and performance optimization of just model call.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: VLLM -> vLLM

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A rewording suggestion:
"The goal of microbenchmarking is to strip out the vLLM server layer and focus on just profiling the model calls."


The current version is ** working on pinned main **
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"The current implementation runs on the following pinned version of the main branch:"

Copy link
Collaborator Author

@mailvijayasingh mailvijayasingh Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, is not pinned anymore, I will just mention it so we can backtrack, but as long as model call API remains unchanged, it should work


```
Commit ID 5797c31acb0010cf8c54ba9218bacf96d8a1260e
```

> ⚠️ The microbenchmarking code **does not support all models and features and is currently used for debugging and optimizing static workloads

**Only tested model for microbenchmarking is QWEN3-32B**
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"The only model validated for microbenchmarking is Qwen3-32B."

Don't we support DeepSeek as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to keep this, and remove once we verify the runs together


## Params needed by microbenchmarking code

### `max_seq_len` -
max model len this is length of the model including number of prefill and decode tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"max model len this is the maximum supported length of each request. Typically this equals the maximum number of prefill + decode tokens across all requests."


### `phase` -

phase of the model, supported modes are prefill and decode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Inference phase - supported phases are 'prefill' and 'decode'."


### `decode_offset_from_prefill` -
used in decode primarily, if the value is 1, it means 1st token after prefill
Copy link
Collaborator

@gpolovets1 gpolovets1 Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"This offset indicates the decode step index to profile. E.g. setting a value of 10 corresponds to profiling the 10th decode step."


### `model_hf_config` -
path to json file where HFConfig is saved. We need this because we dont want to download from huggingface.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"We need this to avoid having to download the model from huggingface everytime."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, if we are just downloading the config, would it be a problem to download from HF?


### `num_block_override` -
number of blocks in KV Cache. This is kept as an override because we need the KV Cache part to be representative, a good value is obtained from
`offline_inference.py` runs.

### `max_prefill_len` -
max length of prefill sequence

### `max_num_sequence` -
is the maximum number of sequence supported by model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sequence -> sequences


### `Caveats are` :

i) In Prefill phase - `max_num_sequence` = max_seq_len // max_prefill_len

ii) In Decode phase - `max_num_sequence` < `max_seq_len`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the maximum sequence length influence the maximum number of allowed sequences (and vice versa)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it help if we added --max-num-batched-tokens from vLLM? This variable corresponds to the maximum total tokens that we can process in a single batch.


### `model_call_steps` -
number of times the model is to be called

### `block_size` -
or same as `page_size` for KV Cache

### `additional_config` -
example of additional config
Copy link
Collaborator

@gpolovets1 gpolovets1 Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"This is used to propagate tpu_commons-specific arguments (e.g. sharding and quantization settings)."


```
'{"sharding": {"sharding_strategy": {"tensor_parallelism": 8, "data_parallelism": 1}}, "quantization": { "qwix": { "rules": [{ "module_path": ".*", "weight_qtype": "float8_e4m3fn", "act_qtype": "float8_e4m3fn"}]}}}' --model_config='{"model":"Qwen/Qwen3-32B"}'
```

### `model_config` -
--model_config='{"model":"Qwen/Qwen3-32B"}'

### `new_model_design` -
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we using this? it seems like you are setting this via env variables in your code.

True if microbenchmarking is done for new models like L4 and DeepSeek v3

### `trace_dir` -

local location where traces are stored. Default value is `/tmp/tpu_commons_traces`

## Example command to run Microbenchmark
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these!


###
## Decode

```
python examples/microbenchamarking/microbenchmark_app.py --additional_config='{"sharding": {"sharding_strategy": {"tensor_parallelism": 8, "data_parallelism": 1}}}' --model_config='{"model":"Qwen/Qwen3-32B"}' --phase='decode' --max_seq_len=4096 --max_num_seq=2048 --model_hf_config="examples/microbenchamarking/hf_configs/qwen3_32b_hf_config.json"

```

## Prefill

```
python examples/microbenchamarking/microbenchmark_app.py --additional_config='{"sharding": {"sharding_strategy": {"tensor_parallelism": 8, "data_parallelism": 1}}}' --model_config='{"model":"Qwen/Qwen3-32B"}' --phase='prefill' --max_seq_len=1024 --max_num_seq=2 --max_prefill_len=512 --model_hf_config="examples/microbenchamarking/hf_configs/qwen3_32b_hf_config.json"

```

Notice that max_num_seq = 2 as maximum of 2 sequences can fit with 512 as max_prefill_len
62 changes: 62 additions & 0 deletions examples/microbenchamarking/hf_configs/deepseek_v3_hf_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
{
"architectures": ["DeepseekV3ForCausalLM"],
"attention_bias": "False",
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "configuration_deepseek.DeepseekV3Config",
"AutoModel": "modeling_deepseek.DeepseekV3Model",
"AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
},
"bos_token_id": 0,
"eos_token_id": 1,
"ep_size": 1,
"first_k_dense_replace": 3,
"hidden_act": "silu",
"hidden_size": 7168,
"initializer_range": 0.02,
"intermediate_size": 18432,
"kv_lora_rank": 512,
"max_position_embeddings": 163840,
"model_type": "deepseek_v3",
"moe_intermediate_size": 2048,
"moe_layer_freq": 1,
"n_group": 8,
"n_routed_experts": 256,
"n_shared_experts": 1,
"norm_topk_prob": "True",
"num_attention_heads": 128,
"num_experts_per_tok": 8,
"num_hidden_layers": 61,
"num_key_value_heads": 128,
"num_nextn_predict_layers": 1,
"q_lora_rank": 1536,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"quantization_config": {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [128, 128]
},
"rms_norm_eps": 1e-06,
"rope_scaling": {
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn"
},
"rope_theta": 10000,
"routed_scaling_factor": 2.5,
"scoring_func": "sigmoid",
"tie_word_embeddings": "False",
"topk_group": 4,
"topk_method": "noaux_tc",
"torch_dtype": "bfloat16",
"transformers_version": "4.33.1",
"use_cache": "True",
"v_head_dim": 128,
"vocab_size": 129280
}
28 changes: 28 additions & 0 deletions examples/microbenchamarking/hf_configs/qwen3_32b_hf_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"architectures": ["Qwen3ForCausalLM"],
"attention_bias": "False",
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 5120,
"initializer_range": 0.02,
"intermediate_size": 25600,
"max_position_embeddings": 40960,
"max_window_layers": 64,
"model_type": "qwen3",
"num_attention_heads": 64,
"num_hidden_layers": 64,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": "None",
"rope_theta": 1000000,
"sliding_window": "None",
"tie_word_embeddings": "False",
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": "True",
"use_sliding_window": "False",
"vocab_size": 151936
}
209 changes: 209 additions & 0 deletions examples/microbenchamarking/microbenchmark_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# pytype: disable=import-error
# pytype: disable=module-attr
# pytype: disable=attribute-error
# pytype: disable=wrong-arg-types
import json
import os
import time
from typing import Any, Sequence

import jax
from absl import app, flags
from flax import nnx
from microbenchmark_input_utils import InputArgs, InputCreator
from microbenchmark_utils import Sampler, init_mesh

from tpu_commons.logger import init_logger
from tpu_commons.mock.vllm_config_utils import (CacheConfig, ModelConfig,
VllmConfig)
from tpu_commons.models.jax.model_loader import get_model

logger = init_logger(__name__)

_MAX_SEQ_LEN = flags.DEFINE_integer("max_seq_len", 4096,
"Maximum sequence length.")

_PHASE = flags.DEFINE_string("phase", "decode", "Phase to benchmark.")

_DECODE_OFFSET_FROM_PREFILL = flags.DEFINE_integer(
"decode_offset_from_prefill",
0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the readme, does offset of 0 correspond to the last prefill token?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least set it to 1.

"Offset from the prefill length to start decoding.",
)

_MODEL_NAME = flags.DEFINE_string(
"model_name", "qwen3-32b",
"Model name to benchmark. Supported models: qwen3-32b, deepseek_v3")

_MODEL_HF_CONFIG = flags.DEFINE_string(
"model_hf_config",
"",
"Model HF config in json format.",
)

# this has to be overriden as the calculation is not very correct yet on microbenchmark side.
#TODO: @(vijaya) Fix the calculation and remove this flag as an override.
_KV_NUM_BLOCK_OVERRIDE = flags.DEFINE_integer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share some guidelines in the readme for how you are calculating this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far running on v7x-2 (depending on how much TPU memory you have right now) to determine the number of total blocks.

"num_block_override",
2048,
"Override number of blocks.",
)

_MAX_PREFILL_LEN = flags.DEFINE_integer("max_prefill_len", 512,
"Maximum prefill length.")

# for prefill, max_num_seq = max_seq_len // max_prefill_len
# for decode max_num_seq = max_seq_len // 1
_MAX_NUM_SEQ = flags.DEFINE_integer(
"max_num_seq",
2048,
"maximum number of sequences to be benchmarked.",
)

_MODEL_CALL_STEPS = flags.DEFINE_integer("model_call_steps", 5,
"Number of model call steps.")

_BLOCK_SIZE = flags.DEFINE_integer("block_size", 128, "Block size.")
_SAMPLER_TYPE = flags.DEFINE_string("sampler_type", "fixed", "Sampler type.")
_SAMPLER_STD = flags.DEFINE_float("sampler_std", 1.0,
"Sampler standard deviation.")
_ADDITIONAL_CONFIG = flags.DEFINE_string(
"additional_config",
"",
"Additional configuration for the model.",
)
_MODEL_CONFIG = flags.DEFINE_string(
"model_config",
"",
"Model configuration for the model.",
)

NEW_MODEL_DESIGN = flags.DEFINE_string(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we using this?

"NEW_MODEL_DESIGN",
"True",
"Model design to use. If True, uses the new model design which is used for newer models like DeepseekV3 and Llama4",
)

_TRACE_DIR = flags.DEFINE_string(
"trace_dir",
"/tmp/tpu_commons_traces",
"Directory to save the trace files.",
)


def get_hf_config_attribute_map(model_hf_config: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe creating this function is overkill? It's just one extra line of code to convert the string to a json =]

with open(model_hf_config, 'r') as file:
# Load the JSON data from the file
data = json.load(file)
return data


def validate_command_line_args():
if _PHASE.value not in ["prefill", "decode"]:
raise ValueError(
f"Phase {_PHASE.value} not supported. Choose either 'prefill' or 'decode'."
)
if _MAX_SEQ_LEN.value % _BLOCK_SIZE.value != 0:
raise ValueError(
f"Max sequence length {_MAX_SEQ_LEN.value} must be divisible by block size {_BLOCK_SIZE.value}."
)

if _PHASE.value == "prefill":
if _MAX_SEQ_LEN.value % _MAX_PREFILL_LEN.value != 0:
raise ValueError(
f"Max sequence length {_MAX_SEQ_LEN.value} must be divisible by max prefill length {_MAX_PREFILL_LEN.value}."
)
if _MAX_SEQ_LEN.value // _MAX_PREFILL_LEN.value != _MAX_NUM_SEQ.value:
raise ValueError(
f"Max number of sequences {_MAX_NUM_SEQ.value} must be equal to max sequence length {_MAX_SEQ_LEN.value} divided by max prefill length {_MAX_PREFILL_LEN.value}."
)


class Benchmarker:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a quick docstring on what this class does?

def __init__(self, vllm_config: VllmConfig, model: Any, mesh: Any,
sampler: Sampler, rng: nnx.Rngs, model_hf_config, state,
trace_directory):
"""
Class takes in VllmConfig, model function, mesh, sampler, rng, model_hf_config, state and trace_directory.
and benchmarks the model for the given phase after creating the input using InputCreator class.
"""
self.vllm_config = vllm_config
self.model = model
self.mesh = mesh
self.sampler = sampler
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used?

self.rng = rng
self.model_hf_config = model_hf_config
self.state = state
self.trace_directory = trace_directory

def benchmark(self, phase: str):
input_args = InputArgs(
block_size=_BLOCK_SIZE.value,
max_num_seq=_MAX_NUM_SEQ.value,
min_prefill_len=1,
max_prefill_len=_MAX_PREFILL_LEN.value,
max_model_len=_MAX_SEQ_LEN.value,
decode_offset_from_prefill=_DECODE_OFFSET_FROM_PREFILL.value,
sampler=self.sampler,
model_hf_config=self.model_hf_config,
phase=phase,
num_blocks_override=_KV_NUM_BLOCK_OVERRIDE.value)

input_creator = InputCreator(input_args=input_args,
sharding=None,
mesh=self.mesh,
rng=self.rng)
model_input = input_creator.create_input(phase=phase)

jax.profiler.start_trace(self.trace_directory)
start_time = time.time()
kv_caches, act = self.model(
self.state,
model_input.kv_caches,
model_input.input_ids,
model_input.attention_metadata,
)

act.block_until_ready()
end_time = time.time()
jax.profiler.stop_trace()
logger.info(
f"Time taken for model call in phase {phase}: {end_time - start_time} seconds. and profile trace is saved in {self.trace_directory}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"seconds. and" -> "seconds\nProfile"

)


def main(argv: Sequence[str]):
sampler = Sampler(type=_SAMPLER_TYPE.value, std=_SAMPLER_STD.value)
rng = nnx.Rngs(params=0)
vllm_config = VllmConfig(
additional_config=json.loads(_ADDITIONAL_CONFIG.value),
model_config=ModelConfig(**json.loads(_MODEL_CONFIG.value)),
cache_config=CacheConfig(block_size=_BLOCK_SIZE.value),
)

validate_command_line_args()

vllm_config.model_config.hf_config.attribute_map = get_hf_config_attribute_map(
_MODEL_HF_CONFIG.value)

mesh = init_mesh(vllm_config, jax.devices())
model_fn, compute_logits_fn, get_multimodal_embeddings_fn, get_input_embeddings_fn, state = get_model(
vllm_config,
rng.params(),
mesh,
)

benchmarker = Benchmarker(vllm_config, model_fn, mesh, sampler, rng,
vllm_config.model_config.hf_config, state,
_TRACE_DIR.value)
for _ in range(_MODEL_CALL_STEPS.value):
benchmarker.benchmark(_PHASE.value)


if __name__ == "__main__":
# uncomment below line to enable new model design
# os.environ['NEW_MODEL_DESIGN'] = 'True'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment to uncomment only when new model design

os.environ['JAX_RANDOM_WEIGHTS'] = 'True'
os.environ['TPU_BACKEND_TYPE'] = 'JAX'
app.run(main)
Loading