-
Notifications
You must be signed in to change notification settings - Fork 14
prefill decode microbenchmark for QWen3 #699
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure (put X in square brackets):
|
130786a
to
cca99e4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a description of all the args you can pass to the MB script and what they are?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean in README? Sure!
NEW_MODEL_DESIGN = flags.DEFINE_string( | ||
"NEW_MODEL_DESIGN", | ||
"True", | ||
"Model design to use. If True, uses the new model design.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe note that this is only needed for a few models right now (L4, DSv3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure!
if model_name == "qwen3-32b": | ||
return qwen3_32b_hf_config | ||
elif model_name == "deepseek_v3": | ||
return deepseek_v3_hf_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the README said only Qwen3 is supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I just checked for DeepSeek, and it just works out of the box, so will keep it for both.
from tpu_commons.utils import make_optimized_mesh | ||
|
||
logger = init_logger(__name__) | ||
power_of_two = np.pow(2, np.arange(18)) # up to 128k seq lens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this code is also called somewhere else -- maybe move to common utils file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do!
|
||
def init_mesh(vllm_config, devices) -> None: | ||
try: | ||
# TODO: Update override steps. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not going to be addressed here?
num_kv_heads: int = 32 | ||
head_dim: int = 128 | ||
vocab_size: int = 32000 | ||
model: str = "llama3" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you're updating the defaults here -- did you test this won't break anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have verified it, but we can check if having this is necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, model config in this format is needed only for new model design and is not needed for consolidated model code. We can remove it or modify it once everything is consolidated
head_size = model_config.get_head_size() | ||
num_kv_heads = model_config.get_total_num_kv_heads() | ||
hf_config = vllm_config.model_config.hf_config | ||
head_size = hf_config.head_dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I realized I said this was safe offline, but can you double check this won't break anything for Llama3 / Qwen3 / DeepSeek?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'll check offline_inference.py. Will that be a good check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I will remove it for now, github does not need it. g3 needs it. I can handle this in next PR
cca99e4
to
52841ee
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried this out, but I think few things can be improved from user-journey perspective
- Running the benchmark doesn't really return anything on the console. It took me a while to realize that it was saving a profile data in directory specified by argument
trace_dir
- Printing even just simple
the run took xyz seconds to complete
to the console would be super useful. Trying to dig through trace file every time I run the microbenchmark isn't ideal. - For better workflow, I wish it could automatically create xprof link. I believe it's not possible due to technical limitation, but I can imagine following features
- Add argument specifying gs bucket directory (e.g.,
--gs_dir=gs://...
) - Microbenchmark automatically uploads trace files to the user specified gs bucket and remember the trace file's path in gs bucket. (e.g.,
gs://...xplan.pb
- In the console, microbenchmark prints out commands user need to run in g3 workspace to create xprof linke (e.g.,
To create xprof link, run following command in g3: blaze run -c opt //cloud/tpu/tools/c2xprof:main -- --alsologtostderr --gcs_path=gs://...xplane.pb
- Add argument specifying gs bucket directory (e.g.,
Few additional minor comments
- Left a comment in the file, but creating a config for every model we want to support isn't scalable. Please consider reusing a logic that fetches config.json from hugging face and create a model config.
- No support for vllm model?
- This branch didn't work out-of-the-box and I had to make some fixes in
microbenchmark_input_utils.py
- Is there a number that you can share that the number returned here closely matches the number in the e2e run?
52841ee
to
ae522a8
Compare
Signed-off-by: Vijaya singh <singhvijaya@google.com>
ae522a8
to
7797944
Compare
@@ -0,0 +1,86 @@ | |||
# MICROBENCHAMRKING IS EXPERIMENTAL AND NOT SUPPORTED FOR ALL MODELS AND FLEXIBLE WORKLOADS | |||
|
|||
The Goal of microbenchmarking is to strip the model call from VLLM Dependencies (Scheduler and KV Cache Manager) for efficient debugging and performance optimization of just model call. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: VLLM -> vLLM
@@ -0,0 +1,86 @@ | |||
# MICROBENCHAMRKING IS EXPERIMENTAL AND NOT SUPPORTED FOR ALL MODELS AND FLEXIBLE WORKLOADS | |||
|
|||
The Goal of microbenchmarking is to strip the model call from VLLM Dependencies (Scheduler and KV Cache Manager) for efficient debugging and performance optimization of just model call. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A rewording suggestion:
"The goal of microbenchmarking is to strip out the vLLM server layer and focus on just profiling the model calls."
|
||
The Goal of microbenchmarking is to strip the model call from VLLM Dependencies (Scheduler and KV Cache Manager) for efficient debugging and performance optimization of just model call. | ||
|
||
The current version is ** working on pinned main ** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"The current implementation runs on the following pinned version of the main branch:"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, is not pinned anymore, I will just mention it so we can backtrack, but as long as model call API remains unchanged, it should work
|
||
> ⚠️ The microbenchmarking code **does not support all models and features and is currently used for debugging and optimizing static workloads | ||
**Only tested model for microbenchmarking is QWEN3-32B** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"The only model validated for microbenchmarking is Qwen3-32B."
Don't we support DeepSeek as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to keep this, and remove once we verify the runs together
## Params needed by microbenchmarking code | ||
|
||
### `max_seq_len` - | ||
max model len this is length of the model including number of prefill and decode tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"max model len this is the maximum supported length of each request. Typically this equals the maximum number of prefill + decode tokens across all requests."
|
||
### `phase` - | ||
|
||
phase of the model, supported modes are prefill and decode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Inference phase - supported phases are 'prefill' and 'decode'."
phase of the model, supported modes are prefill and decode | ||
|
||
### `decode_offset_from_prefill` - | ||
used in decode primarily, if the value is 1, it means 1st token after prefill |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"This offset indicates the decode step index to profile. E.g. setting a value of 10 corresponds to profiling the 10th decode step."
used in decode primarily, if the value is 1, it means 1st token after prefill | ||
|
||
### `model_hf_config` - | ||
path to json file where HFConfig is saved. We need this because we dont want to download from huggingface. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"We need this to avoid having to download the model from huggingface everytime."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, if we are just downloading the config, would it be a problem to download from HF?
max length of prefill sequence | ||
|
||
### `max_num_sequence` - | ||
is the maximum number of sequence supported by model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sequence -> sequences
|
||
i) In Prefill phase - `max_num_sequence` = max_seq_len // max_prefill_len | ||
|
||
ii) In Decode phase - `max_num_sequence` < `max_seq_len` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the maximum sequence length influence the maximum number of allowed sequences (and vice versa)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it help if we added --max-num-batched-tokens from vLLM? This variable corresponds to the maximum total tokens that we can process in a single batch.
or same as `page_size` for KV Cache | ||
|
||
### `additional_config` - | ||
example of additional config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"This is used to propagate tpu_commons-specific arguments (e.g. sharding and quantization settings)."
### `model_config` - | ||
--model_config='{"model":"Qwen/Qwen3-32B"}' | ||
|
||
### `new_model_design` - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we using this? it seems like you are setting this via env variables in your code.
|
||
local location where traces are stored. Default value is `/tmp/tpu_commons_traces` | ||
|
||
## Example command to run Microbenchmark |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding these!
|
||
_DECODE_OFFSET_FROM_PREFILL = flags.DEFINE_integer( | ||
"decode_offset_from_prefill", | ||
0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the readme, does offset of 0 correspond to the last prefill token?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least set it to 1.
|
||
# this has to be overriden as the calculation is not very correct yet on microbenchmark side. | ||
#TODO: @(vijaya) Fix the calculation and remove this flag as an override. | ||
_KV_NUM_BLOCK_OVERRIDE = flags.DEFINE_integer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you share some guidelines in the readme for how you are calculating this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far running on v7x-2 (depending on how much TPU memory you have right now) to determine the number of total blocks.
"Model configuration for the model.", | ||
) | ||
|
||
NEW_MODEL_DESIGN = flags.DEFINE_string( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we using this?
) | ||
|
||
|
||
def get_hf_config_attribute_map(model_hf_config: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe creating this function is overkill? It's just one extra line of code to convert the string to a json =]
end_time = time.time() | ||
jax.profiler.stop_trace() | ||
logger.info( | ||
f"Time taken for model call in phase {phase}: {end_time - start_time} seconds. and profile trace is saved in {self.trace_directory}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"seconds. and" -> "seconds\nProfile"
|
||
axis_names = ("data", "model") | ||
mesh_shape = (dp, tp) | ||
# for deepseekv3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since deepseek is a new model, I think this is not needed/will be skipped?
type: str | ||
std: float = None | ||
|
||
def generate_samples(self, shape: Tuple[int], fill_val: Any) -> np.array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this used anywhere?
self.vllm_config = vllm_config | ||
self.model = model | ||
self.mesh = mesh | ||
self.sampler = sampler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not used?
]) | ||
|
||
def _create_mock_block_table(self, random_permute: bool = False): | ||
block_table = np.arange(self.input_args.num_blocks_override, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we assuming that the KV cache is always full? My understanding is that the block table tells you which blocks to use in a batch but it can technically be less than the full KV cache buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless the block_table is padded to represent a non-full table.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation is different with tpu_commons.
@dataclass | ||
class InputArgs: | ||
max_num_seq: int | ||
max_prefill_len: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a duplicate.
In prefill phase, all sequences are of length max_prefill_len""" | ||
if phase == 'decode': | ||
# this means all sequences are in decode phase | ||
return np.random.randint(1, vocab_size, size=max_num_sequence) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is probably between 0 and vocab_size-1
phase='decode') -> jnp.ndarray: | ||
""" | ||
Creates sequence lengths based on phase. | ||
In decode phase, all sequences are of length offset_from_prefill (usually 1) + max_prefill_len -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't offset_from_prefill configurable now?
elif phase == 'prefill': | ||
# this means all sequences are in prefill phase with same prefill length which is max_prefill_len | ||
return np.random.randint(max_prefill_len, | ||
max_prefill_len + 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it max_prefill_len+1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this deprecated?
""" | ||
if phase == 'prefill': | ||
query_start_offsets = np.zeros(max_num_seq + 1, dtype=np.int32) | ||
query_start_offsets[0] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already true from line 249
offset_from_prefill: int = 1, | ||
phase='decode') -> jnp.ndarray: | ||
""" | ||
Creates input positions based on phase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain what input positions is in this context?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about mentioning it is the input position index?
""" | ||
if phase == 'decode': | ||
# in decode phase, all sequences are of length 1 and the position is immediately after the prefill length | ||
return np.full((max_num_seqs, ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be the sum(seq_lens) (i.e. total scheduled tokens)?
|
||
def create_num_blocks(max_model_len: int, | ||
block_size: int, | ||
num_block_override=0) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you should not set a default for num_block_override if right now the implementation requires setting this? You can add a TODO instead to make it dynamic in the future.
|
||
if phase == 'decode': | ||
# this means all sequences are in decode phase | ||
return np.array([max_num_sequence, max_num_sequence, max_num_sequence], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whey are there max_num_sequence
prefill sequences during deocde?
dtype=np.int32) | ||
elif phase == 'prefill': | ||
# this means all sequences are in prefill phase | ||
return np.array([0, 0, max_num_sequence], dtype=np.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the docstring, this is setting both prefill & decode to 0 but total number of sequences to max_num_sequence
""" | ||
Creates request distribution based on phase. | ||
request_distribution is of shape (3,) where | ||
request_distribution[0] = number of sequences in decode phase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the documentation here? We can also write that empirically this has been confirmed but a TODO is to review the tpu_commons code one more time and confirm.
Description
Start with a short description of what the PR does and how this is a change from
the past.
The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure: