-
Notifications
You must be signed in to change notification settings - Fork 264
Add lightllm kernel #925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add lightllm kernel #925
Conversation
# vit fp8w8a8量化推理相关算子优化 ## 新增算子 1. rmsnorm_bf16,性能较pytorch较大提升 2. pre_tp_norm,融合了tp_norm的通信前操作 3. post_tp_norm,融合了tp_norm的通信后操作 4. pre_token_quant,逐token FP8量化,性能较vllm的quant极大提升,较sgl的quant性能更好 5. gelu_per_token_quant,融合了GELU激活 + 逐token FP8量化 6. add_norm_quant,融合了attention与mlp模块间的,add norm quant操作 7. cutlass_scaled_mm_bias_ls,融合了量化矩阵乘、反量化和可选的bias和ls weight
# 添加vit调用优化算子逻辑 1. rmsnorm,添加单卡或者多卡优化算子逻辑 2. per_token_quant,添加激活值量化优化算子逻辑
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @SangChengC, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
gemini-code-assist here to provide a summary of this pull request. This PR introduces a new CUDA kernel library, lightllm-kernel
, aimed at providing highly optimized GPU operations for the LightLLM inference system. The library includes custom kernels for various critical components of large language model inference, such as attention mechanisms (including INT8 KV cache and flash decoding variants), normalization (RMSNorm and Tensor Parallel RMSNorm), quantization (per-token FP8), fused operations (Add+Norm+Quant, GELU+Quant), matrix multiplication (leveraging CUTLASS 3.x for scaled GEMMs with bias and layer scale fusion), and Mixture-of-Experts (MoE) related operations (Grouped Top-K). The library is designed to be built as a Python extension module and includes build configurations using both CMake/scikit-build and setuptools. The PR also integrates these new kernels into the main LightLLM codebase, specifically updating parts of the Llama and ViT model implementations and the general quantization logic to utilize the new optimized operations where applicable. Comprehensive benchmark scripts are included to compare the performance and accuracy of the new kernels against existing implementations (PyTorch, Triton, vLLM, SGLang).
Highlights
- New CUDA Kernel Library: Introduces
lightllm-kernel
, a dedicated library for high-performance GPU operations in LightLLM. - Optimized Operations: Adds custom CUDA kernels for key LLM inference tasks: Attention (INT8 KV, Flash Decoding), RMSNorm, Tensor Parallel Norm, Per-Token FP8 Quantization, Fused Add+Norm+Quant, Fused GELU+Quant, Scaled GEMM (via CUTLASS 3.x), and Grouped Top-K for MoE.
- CUTLASS 3.x Integration: Leverages the latest CUTLASS 3.x API for advanced GEMM fusion capabilities, including scaling, bias, and layer scale (LS) application directly within the matrix multiplication kernel.
- LightLLM Core Integration: Updates existing LightLLM model implementations (Llama, ViT) and quantization logic to use the newly added kernels for improved performance.
- Build and Benchmarking: Provides flexible build options (CMake/scikit-build, setuptools) and includes benchmark scripts to validate kernel performance and accuracy.
Changelog
Click here to see the changelog
- .gitmodules
- Added
third-party/cutlass
as a submodule.
- Added
- docs/CN/source/getting_started/multimodal_model_quickstart.rst
- Added a command example for deploying a multimodal API server with
--enable_multimodal
.
- Added a command example for deploying a multimodal API server with
- lightllm-kernel/
- Added a new directory containing the source code, build files, and documentation for the
lightllm-kernel
library.
- Added a new directory containing the source code, build files, and documentation for the
- lightllm-kernel/CMakeLists.txt
- Added CMake build configuration for the
lightllm-kernel
Python extension module, including finding dependencies (Torch, Python, CUDA) and defining build/install rules.
- Added CMake build configuration for the
- lightllm-kernel/LICENSE
- Added the Apache License 2.0 for the
lightllm-kernel
library.
- Added the Apache License 2.0 for the
- lightllm-kernel/Makefile
- Added a Makefile with
build
andclean
targets for the kernel library, including submodule update and pip installation with specific CUDA architectures.
- Added a Makefile with
- lightllm-kernel/README-CH.md
- Added a Chinese README file describing the
lightllm-kernel
library, its features, installation, and contribution guidelines.
- Added a Chinese README file describing the
- lightllm-kernel/README.md
- Added an English README file describing the
lightllm-kernel
library, its features, installation, and contribution guidelines.
- Added an English README file describing the
- lightllm-kernel/benchmark/bench_quant_per_token_bf16_fp8.py
- Added a benchmark script for per-token quantization comparing
lightllm_kernel
,vllm
, andsgl_kernel
.
- Added a benchmark script for per-token quantization comparing
- lightllm-kernel/benchmark/bench_rms_norm.py
- Added a benchmark script for RMS normalization comparing
lightllm_kernel
,torch
,triton
, andvllm
.
- Added a benchmark script for RMS normalization comparing
- lightllm-kernel/benchmark/bench_tp_norm.py
- Added a benchmark script for tensor-parallel RMS normalization comparing
lightllm_kernel
and a PyTorch reference.
- Added a benchmark script for tensor-parallel RMS normalization comparing
- lightllm-kernel/csrc/allgather/all_gather.cu
- Added CUDA implementation for custom Allgather operation using IPC and named barriers.
- lightllm-kernel/csrc/allgather/all_gather.cuh
- Added CUDA header defining structures and kernels for custom Allgather and Allreduce operations.
- lightllm-kernel/csrc/attention/decode_attention_kernel.cu
- Added CUDA kernel for decoding attention with INT8 KV cache.
- lightllm-kernel/csrc/attention/decode_attention_kernel_in8kv_flashdecoding.cu
- Added CUDA kernel for flash decoding attention with INT8 KV cache.
- lightllm-kernel/csrc/cuda_compat.h
- Added CUDA/ROCm compatibility macros for warp shuffle and device attributes.
- lightllm-kernel/csrc/fusion/add_norm_quant.cu
- Added CUDA kernel for fused Add, RMSNorm, and per-token FP8 quantization.
- lightllm-kernel/csrc/fusion/gelu_per_token_quant.cu
- Added CUDA kernel for fused GELU activation and per-token FP8 quantization.
- lightllm-kernel/csrc/fusion/post_tp_norm.cu
- Added CUDA kernel for the second stage of tensor-parallel RMS normalization.
- lightllm-kernel/csrc/fusion/pre_tp_norm.cu
- Added CUDA kernel for the first stage of tensor-parallel RMS normalization.
- lightllm-kernel/csrc/gemm/Epilogues.md
- Added documentation for CUTLASS epilogues used for fused de-quantization and bias/LS.
- lightllm-kernel/csrc/gemm/scaled_mm_c3x.cu
- Added CUDA implementation for scaled matrix multiplication using CUTLASS 3.x (SM90+).
- lightllm-kernel/csrc/gemm/scaled_mm_c3x.cuh
- Added CUDA header defining CUTLASS 3.x GEMM structures and caller function.
- lightllm-kernel/csrc/gemm/scaled_mm_c3x_sm90_fp8_dispatch.cuh
- Added CUDA header for dispatching SM90 FP8 GEMM configurations based on problem shape.
- lightllm-kernel/csrc/gemm/scaled_mm_entry.cu
- Added CUDA entry point for scaled matrix multiplication, including device capability checks.
- lightllm-kernel/csrc/moe/grouped_topk.cu
- Added CUDA kernel for Grouped Top-K selection (for MoE).
- lightllm-kernel/csrc/norm/rmsnorm_bf16.cu
- Added CUDA kernel for RMS normalization (BF16).
- lightllm-kernel/csrc/ops_bindings.cpp
- Added PyBind11 bindings to expose C++/CUDA kernels to Python.
- lightllm-kernel/include/cutlass_extensions/common.hpp
- Added utility functions and macros for CUTLASS/CUDA error checking and device info.
- lightllm-kernel/include/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
- Added custom CUTLASS 3.x epilogue components for broadcasting data.
- lightllm-kernel/include/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
- Added custom CUTLASS 3.x epilogues for fusing scaling, bias, and LS.
- lightllm-kernel/include/ops_common.h
- Added header declaring C++ functions exposed via PyBind11.
- lightllm-kernel/include/reduce/sm70.cuh
- Added CUDA kernels for block-wide sum and max reduction.
- lightllm-kernel/include/utils.h
- Added CUDA utility functions and type definitions.
- lightllm-kernel/lightllm_kernel/init.py
- Added Python package initialization.
- lightllm-kernel/lightllm_kernel/ops/init.py
- Added Python module initialization for kernel operations, including dynamic compilation logic and Python wrappers.
- lightllm-kernel/lightllm_kernel/ops/allgather.py
- Added Python wrappers for custom Allgather functions.
- lightllm-kernel/lightllm_kernel/ops/attention.py
- Added Python wrappers for INT8KV decode and flash decoding attention kernels.
- lightllm-kernel/lightllm_kernel/ops/fusion.py
- Added Python wrappers for fused operations (TP norm, add_norm_quant, gelu_quant).
- lightllm-kernel/lightllm_kernel/ops/gemm.py
- Added Python wrapper for CUTLASS scaled matrix multiplication.
- lightllm-kernel/lightllm_kernel/ops/moe.py
- Added Python wrapper for grouped top-k kernel.
- lightllm-kernel/lightllm_kernel/ops/norm.py
- Added Python wrapper for RMS normalization kernel.
- lightllm-kernel/lightllm_kernel/ops/quant.py
- Added Python wrapper for per-token quantization kernel.
- lightllm-kernel/pyproject.toml
- Added pyproject.toml for scikit-build-core configuration.
- lightllm-kernel/setup.py
- Added setup.py for building the kernel as a Python extension using torch.utils.cpp_extension.
- lightllm-kernel/test/fusion/add_norm_quant_test.py
- Added unit tests and benchmarks for fused Add+Norm+Quant.
- lightllm-kernel/test/fusion/gelu_per_token_quant_test.py
- Added unit tests and benchmarks for fused GELU+Quant.
- lightllm-kernel/test/fusion/post_tp_norm_test.py
- Added unit tests and benchmarks for post-TP Norm.
- lightllm-kernel/test/fusion/pre_tp_norm_test.py
- Added unit tests and benchmarks for pre-TP Norm.
- lightllm-kernel/test/gemm/cutlass_scaled_mm_test.py
- Added unit tests and benchmarks for CUTLASS scaled GEMM with bias and LS.
- lightllm-kernel/test/norm/rmsnorm_test.py
- Added unit tests and benchmarks for RMS Norm.
- lightllm-kernel/test/quant/quant_test.py
- Added unit tests and benchmarks for per-token quantization.
- lightllm-kernel/test/utils.py
- Added utility functions for benchmarking and error calculation.
- lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py
- Modified
mm
method to accept an optionalls_weight
parameter.
- Modified
- lightllm/common/quantization/w8a8_quant.py
- Added conditional logic to use
lightllm_kernel
for scaled FP8 quantization if available. - Updated
apply
method to passls_weight
to the matrix multiplication call.
- Added conditional logic to use
- lightllm/models/llama/layer_infer/transformer_layer_infer.py
- Updated decode attention calls to use the new
lightllm_kernel.ops.group_int8kv_decode_attention
.
- Updated decode attention calls to use the new
- lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py
- Updated flash decoding attention calls to use the new
lightllm_kernel.ops.group8_int8kv_flashdecoding_stage1
.
- Updated flash decoding attention calls to use the new
- lightllm/models/vit/layer_infer/transformer_layer_infer.py
- Added
tp_norm_cuda
method utilizinglightllm_kernel
TP norm ops. - Updated
_qk_norm
to conditionally usetp_norm_cuda
. - Updated
_get_o
and_ffn
methods to passls_weight
to the matrix multiplication calls.
- Added
- lightllm/models/vit/triton_kernel/flashattention_nopad.py
- Adjusted the argument list in
flash_attention_v3_fwd
call by adding aNone
argument.
- Adjusted the argument list in
- lightllm/utils/light_utils.py
- Modified the import logic for
lightllm_kernel
to prefer theops
submodule if it exists.
- Modified the import logic for
- third-party/cutlass
- Updated the pinned commit for the
cutlass
submodule.
- Updated the pinned commit for the
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new, self-contained lightllm-kernel
package, which is a significant and positive architectural change. It brings optimized CUDA kernels for various operations like normalization, quantization, GEMM (leveraging CUTLASS 3.x for Hopper), attention, and custom all-gather/all-reduce. The inclusion of comprehensive build systems (CMake, pyproject.toml, setup.py, Makefile), Python bindings, benchmark scripts, and unit tests is commendable and demonstrates a thorough approach.
The integration into the main lightllm
library appears to be well-handled with conditional logic to use these new kernels when available. The overall structure is modular and aims for performance improvements, especially with FP8 support and fused operations.
While the PR is generally in good shape, there are a few areas for improvement, primarily concerning build robustness, documentation placeholders, and error handling in kernels, which are detailed in the specific comments.
Summary of Findings
- Build System: The CMakeLists.txt uses
file(GLOB_RECURSE)
which can be less robust for detecting new source files compared to explicit listing. This is a medium maintainability concern. - Documentation: README files (both English and Chinese versions) contain placeholder URLs for the repository. The Chinese README also has an unfinished list item. These are medium severity issues affecting documentation accuracy.
- Error Handling in Kernels: Several CUDA kernels use
assert(false)
for unhandled cases (e.g., unsupported head dimensions, shared memory limits). This should be replaced withTORCH_CHECK
orstd::runtime_error
for better error reporting. This is a high severity issue. - Numerical Consistency: The
gelu_per_token_quant.cu
kernel uses different GeLU approximations (erf
-based vstanhf
-based) depending on the input dimensionN
. This could lead to minor numerical inconsistencies and should be clarified or made consistent. This is a medium severity issue. - Path Fragility in Python Bindings: The JIT compilation fallback in
lightllm_kernel/ops/__init__.py
uses a hardcoded relative path to find the repository root, which could be fragile. This is a medium maintainability concern. - File Endings (Low Severity - Not Commented): Several new files (LICENSE, Makefile, README.md, and some test/kernel files) are missing a final newline character. This is a minor stylistic issue.
- Chinese Comments in CMake (Low Severity - Not Commented): A comment in
lightllm-kernel/CMakeLists.txt
(line 4) is in Chinese. While not critical, using English for comments generally improves broader accessibility and maintainability in open-source projects.
Merge Readiness
This pull request introduces a substantial and valuable new kernel library. The code is generally well-structured and includes important components like tests and benchmarks. However, there are several high
and medium
severity issues identified, particularly concerning error handling in CUDA kernels (assert(false)
should be replaced) and placeholder/incomplete documentation. Addressing these issues, especially the error handling, is recommended before merging to ensure robustness and maintainability. I am not authorized to approve pull requests, so further review and approval by maintainers will be necessary after these points are considered.
); | ||
break; | ||
default: | ||
assert(false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using assert(false)
for unhandled head_dim
cases (and similarly on line 501 for shared memory limits) can lead to uninformative crashes in release builds where asserts might be disabled, or cryptic messages otherwise. Would it be better to use TORCH_CHECK(false, "Unsupported head_dim: ", head_dim);
or throw a std::runtime_error
to provide clearer error messages that can be propagated to Python?
); | ||
break; | ||
default: | ||
assert(false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
file(GLOB_RECURSE SRC_CPP CONFIGURE_DEPENDS "${PROJECT_SOURCE_DIR}/csrc/*.cpp") | ||
file(GLOB_RECURSE SRC_CUDA CONFIGURE_DEPENDS "${PROJECT_SOURCE_DIR}/csrc/*.cu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using file(GLOB_RECURSE)
can be convenient during development, but it's generally discouraged for production CMake builds. If new source files are added, CMake won't automatically detect them unless it's re-run, which can lead to build issues. Explicitly listing source files is more robust. Could you consider if explicitly listing sources would be more appropriate here, or if the current approach is preferred for ease of development within this kernel library?
|
||
#### Static Compilation (Recommended) | ||
```bash | ||
git clone https://github.com/YourUsername/lightllm_kernel.git |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 注册Python接口:在 csrc/ops_bindings.cpp中,将新增的算子通过 PyBind11 或 TORCH_LIBRARY 等机制注册到 Python 接口。 | ||
- 导出算子到Python模块:在lightllm_kernel/ops/__init__.py只添加相应的导出代码,使新算子包含在 lightllm_kernel.ops 模块中。 | ||
- 本地测试:开发完成后,请在本地对您的更改进行测试。您可以编译安装新的版本并编写简单的脚本调用新算子,检查其功能和性能是否符合预期。如果项目附带了测试用例,也请运行所有测试确保不引入回归。 | ||
- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
#### Static Compilation (Recommended) | ||
```bash | ||
git clone https://github.com/YourUsername/lightllm_kernel.git |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
local_bf16[j] = _float22bf162_rn(tmp); | ||
} | ||
|
||
vec_copy<sizeof(bf16_t) * VPT>(local_bf16, workspace + (i >> 1)); | ||
|
||
// Compute the max for the VPT elements. | ||
#pragma unroll | ||
for(int32_t j = 0; j< VPT/2; j++){ | ||
fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]); | ||
fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y)); | ||
local_max = fmaxf(local_max, max); | ||
} | ||
} | ||
|
||
// Reduce the maximum value across the thread group | ||
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max); | ||
|
||
// Compute the scale factor with epsilon to avoid division by zero | ||
constexpr fp32_t epsilon = 1e-7f; | ||
const fp32_t scale = reduced_max / FP8_E4M3_MAX; | ||
const fp32_t inv_scale = 1.0f / (scale + epsilon); | ||
|
||
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { | ||
vec_copy<sizeof(bf16_t) * VPT>(workspace + (i >> 1), local_bf16); | ||
|
||
#pragma unroll | ||
for (int32_t j = 0; j < VPT/4; j++) { | ||
fp32x2_t x = bf16x2_to_fp32x2(local_bf16[2 * j + 0]); | ||
fp32x2_t y = bf16x2_to_fp32x2(local_bf16[2 * j + 1]); | ||
fp32x4_t ret = make_float4( | ||
x.x * inv_scale, | ||
x.y * inv_scale, | ||
y.x * inv_scale, | ||
y.y * inv_scale | ||
); | ||
local_f8[j] = fp8x4_e4m3_t(ret); | ||
} | ||
|
||
vec_copy<sizeof(fp8_e4m3_t) * VPT>(local_f8, _output + i); | ||
} | ||
|
||
if(tid == 0){ | ||
*_scales = scale; | ||
} | ||
} | ||
|
||
|
||
template<int32_t TPB> | ||
__global__ void gelu_per_token_quant_bf16_to_fp8_general( | ||
const bf16_t* __restrict__ input, // Input tensor in BF16 format | ||
fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format | ||
fp32_t* __restrict__ scales, // Output scales for each group | ||
const int64_t M, // Number of rows in the input tensor | ||
const int32_t N | ||
) { | ||
const int32_t bid = blockIdx.x; | ||
const int32_t tid = threadIdx.x; | ||
constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format | ||
constexpr fp32_t sqrt_2_over_pi = 0.7978845608028654f; | ||
constexpr fp32_t coeff = 0.044715f; | ||
|
||
const bf16_t* _input = input + bid * N; // Input pointer for the group | ||
fp8_e4m3_t* _output = output + bid * N; // Output pointer for the group | ||
|
||
fp32_t* _scales; | ||
_scales = scales + bid; | ||
|
||
extern __shared__ bf16_t workspace_[]; | ||
|
||
fp32_t local_max = -FLT_MAX; | ||
|
||
for (int32_t i = tid; i < N; i += TPB) { | ||
fp32_t tmp = cvt_bf16_f32(_input[i]); | ||
fp32_t tanh_arg = sqrt_2_over_pi * (tmp + coeff * tmp * tmp * tmp); | ||
tmp = 0.5f * tmp * (1.0f + tanhf(tanh_arg)); | ||
local_max = fmaxf(local_max, fabsf(tmp)); | ||
workspace_[i] = cvt_f32_bf16(tmp); | ||
} | ||
|
||
// Reduce the maximum value across the thread group | ||
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max); | ||
|
||
// Compute the scale factor with epsilon to avoid division by zero | ||
constexpr fp32_t epsilon = 1e-7f; | ||
const fp32_t scale = reduced_max / FP8_E4M3_MAX; | ||
const fp32_t inv_scale = 1.0f / (scale + epsilon); | ||
|
||
for (int32_t i = tid; i < N; i += TPB) { | ||
// Load the previously stored vectorized data from shared memory. | ||
fp32_t x = cvt_bf16_f32(workspace_[i]); | ||
// Apply normalization: multiply by inv_norm and then scale by the weight. | ||
fp32_t ret = x * inv_scale; | ||
_output[i] = fp8_e4m3_t(ret); | ||
} | ||
|
||
if(tid == 0){ | ||
*_scales = scale; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It appears that different GeLU approximations are used within the dispatch logic of the gelu_per_token_quant_bf16_fp8
C++ binding function:
- The fixed-N templated kernels (e.g.,
device_gelu_per_token_quant_bf16_to_fp8<TPB, N>
) use anerf
-based GeLU approximation (lines 45-53). - The general (
_general
) and vectorized (_vpt
) fallback kernels use atanhf
-based approximation (e.g., lines 131-140 for_vpt
, lines 211-216 for_general
).
While both are valid approximations, using different ones based on N
might lead to subtle numerical inconsistencies. Was this intentional for performance reasons, or would it be preferable to use a consistent GeLU approximation across all paths? If intentional, a comment explaining this choice could be beneficial.
_C = importlib.import_module(f"{PKG}._C") | ||
except ImportError: | ||
# raise ImportError("Cannot import compiled extension 'lightllm_kernel.ops'") | ||
repo_root = Path(__file__).resolve().parents[3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The line repo_root = Path(__file__).resolve().parents[3]
hardcodes the relative path depth to find the repository root. This can be fragile if the directory structure changes. Could a more robust way to locate the repository root or necessary include paths be considered for the JIT compilation fallback, perhaps using environment variables or a configuration mechanism if this JIT path is critical for some workflows?
38d3b47
to
e540520
Compare
No description provided.