Skip to content

Conversation

mailvijayasingh
Copy link
Collaborator

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

Copy link

github-actions bot commented Sep 16, 2025

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

@mailvijayasingh mailvijayasingh force-pushed the wip_check_vij branch 8 times, most recently from 130786a to cca99e4 Compare September 19, 2025 22:29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a description of all the args you can pass to the MB script and what they are?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean in README? Sure!

NEW_MODEL_DESIGN = flags.DEFINE_string(
"NEW_MODEL_DESIGN",
"True",
"Model design to use. If True, uses the new model design.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe note that this is only needed for a few models right now (L4, DSv3)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

if model_name == "qwen3-32b":
return qwen3_32b_hf_config
elif model_name == "deepseek_v3":
return deepseek_v3_hf_config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the README said only Qwen3 is supported?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I just checked for DeepSeek, and it just works out of the box, so will keep it for both.

from tpu_commons.utils import make_optimized_mesh

logger = init_logger(__name__)
power_of_two = np.pow(2, np.arange(18)) # up to 128k seq lens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this code is also called somewhere else -- maybe move to common utils file?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do!


def init_mesh(vllm_config, devices) -> None:
try:
# TODO: Update override steps.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not going to be addressed here?

num_kv_heads: int = 32
head_dim: int = 128
vocab_size: int = 32000
model: str = "llama3"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you're updating the defaults here -- did you test this won't break anything?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have verified it, but we can check if having this is necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, model config in this format is needed only for new model design and is not needed for consolidated model code. We can remove it or modify it once everything is consolidated

head_size = model_config.get_head_size()
num_kv_heads = model_config.get_total_num_kv_heads()
hf_config = vllm_config.model_config.hf_config
head_size = hf_config.head_dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I realized I said this was safe offline, but can you double check this won't break anything for Llama3 / Qwen3 / DeepSeek?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll check offline_inference.py. Will that be a good check?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I will remove it for now, github does not need it. g3 needs it. I can handle this in next PR

Copy link
Collaborator

@kyuyeunk kyuyeunk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried this out, but I think few things can be improved from user-journey perspective

  1. Running the benchmark doesn't really return anything on the console. It took me a while to realize that it was saving a profile data in directory specified by argument trace_dir
  2. Printing even just simple the run took xyz seconds to complete to the console would be super useful. Trying to dig through trace file every time I run the microbenchmark isn't ideal.
  3. For better workflow, I wish it could automatically create xprof link. I believe it's not possible due to technical limitation, but I can imagine following features
    1. Add argument specifying gs bucket directory (e.g., --gs_dir=gs://...)
    2. Microbenchmark automatically uploads trace files to the user specified gs bucket and remember the trace file's path in gs bucket. (e.g., gs://...xplan.pb
    3. In the console, microbenchmark prints out commands user need to run in g3 workspace to create xprof linke (e.g., To create xprof link, run following command in g3: blaze run -c opt //cloud/tpu/tools/c2xprof:main -- --alsologtostderr --gcs_path=gs://...xplane.pb

Few additional minor comments

  1. Left a comment in the file, but creating a config for every model we want to support isn't scalable. Please consider reusing a logic that fetches config.json from hugging face and create a model config.
  2. No support for vllm model?
  3. This branch didn't work out-of-the-box and I had to make some fixes in microbenchmark_input_utils.py
  4. Is there a number that you can share that the number returned here closely matches the number in the e2e run?

Signed-off-by: Vijaya singh <singhvijaya@google.com>
@@ -0,0 +1,86 @@
# MICROBENCHAMRKING IS EXPERIMENTAL AND NOT SUPPORTED FOR ALL MODELS AND FLEXIBLE WORKLOADS

The Goal of microbenchmarking is to strip the model call from VLLM Dependencies (Scheduler and KV Cache Manager) for efficient debugging and performance optimization of just model call.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: VLLM -> vLLM

@@ -0,0 +1,86 @@
# MICROBENCHAMRKING IS EXPERIMENTAL AND NOT SUPPORTED FOR ALL MODELS AND FLEXIBLE WORKLOADS

The Goal of microbenchmarking is to strip the model call from VLLM Dependencies (Scheduler and KV Cache Manager) for efficient debugging and performance optimization of just model call.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A rewording suggestion:
"The goal of microbenchmarking is to strip out the vLLM server layer and focus on just profiling the model calls."


The Goal of microbenchmarking is to strip the model call from VLLM Dependencies (Scheduler and KV Cache Manager) for efficient debugging and performance optimization of just model call.

The current version is ** working on pinned main **
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"The current implementation runs on the following pinned version of the main branch:"

Copy link
Collaborator Author

@mailvijayasingh mailvijayasingh Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, is not pinned anymore, I will just mention it so we can backtrack, but as long as model call API remains unchanged, it should work


> ⚠️ The microbenchmarking code **does not support all models and features and is currently used for debugging and optimizing static workloads
**Only tested model for microbenchmarking is QWEN3-32B**
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"The only model validated for microbenchmarking is Qwen3-32B."

Don't we support DeepSeek as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to keep this, and remove once we verify the runs together

## Params needed by microbenchmarking code

### `max_seq_len` -
max model len this is length of the model including number of prefill and decode tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"max model len this is the maximum supported length of each request. Typically this equals the maximum number of prefill + decode tokens across all requests."


### `phase` -

phase of the model, supported modes are prefill and decode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Inference phase - supported phases are 'prefill' and 'decode'."

phase of the model, supported modes are prefill and decode

### `decode_offset_from_prefill` -
used in decode primarily, if the value is 1, it means 1st token after prefill
Copy link
Collaborator

@gpolovets1 gpolovets1 Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"This offset indicates the decode step index to profile. E.g. setting a value of 10 corresponds to profiling the 10th decode step."

used in decode primarily, if the value is 1, it means 1st token after prefill

### `model_hf_config` -
path to json file where HFConfig is saved. We need this because we dont want to download from huggingface.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"We need this to avoid having to download the model from huggingface everytime."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, if we are just downloading the config, would it be a problem to download from HF?

max length of prefill sequence

### `max_num_sequence` -
is the maximum number of sequence supported by model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sequence -> sequences


i) In Prefill phase - `max_num_sequence` = max_seq_len // max_prefill_len

ii) In Decode phase - `max_num_sequence` < `max_seq_len`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the maximum sequence length influence the maximum number of allowed sequences (and vice versa)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it help if we added --max-num-batched-tokens from vLLM? This variable corresponds to the maximum total tokens that we can process in a single batch.

or same as `page_size` for KV Cache

### `additional_config` -
example of additional config
Copy link
Collaborator

@gpolovets1 gpolovets1 Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"This is used to propagate tpu_commons-specific arguments (e.g. sharding and quantization settings)."

### `model_config` -
--model_config='{"model":"Qwen/Qwen3-32B"}'

### `new_model_design` -
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we using this? it seems like you are setting this via env variables in your code.


local location where traces are stored. Default value is `/tmp/tpu_commons_traces`

## Example command to run Microbenchmark
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these!


_DECODE_OFFSET_FROM_PREFILL = flags.DEFINE_integer(
"decode_offset_from_prefill",
0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the readme, does offset of 0 correspond to the last prefill token?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least set it to 1.


# this has to be overriden as the calculation is not very correct yet on microbenchmark side.
#TODO: @(vijaya) Fix the calculation and remove this flag as an override.
_KV_NUM_BLOCK_OVERRIDE = flags.DEFINE_integer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share some guidelines in the readme for how you are calculating this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far running on v7x-2 (depending on how much TPU memory you have right now) to determine the number of total blocks.

"Model configuration for the model.",
)

NEW_MODEL_DESIGN = flags.DEFINE_string(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we using this?

)


def get_hf_config_attribute_map(model_hf_config: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe creating this function is overkill? It's just one extra line of code to convert the string to a json =]

end_time = time.time()
jax.profiler.stop_trace()
logger.info(
f"Time taken for model call in phase {phase}: {end_time - start_time} seconds. and profile trace is saved in {self.trace_directory}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"seconds. and" -> "seconds\nProfile"


axis_names = ("data", "model")
mesh_shape = (dp, tp)
# for deepseekv3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since deepseek is a new model, I think this is not needed/will be skipped?

type: str
std: float = None

def generate_samples(self, shape: Tuple[int], fill_val: Any) -> np.array:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used anywhere?

self.vllm_config = vllm_config
self.model = model
self.mesh = mesh
self.sampler = sampler
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used?

])

def _create_mock_block_table(self, random_permute: bool = False):
block_table = np.arange(self.input_args.num_blocks_override,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we assuming that the KV cache is always full? My understanding is that the block table tells you which blocks to use in a batch but it can technically be less than the full KV cache buffer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless the block_table is padded to represent a non-full table.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation is different with tpu_commons.

@dataclass
class InputArgs:
max_num_seq: int
max_prefill_len: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a duplicate.

In prefill phase, all sequences are of length max_prefill_len"""
if phase == 'decode':
# this means all sequences are in decode phase
return np.random.randint(1, vocab_size, size=max_num_sequence)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is probably between 0 and vocab_size-1

phase='decode') -> jnp.ndarray:
"""
Creates sequence lengths based on phase.
In decode phase, all sequences are of length offset_from_prefill (usually 1) + max_prefill_len -1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't offset_from_prefill configurable now?

elif phase == 'prefill':
# this means all sequences are in prefill phase with same prefill length which is max_prefill_len
return np.random.randint(max_prefill_len,
max_prefill_len + 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it max_prefill_len+1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this deprecated?

"""
if phase == 'prefill':
query_start_offsets = np.zeros(max_num_seq + 1, dtype=np.int32)
query_start_offsets[0] = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already true from line 249

offset_from_prefill: int = 1,
phase='decode') -> jnp.ndarray:
"""
Creates input positions based on phase.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what input positions is in this context?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about mentioning it is the input position index?

"""
if phase == 'decode':
# in decode phase, all sequences are of length 1 and the position is immediately after the prefill length
return np.full((max_num_seqs, ),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be the sum(seq_lens) (i.e. total scheduled tokens)?


def create_num_blocks(max_model_len: int,
block_size: int,
num_block_override=0) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you should not set a default for num_block_override if right now the implementation requires setting this? You can add a TODO instead to make it dynamic in the future.


if phase == 'decode':
# this means all sequences are in decode phase
return np.array([max_num_sequence, max_num_sequence, max_num_sequence],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whey are there max_num_sequence prefill sequences during deocde?

dtype=np.int32)
elif phase == 'prefill':
# this means all sequences are in prefill phase
return np.array([0, 0, max_num_sequence], dtype=np.int32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the docstring, this is setting both prefill & decode to 0 but total number of sequences to max_num_sequence

"""
Creates request distribution based on phase.
request_distribution is of shape (3,) where
request_distribution[0] = number of sequences in decode phase
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the documentation here? We can also write that empirically this has been confirmed but a TODO is to review the tpu_commons code one more time and confirm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants