-
Notifications
You must be signed in to change notification settings - Fork 424
Support loading autotuned results from json for cutlass fp4 moe backends #1310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @kaixih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a crucial caching mechanism for autotuned results of the cutlass fp4 moe
operation. By saving optimal configurations to a JSON file and allowing them to be loaded, it aims to significantly reduce the overhead of repeated autotuning, especially during inference. This change also includes a new set of pre-tuned configurations for NVIDIA B200 GPUs and minor improvements to the autotuning benchmark script and logging utility.
Highlights
- Caching Autotuned Results: Implemented functionality to save and load autotuned configurations for
cutlass fp4 moe
operations to/from a JSON file, leveraging theflashinfer.autotuner
module. - Pre-tuned Configurations for B200: Added a new JSON file (
benchmarks/trtllm_fused_moe_NVIDIA_B200.json
) containing pre-computed optimal configurations for NVIDIA B200 GPUs, specifically fortrtllm::fused_moe::gemm1
andgemm2
operations across various input sizes. - Autotuning Process Refinement: Modified the benchmark script (
benchmarks/bench_cutlass_fused_moe.py
) to enable saving of autotuned results and adjusted thecutlass_fused_moe
call to use a configurabletune_max_num_tokens
parameter instead of a hardcoded value. - Dynamic Cache Loading: Enhanced the autotuner to conditionally load cached results from a JSON file based on the
FLASHINFER_AUTOTUNER_LOAD_FROM_FILE
environment variable, preventing redundant autotuning during inference when pre-tuned results are available. - Logging Utility Improvement: Introduced an
info_once
method in the JIT logger (flashinfer/jit/core.py
) to prevent repetitive log messages, improving log readability during autotuning and cache loading.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a mechanism to cache autotuning results to a JSON file, which can then be loaded to avoid expensive tuning runs during inference. The changes include modifications to the autotuner, a benchmark script to generate the cache, and a new logger utility. The main feedback is on the implementation of loading from the JSON cache in flashinfer/autotuner.py
. The current implementation is not robust against missing or corrupt cache files and is inefficient in its caching strategy. I've provided a suggestion to improve this. Overall, this is a valuable feature for improving performance by avoiding repeated autotuning.
flashinfer/autotuner.py
Outdated
def load_from_json(key): | ||
with open(get_json_path(), "r") as f: | ||
configs = json.load(f) | ||
k = str((key[0], key[1], key[3])) | ||
if k in configs: | ||
return True, configs[k][0], configs[k][1], None | ||
return False, 0, -1, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation of load_from_json
has two potential issues:
-
Lack of Error Handling: It doesn't handle
FileNotFoundError
if the cache file doesn't exist, orjson.JSONDecodeError
if the file is corrupt. This could cause the program to crash. -
Inefficient Caching: The
@lru_cache
is applied toload_from_json
which takeskey
as an argument. This means the JSON file will be re-read and re-parsed for every uniquekey
, which is inefficient as the file content is the same for all keys within a run.
A better approach is to cache the result of reading the file itself. You can introduce a helper function to read and parse the JSON file and apply lru_cache
to it. This ensures the file is read only once.
from functools import lru_cache
import json
@lru_cache(maxsize=1)
def _read_autotune_cache_file(path):
"""Helper to read and parse the autotune cache file, cached."""
try:
with open(path, "r") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
return None
def load_from_json(key):
json_path = get_json_path()
configs = _read_autotune_cache_file(json_path)
if configs is None:
return False, 0, -1, None
k = str((key[0], key[1], key[3]))
if k in configs:
return True, configs[k][0], configs[k][1], None
return False, 0, -1, None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, thanks for brining cutlass kernel tuning to flashinfer.
My main concern is how do we we store the best configs, currently I believe storing them as python objects under flashinfer.tuning_configs
(I'm flexible with the naming) is easier for packaging compared to json:
- If tuning_configs is a python module, we just need to add
flashinfer.tuning_configs
module to packages:Line 40 in 43e08e9
packages = [ - If they are json files, we have to update the package dir
Line 61 in 43e08e9
[tool.setuptools.package-dir] Line 67 in 43e08e9
[tool.setuptools.package-data]
So in general I feel like using python to store them is the most convenient solution.
If the tuning configuration is large, we can consider host them on artifactory (we left it for future work).
Is it possible that the autotune pass is called at the profiling stage in the framework, and also |
@yzh119 I think all the comments are addressed. PTAL. |
if we have such a stage in FW, we don't need the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Left a minor suggestion
@@ -195,23 +186,44 @@ def bench_cutlass_fused_moe( | |||
output=flash_output, | |||
) | |||
) | |||
avg_ms = sum(ms_list) / len(ms_list) | |||
print("input\tweight1\tweight2\ttime(ms)") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two lines are not aligned and displayed as:
input weight1 weight2 time(ms)
(32, 3584) (32, 4096, 7168) (32, 7168, 2048) 0.19970719873905182
print("input\tweight1\tweight2\ttime(ms)") | |
print(f"{'input':<15} {'weight1':<20} {'weight2':<20} {'time(ms)'}") | |
print( | |
f"{str(tuple(hidden_states.shape)):<15} {str(tuple(w1.shape)):<20} {str(tuple(w2.shape)):<20} {avg_ms:.3f}" | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. PTAL.
This is a typical workflow: # update the config file
if [[ "$1" == "update" ]]; then
FLASHINFER_AUTOTUNER_LOAD_FROM_FILE=0 python bench_cutlass_fused_moe.py --num-tokens 4096 --update-config
exit 0
fi
# benchmark
export FLASHINFER_AUTOTUNER_LOAD_FROM_FILE=0
for i in 1 2 4 8 16 24 32 48 64 96 128 256 512 1024 1536 2048 3072 4096 8192 16384; do
python bench_cutlass_fused_moe.py --num-tokens $i --skip-autotune
done
export FLASHINFER_AUTOTUNER_LOAD_FROM_FILE=1
for i in 1 2 4 8 16 24 32 48 64 96 128 256 512 1024 1536 2048 3072 4096 8192 16384; do
python bench_cutlass_fused_moe.py --num-tokens $i --skip-autotune
done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you @kaixih !
This PR adds support for loading autotuned results from JSON files for the Cutlass FP4 MoE backends.
The script
benchmarks/bench_cutlass_fused_moe.py
generates a JSON file atconfigs/<flashinfer_version>/trtllm_fused_moe_<device_name>.json
, mapping input shapes to the optimal config/tactic for GEMMs used infused_moe.cutlass_fused_moe
.At runtime, setting the
FLASHINFER_AUTOTUNER_LOAD_FROM_FILE
environment variable enables loading from this file. If the variable is unset or a matching entry is not found, it falls back to the default config/tactic.Configs are organized by flashinfer version and GPU device.
cc. @yzh119 @wenscarl @kushanam