-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fp8-fused gemm kernel #5764
Conversation
One thing that needs to be resolved before merging this, this kernel requires |
Hi @jeffra. To clarify, does this kernel require exactly |
@sfc-gh-reyazda would know better, I am not sure if we've tested with newer triton than 2.3.0. I have not personally tested this at least. |
It needs that specific version, unfortunately triton keeps changing/improving and their APIs change too so it is hard to track it properly. That's also another motivation to move to cutlass soon and have a more solid implementation to work independent of other libraries. On the other hand, Triton gives the flexibility to run on various hardwares. So, it is always a tradeoff. I think we need to have some more discussions on such dependencies later in a different discussion. |
This is a refresh of of `OptimizedLinear` with the following features to improve performance and usability: * More efficient sharing of base weights using `all_gather_into_tensor` * Flattened sharded weights * Selectively offload frozen weights to cpu * `deepspeed.linear.Init` that allows injecting OptimizedLinear during model construction (similar to zero.Init) * Support for load state dict directly in OptimizedLinear, this allows loading HF model weights correctly into sharded params * Various bug fixes for the LoRA implementation introduced previously * Several new unit tests Builds on-top of @RezaYazdaniAminabadi's previous FP8 updates (#5764) to support dense model fp8 quantization. Example usage of this to fine-tune llama-3.1-405B on a single node: https://github.com/Snowflake-Labs/snowflake-arctic/tree/main/training/llama3.1 --------- Co-authored-by: Reza Yazdani <reza.yazdani@snowflake.com> Co-authored-by: Reza Yazdani <152926435+sfc-gh-reyazda@users.noreply.github.com>
This PR adds the new fused kernel for the Dense GeMM using fp8-quantized weight.